torch.nn.Module 뜯어먹기
Introduction
PyTorch로 Machine Learning 모델링과 학습을 한다면 대부분 nn.Module을 상속해서 사용할 것이다. 그래서 대부분의 기능이 nn.Module에서 구현된 코드와 연결된 것이 많다. 좀 더 좋은 코드를 만들고 이해하고자 부스트캠프 조원들과 심화 포스팅을 하기로 했는데, 내가 하기로 한 심화 포스팅 주제는 nn.Module을 뜯어먹기 이다.
많은 함수 분석을 하겠지만 모든 함수를 분석하진 않을 것이다. 그리고 외부에 보이는 메서드만 말고 내부에서 동작하는 메서드도 분석할 것이다.
torch.nn.Module
torch.nn.Module은 Nueral Network의 base class 역할을 한다.- 우리의 모델은
nn.Module의 sub class로 구현이 된다. Module의 클래스에는 변수로dump_patches와_version,training,_is_full_backward_hook이 선언되어있다. 이 attribute들은 이후 다른 메서드들에서 사용이 된다.
1. __init__(self)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def __init__(self) -> None:
torch._C._log_api_usage_once("python.nn_module")
self.training = True
self._parameters: Dict[str, Optional[Parameter]] = OrderedDict()
self._buffers: Dict[str, Optional[Tensor]] = OrderedDict()
self._non_persistent_buffers_set: Set[str] = set()
self._backward_hooks: Dict[int, Callable] = OrderedDict()
self._is_full_backward_hook = None
self._forward_hooks: Dict[int, Callable] = OrderedDict()
self._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
self._state_dict_hooks: Dict[int, Callable] = OrderedDict()
self._load_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict()
self._modules: Dict[str, Optional['Module']] = OrderedDict()
- 개인적으로 여러군데에 사용되는 데 정확한 역할이 뭔지 찾지 못한 것이 있는데,
torch._C에 있는 함수들의 원형을 찾기가 어려웠다. - 이 외에는 대부분은 변수를 설정하는 역할을 한다.
self.training은 이후eval()함수와train()함수에서 사용되는 훈련 세팅 여부를 결정한다. - 우리가 세팅한
nn.Linear,Conv같은 레이어들은nn.Module기반이므로self._modules에{레이어 변수명 : 레이어 종류}형태로 저장한다. self._modules와 같이self._parameters도Parameter변수가 자동적으로 저장된다.

2. forward()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
forward: Callable[..., Any] = _forward_unimplemented
def _forward_unimplemented(self, *input: Any) -> None:
"""
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`Module` instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
"""
raise NotImplementedError
def _call_impl(self, *input, **kwargs):
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
...
forward는 반드시 모든 subclass에서 오버라이딩을 통해 재현을 해줘야한다. 만약 구현이 되지 않는다면NotImplementedError를 발생시킨다._forward_unimplemented에 적힌 추가적인 설명에 따르면 우리가 정의한 함수는 등록된 hook들을 신경쓰며 수행하지만 정의한 함수가 Module instance를 호출하게되는 경우 hook을 무시하게된다고 한다.
영어 자체가 해석하기 어렵게 적혀있어서…. 틀린 이해일 수 있다._forward_unimplemented는 forward를 정의하지 않을 경우에만 호출된다.
_call_impl함수에 따르면 가장 첫단계로torch._C.get_tracing_state()를 통해 조건을 확인하고self.forward의 구현부를 가져온다.
3. apply(self, fn)
1
2
3
4
5
def apply(self: T, fn: Callable[['Module'], None]) -> T:
for module in self.children():
module.apply(fn)
fn(self)
return self
- self의
children()을 통해named_children을 호출하고 이는yield구문을 통해 module과 name을 반환해주는데, 이를 활용해 후위순회로fn을 모듈에 적용한다.
4. dump_patches, _version
dump_patches와_version은 module의 변화 상태를 기록하는 역할을 하는 것으로 보인다.- 새로운 parameter와 buffer가 module에 추가/제거되면 충돌을 일으키는 역할을
dump_patches가 하는데, 이때_load_from_state_dict가_version의 번호를 비교하여 적절한 수행을 한다.
5. eval(), train()
1
2
3
4
5
6
7
8
9
10
def eval(self: T) -> T:
return self.train(False)
def train(self: T, mode: bool = True) -> T:
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
self.training = mode
for module in self.children():
module.train(mode)
return self
eval과train은 현재 모델의 훈련 상태를 설정한다.eval은 모든 module의self.training을False로 만들고train은 인자로 들어온 상태에 대해self.training을 세팅하는 역할을 한다.
apply도 그렇고 의외로 재귀 구문으로 module에 함수를 적용하는 패턴이 많이 발견된다.
6. extra_repr(), __repr__()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def extra_repr(self) -> str:
return ''
def __repr__(self):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = self.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split('\n')
child_lines = []
for key, module in self._modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
lines = extra_lines + child_lines
main_str = self._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str
extra_repr그 자체는 큰 의미가 있는 함수는 아니다. 이를 분석하기 위해서는__repr__을 분석해야한다.__repr__에서 시작하자마자self.extra_repr을 우선적으로 호출한다.
그 후 내용이 있다면\n을 기준으로 분리를 하고 일련의 과정을 통해 출력문을 설정한다.
7. register_forward(_pre)_hook(self, hook)
register_forward_hook과register_forward_pre_hook은 둘다 객체의 attribute인self._forward_hooks와self._forward_pre_hooks에 입력으로 들어온 hook을 저장한다.forawrd_hook- 코드 설명에 따르면 매
forward()호출 후 출력에 대해 hook을 수행한다고 한다. - output, input을 모두 수정할 수 있지만
forward가 호출된 이후이므로 input의 수정이 영향을 미치지 않는다.
(원문 : The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called afterforwardis called.) - 실제 자세한 구동은 뒤에 서술할
_call_impl에서 과정을 설명하겠다.
- 코드 설명에 따르면 매
forward_pre_hook- 코드 설명에 따르면
forward()호출 이전에 invoke된다. - input 수정이 가능하고 return이 가능하다. 단일 value를 return 하는 경우 tuple로 값을 wrapping한다.
- 코드 설명에 따르면
8. register_full_backward_hook(self, hook)
- 객체의 attribute인
self._is_full_backward_hook을 True로 변경한다. 이후self._backward_hooks에hook을 등록한다. - 입력에대한 gradient가 계산될 때마다 hook을 호출한다.
- hook은 argument들을 수정하면 안되지만 선택적으로 (Optionally)
grad_input대신 사용할 new gradient의 반환이 가능하다. input과output을 수정할 경우 에러를 발생시킨다.
(원문 : Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error.)
9. __call__, _call_impl(*input, **kwargs)
- 객체가 호출될 때 수행되는 magic method 함수이다.
- 이 부분 코드에 의해 단순히 객체가 call됨에도 forward 연산이 수행된다.
- 세세하게 뜯어보긴하겠지만 모든 코드를 해석하진 않을 것이다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
__call__ : Callable[..., Any] = _call_impl
def _call_impl(self, *input, **kwargs):
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
# If we don't have any hooks, we want to skip the rest of the logic in
# this function, and just call forward.
if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks
or _global_backward_hooks or _global_forward_hooks
or _global_forward_pre_hooks):
return forward_call(*input, **kwargs)
# Do not call functions when jit is used
full_backward_hooks, non_full_backward_hooks = [], []
if self._backward_hooks or _global_backward_hooks:
full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
if _global_forward_pre_hooks or self._forward_pre_hooks:
for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
bw_hook = None
if full_backward_hooks:
bw_hook = hooks.BackwardHook(self, full_backward_hooks)
input = bw_hook.setup_input_hook(input)
result = forward_call(*input, **kwargs)
if _global_forward_hooks or self._forward_hooks:
for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if bw_hook:
result = bw_hook.setup_output_hook(result)
# Handle the non-full backward hooks
if non_full_backward_hooks:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in non_full_backward_hooks:
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
return result
torch._C._get_tracing_state()의 상태가False면self.forwad()를 사용- hook이 없다면 앞서 할당한
foward_callattribute를 바로 반환한다.1 2 3 4
if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks or _global_forward_hooks or _global_forward_pre_hooks): return forward_call(*input, **kwargs)
- hook이 등록된 경우의 순서는 다음과 같다.
- 우선 backward와 관련된 hook들을 미리 확인하여 저장한다.
1 2 3
full_backward_hooks, non_full_backward_hooks = [], [] if self._backward_hooks or _global_backward_hooks: full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
- pre_hook들을 확인하고 존재한다면 input에 hook을 적용한다.
1 2 3 4 5 6 7
if _global_forward_pre_hooks or self._forward_pre_hooks: for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()): result = hook(self, input) if result is not None: if not isinstance(result, tuple): result = (result,) input = result
- bw_hook으로 input에 hook을 설정하고
forward를 진행한다.1 2 3 4 5 6
bw_hook = None if full_backward_hooks: bw_hook = hooks.BackwardHook(self, full_backward_hooks) input = bw_hook.setup_input_hook(input) result = forward_call(*input, **kwargs)
- 그 후
forward_hook과 관련된 내용들을 확인한다. 만약 값들이 존재한다면 앞서 연산한 forward의 결과에forward_hook들을 적용한다. 이런 이유로 input을 변경해도forward값에 영향을 주지 못하는 것으로 보인다.1 2 3 4 5
if _global_forward_hooks or self._forward_hooks: for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()): hook_result = hook(self, input, result) if hook_result is not None: result = hook_result
- 그 후 앞서 backward_hook을 확인해서 저장하는 bw_hook에 값이 있다면 앞서 계산한 forward의 연산 결과에
setup_output_hook을 사용해 설정을 해주는 것으로 보인다.
해당 메서드에 대한 설명이 자세히 나와있지 않아서 자세한 설명이 어렵다…
- 우선 backward와 관련된 hook들을 미리 확인하여 저장한다.
Comments powered by Disqus.