torch.nn.Module 뜯어먹기
Introduction
PyTorch로 Machine Learning 모델링과 학습을 한다면 대부분 nn.Module
을 상속해서 사용할 것이다. 그래서 대부분의 기능이 nn.Module
에서 구현된 코드와 연결된 것이 많다. 좀 더 좋은 코드를 만들고 이해하고자 부스트캠프 조원들과 심화 포스팅을 하기로 했는데, 내가 하기로 한 심화 포스팅 주제는 nn.Module
을 뜯어먹기 이다.
많은 함수 분석을 하겠지만 모든 함수를 분석하진 않을 것이다. 그리고 외부에 보이는 메서드만 말고 내부에서 동작하는 메서드도 분석할 것이다.
torch.nn.Module
torch.nn.Module
은 Nueral Network의 base class 역할을 한다.- 우리의 모델은
nn.Module
의 sub class로 구현이 된다. Module
의 클래스에는 변수로dump_patches
와_version
,training
,_is_full_backward_hook
이 선언되어있다. 이 attribute들은 이후 다른 메서드들에서 사용이 된다.
1. __init__(self)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def __init__(self) -> None:
torch._C._log_api_usage_once("python.nn_module")
self.training = True
self._parameters: Dict[str, Optional[Parameter]] = OrderedDict()
self._buffers: Dict[str, Optional[Tensor]] = OrderedDict()
self._non_persistent_buffers_set: Set[str] = set()
self._backward_hooks: Dict[int, Callable] = OrderedDict()
self._is_full_backward_hook = None
self._forward_hooks: Dict[int, Callable] = OrderedDict()
self._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
self._state_dict_hooks: Dict[int, Callable] = OrderedDict()
self._load_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict()
self._modules: Dict[str, Optional['Module']] = OrderedDict()
- 개인적으로 여러군데에 사용되는 데 정확한 역할이 뭔지 찾지 못한 것이 있는데,
torch._C
에 있는 함수들의 원형을 찾기가 어려웠다. - 이 외에는 대부분은 변수를 설정하는 역할을 한다.
self.training
은 이후eval()
함수와train()
함수에서 사용되는 훈련 세팅 여부를 결정한다. - 우리가 세팅한
nn.Linear
,Conv
같은 레이어들은nn.Module
기반이므로self._modules
에{레이어 변수명 : 레이어 종류}
형태로 저장한다. self._modules
와 같이self._parameters
도Parameter
변수가 자동적으로 저장된다.
2. forward()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
forward: Callable[..., Any] = _forward_unimplemented
def _forward_unimplemented(self, *input: Any) -> None:
"""
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`Module` instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
"""
raise NotImplementedError
def _call_impl(self, *input, **kwargs):
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
...
forward
는 반드시 모든 subclass에서 오버라이딩을 통해 재현을 해줘야한다. 만약 구현이 되지 않는다면NotImplementedError
를 발생시킨다._forward_unimplemented
에 적힌 추가적인 설명에 따르면 우리가 정의한 함수는 등록된 hook들을 신경쓰며 수행하지만 정의한 함수가 Module instance를 호출하게되는 경우 hook을 무시하게된다고 한다.
영어 자체가 해석하기 어렵게 적혀있어서…. 틀린 이해일 수 있다._forward_unimplemented
는 forward를 정의하지 않을 경우에만 호출된다.
_call_impl
함수에 따르면 가장 첫단계로torch._C.get_tracing_state()
를 통해 조건을 확인하고self.forward
의 구현부를 가져온다.
3. apply(self, fn)
1
2
3
4
5
def apply(self: T, fn: Callable[['Module'], None]) -> T:
for module in self.children():
module.apply(fn)
fn(self)
return self
- self의
children()
을 통해named_children
을 호출하고 이는yield
구문을 통해 module과 name을 반환해주는데, 이를 활용해 후위순회로fn
을 모듈에 적용한다.
4. dump_patches, _version
dump_patches
와_version
은 module의 변화 상태를 기록하는 역할을 하는 것으로 보인다.- 새로운 parameter와 buffer가 module에 추가/제거되면 충돌을 일으키는 역할을
dump_patches
가 하는데, 이때_load_from_state_dict
가_version
의 번호를 비교하여 적절한 수행을 한다.
5. eval(), train()
1
2
3
4
5
6
7
8
9
10
def eval(self: T) -> T:
return self.train(False)
def train(self: T, mode: bool = True) -> T:
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
self.training = mode
for module in self.children():
module.train(mode)
return self
eval
과train
은 현재 모델의 훈련 상태를 설정한다.eval
은 모든 module의self.training
을False
로 만들고train
은 인자로 들어온 상태에 대해self.training
을 세팅하는 역할을 한다.
apply
도 그렇고 의외로 재귀 구문으로 module에 함수를 적용하는 패턴이 많이 발견된다.
6. extra_repr(), __repr__()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def extra_repr(self) -> str:
return ''
def __repr__(self):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = self.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split('\n')
child_lines = []
for key, module in self._modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
lines = extra_lines + child_lines
main_str = self._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str
extra_repr
그 자체는 큰 의미가 있는 함수는 아니다. 이를 분석하기 위해서는__repr__
을 분석해야한다.__repr__
에서 시작하자마자self.extra_repr
을 우선적으로 호출한다.
그 후 내용이 있다면\n
을 기준으로 분리를 하고 일련의 과정을 통해 출력문을 설정한다.
7. register_forward(_pre)_hook(self, hook)
register_forward_hook
과register_forward_pre_hook
은 둘다 객체의 attribute인self._forward_hooks
와self._forward_pre_hooks
에 입력으로 들어온 hook을 저장한다.forawrd_hook
- 코드 설명에 따르면 매
forward()
호출 후 출력에 대해 hook을 수행한다고 한다. - output, input을 모두 수정할 수 있지만
forward
가 호출된 이후이므로 input의 수정이 영향을 미치지 않는다.
(원문 : The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called afterforward
is called.) - 실제 자세한 구동은 뒤에 서술할
_call_impl
에서 과정을 설명하겠다.
- 코드 설명에 따르면 매
forward_pre_hook
- 코드 설명에 따르면
forward()
호출 이전에 invoke된다. - input 수정이 가능하고 return이 가능하다. 단일 value를 return 하는 경우 tuple로 값을 wrapping한다.
- 코드 설명에 따르면
8. register_full_backward_hook(self, hook)
- 객체의 attribute인
self._is_full_backward_hook
을 True로 변경한다. 이후self._backward_hooks
에hook
을 등록한다. - 입력에대한 gradient가 계산될 때마다 hook을 호출한다.
- hook은 argument들을 수정하면 안되지만 선택적으로 (Optionally)
grad_input
대신 사용할 new gradient의 반환이 가능하다. input
과output
을 수정할 경우 에러를 발생시킨다.
(원문 : Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error.)
9. __call__, _call_impl(*input, **kwargs)
- 객체가 호출될 때 수행되는 magic method 함수이다.
- 이 부분 코드에 의해 단순히 객체가 call됨에도 forward 연산이 수행된다.
- 세세하게 뜯어보긴하겠지만 모든 코드를 해석하진 않을 것이다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
__call__ : Callable[..., Any] = _call_impl
def _call_impl(self, *input, **kwargs):
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
# If we don't have any hooks, we want to skip the rest of the logic in
# this function, and just call forward.
if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks
or _global_backward_hooks or _global_forward_hooks
or _global_forward_pre_hooks):
return forward_call(*input, **kwargs)
# Do not call functions when jit is used
full_backward_hooks, non_full_backward_hooks = [], []
if self._backward_hooks or _global_backward_hooks:
full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
if _global_forward_pre_hooks or self._forward_pre_hooks:
for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
bw_hook = None
if full_backward_hooks:
bw_hook = hooks.BackwardHook(self, full_backward_hooks)
input = bw_hook.setup_input_hook(input)
result = forward_call(*input, **kwargs)
if _global_forward_hooks or self._forward_hooks:
for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if bw_hook:
result = bw_hook.setup_output_hook(result)
# Handle the non-full backward hooks
if non_full_backward_hooks:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in non_full_backward_hooks:
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
return result
torch._C._get_tracing_state()
의 상태가False
면self.forwad()
를 사용- hook이 없다면 앞서 할당한
foward_call
attribute를 바로 반환한다.1 2 3 4
if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks or _global_forward_hooks or _global_forward_pre_hooks): return forward_call(*input, **kwargs)
- hook이 등록된 경우의 순서는 다음과 같다.
- 우선 backward와 관련된 hook들을 미리 확인하여 저장한다.
1 2 3
full_backward_hooks, non_full_backward_hooks = [], [] if self._backward_hooks or _global_backward_hooks: full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
- pre_hook들을 확인하고 존재한다면 input에 hook을 적용한다.
1 2 3 4 5 6 7
if _global_forward_pre_hooks or self._forward_pre_hooks: for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()): result = hook(self, input) if result is not None: if not isinstance(result, tuple): result = (result,) input = result
- bw_hook으로 input에 hook을 설정하고
forward
를 진행한다.1 2 3 4 5 6
bw_hook = None if full_backward_hooks: bw_hook = hooks.BackwardHook(self, full_backward_hooks) input = bw_hook.setup_input_hook(input) result = forward_call(*input, **kwargs)
- 그 후
forward_hook
과 관련된 내용들을 확인한다. 만약 값들이 존재한다면 앞서 연산한 forward의 결과에forward_hook
들을 적용한다. 이런 이유로 input을 변경해도forward
값에 영향을 주지 못하는 것으로 보인다.1 2 3 4 5
if _global_forward_hooks or self._forward_hooks: for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()): hook_result = hook(self, input, result) if hook_result is not None: result = hook_result
- 그 후 앞서 backward_hook을 확인해서 저장하는 bw_hook에 값이 있다면 앞서 계산한 forward의 연산 결과에
setup_output_hook
을 사용해 설정을 해주는 것으로 보인다.
해당 메서드에 대한 설명이 자세히 나와있지 않아서 자세한 설명이 어렵다…
- 우선 backward와 관련된 hook들을 미리 확인하여 저장한다.
Comments powered by Disqus.