Home [BoostCamp AI Tech / 심화포스팅] torch.nn.Module 뜯어먹기
Post
Cancel

[BoostCamp AI Tech / 심화포스팅] torch.nn.Module 뜯어먹기

torch.nn.Module 뜯어먹기


Introduction

PyTorch로 Machine Learning 모델링과 학습을 한다면 대부분 nn.Module을 상속해서 사용할 것이다. 그래서 대부분의 기능이 nn.Module에서 구현된 코드와 연결된 것이 많다. 좀 더 좋은 코드를 만들고 이해하고자 부스트캠프 조원들과 심화 포스팅을 하기로 했는데, 내가 하기로 한 심화 포스팅 주제는 nn.Module을 뜯어먹기 이다.
많은 함수 분석을 하겠지만 모든 함수를 분석하진 않을 것이다. 그리고 외부에 보이는 메서드만 말고 내부에서 동작하는 메서드도 분석할 것이다.

torch.nn.Module

  • torch.nn.Module은 Nueral Network의 base class 역할을 한다.
  • 우리의 모델은 nn.Module의 sub class로 구현이 된다.
  • Module의 클래스에는 변수로 dump_patches_version, training, _is_full_backward_hook이 선언되어있다. 이 attribute들은 이후 다른 메서드들에서 사용이 된다.

1. __init__(self)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def __init__(self) -> None:
    torch._C._log_api_usage_once("python.nn_module")

    self.training = True
    self._parameters: Dict[str, Optional[Parameter]] = OrderedDict()
    self._buffers: Dict[str, Optional[Tensor]] = OrderedDict()
    self._non_persistent_buffers_set: Set[str] = set()
    self._backward_hooks: Dict[int, Callable] = OrderedDict()
    self._is_full_backward_hook = None
    self._forward_hooks: Dict[int, Callable] = OrderedDict()
    self._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
    self._state_dict_hooks: Dict[int, Callable] = OrderedDict()
    self._load_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict()
    self._modules: Dict[str, Optional['Module']] = OrderedDict()
  • 개인적으로 여러군데에 사용되는 데 정확한 역할이 뭔지 찾지 못한 것이 있는데, torch._C에 있는 함수들의 원형을 찾기가 어려웠다.
  • 이 외에는 대부분은 변수를 설정하는 역할을 한다. self.training은 이후 eval()함수와 train()함수에서 사용되는 훈련 세팅 여부를 결정한다.
  • 우리가 세팅한 nn.Linear, Conv같은 레이어들은 nn.Module 기반이므로 self._modules{레이어 변수명 : 레이어 종류}형태로 저장한다.
  • self._modules와 같이 self._parametersParameter변수가 자동적으로 저장된다.

2. forward()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
forward: Callable[..., Any] = _forward_unimplemented

def _forward_unimplemented(self, *input: Any) -> None:
    """
    .. note::
    Although the recipe for forward pass needs to be defined within
    this function, one should call the :class:`Module` instance afterwards
    instead of this since the former takes care of running the
    registered hooks while the latter silently ignores them.
    """

    raise NotImplementedError

def _call_impl(self, *input, **kwargs):
    forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
        ...
  • forward는 반드시 모든 subclass에서 오버라이딩을 통해 재현을 해줘야한다. 만약 구현이 되지 않는다면 NotImplementedError를 발생시킨다.
  • _forward_unimplemented에 적힌 추가적인 설명에 따르면 우리가 정의한 함수는 등록된 hook들을 신경쓰며 수행하지만 정의한 함수가 Module instance를 호출하게되는 경우 hook을 무시하게된다고 한다.
    영어 자체가 해석하기 어렵게 적혀있어서…. 틀린 이해일 수 있다.
    • _forward_unimplemented는 forward를 정의하지 않을 경우에만 호출된다.
  • _call_impl함수에 따르면 가장 첫단계로 torch._C.get_tracing_state()를 통해 조건을 확인하고 self.forward의 구현부를 가져온다.

3. apply(self, fn)

1
2
3
4
5
def apply(self: T, fn: Callable[['Module'], None]) -> T:
    for module in self.children():
        module.apply(fn)
    fn(self)
    return self
  • self의 children()을 통해 named_children을 호출하고 이는 yield 구문을 통해 module과 name을 반환해주는데, 이를 활용해 후위순회로 fn을 모듈에 적용한다.

4. dump_patches, _version

  • dump_patches_version은 module의 변화 상태를 기록하는 역할을 하는 것으로 보인다.
  • 새로운 parameter와 buffer가 module에 추가/제거되면 충돌을 일으키는 역할을 dump_patches가 하는데, 이때 _load_from_state_dict_version의 번호를 비교하여 적절한 수행을 한다.

5. eval(), train()

1
2
3
4
5
6
7
8
9
10
def eval(self: T) -> T:
    return self.train(False)

def train(self: T, mode: bool = True) -> T:
    if not isinstance(mode, bool):
        raise ValueError("training mode is expected to be boolean")
    self.training = mode
    for module in self.children():
        module.train(mode)
    return self
  • evaltrain은 현재 모델의 훈련 상태를 설정한다.
  • eval은 모든 module의 self.trainingFalse로 만들고 train은 인자로 들어온 상태에 대해 self.training을 세팅하는 역할을 한다.
    apply도 그렇고 의외로 재귀 구문으로 module에 함수를 적용하는 패턴이 많이 발견된다.

6. extra_repr(), __repr__()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def extra_repr(self) -> str:
    return ''

def __repr__(self):
    # We treat the extra repr like the sub-module, one item per line
    extra_lines = []
    extra_repr = self.extra_repr()
    # empty string will be split into list ['']
    if extra_repr:
        extra_lines = extra_repr.split('\n')
    child_lines = []
    for key, module in self._modules.items():
        mod_str = repr(module)
        mod_str = _addindent(mod_str, 2)
        child_lines.append('(' + key + '): ' + mod_str)
    lines = extra_lines + child_lines

    main_str = self._get_name() + '('
    if lines:
        # simple one-liner info, which most builtin Modules will use
        if len(extra_lines) == 1 and not child_lines:
            main_str += extra_lines[0]
        else:
            main_str += '\n  ' + '\n  '.join(lines) + '\n'

    main_str += ')'
    return main_str
  • extra_repr 그 자체는 큰 의미가 있는 함수는 아니다. 이를 분석하기 위해서는 __repr__을 분석해야한다.
  • __repr__에서 시작하자마자 self.extra_repr을 우선적으로 호출한다.
    그 후 내용이 있다면 \n을 기준으로 분리를 하고 일련의 과정을 통해 출력문을 설정한다.

7. register_forward(_pre)_hook(self, hook)

  • register_forward_hookregister_forward_pre_hook은 둘다 객체의 attribute인 self._forward_hooksself._forward_pre_hooks에 입력으로 들어온 hook을 저장한다.
  • forawrd_hook
    • 코드 설명에 따르면 매 forward() 호출 후 출력에 대해 hook을 수행한다고 한다.
    • output, input을 모두 수정할 수 있지만 forward가 호출된 이후이므로 input의 수정이 영향을 미치지 않는다.
      (원문 : The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after forward is called.)
    • 실제 자세한 구동은 뒤에 서술할 _call_impl에서 과정을 설명하겠다.
  • forward_pre_hook
    • 코드 설명에 따르면 forward() 호출 이전에 invoke된다.
    • input 수정이 가능하고 return이 가능하다. 단일 value를 return 하는 경우 tuple로 값을 wrapping한다.

8. register_full_backward_hook(self, hook)

  • 객체의 attribute인 self._is_full_backward_hook을 True로 변경한다. 이후 self._backward_hookshook을 등록한다.
  • 입력에대한 gradient가 계산될 때마다 hook을 호출한다.
  • hook은 argument들을 수정하면 안되지만 선택적으로 (Optionally) grad_input대신 사용할 new gradient의 반환이 가능하다.
  • inputoutput을 수정할 경우 에러를 발생시킨다.
    (원문 : Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error.)

9. __call__, _call_impl(*input, **kwargs)

  • 객체가 호출될 때 수행되는 magic method 함수이다.
  • 이 부분 코드에 의해 단순히 객체가 call됨에도 forward 연산이 수행된다.
  • 세세하게 뜯어보긴하겠지만 모든 코드를 해석하진 않을 것이다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
__call__ : Callable[..., Any] = _call_impl

def _call_impl(self, *input, **kwargs):
    forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
    # If we don't have any hooks, we want to skip the rest of the logic in
    # this function, and just call forward.
    if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks 
            or _global_backward_hooks or _global_forward_hooks 
            or _global_forward_pre_hooks):
        return forward_call(*input, **kwargs)
    # Do not call functions when jit is used
    full_backward_hooks, non_full_backward_hooks = [], []
    if self._backward_hooks or _global_backward_hooks:
        full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
    if _global_forward_pre_hooks or self._forward_pre_hooks:
        for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result

    bw_hook = None
    if full_backward_hooks:
        bw_hook = hooks.BackwardHook(self, full_backward_hooks)
        input = bw_hook.setup_input_hook(input)

    result = forward_call(*input, **kwargs)
    if _global_forward_hooks or self._forward_hooks:
        for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result

    if bw_hook:
        result = bw_hook.setup_output_hook(result)

    # Handle the non-full backward hooks
    if non_full_backward_hooks:
        var = result
        while not isinstance(var, torch.Tensor):
            if isinstance(var, dict):
                var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
            else:
                var = var[0]
        grad_fn = var.grad_fn
        if grad_fn is not None:
            for hook in non_full_backward_hooks:
                wrapper = functools.partial(hook, self)
                functools.update_wrapper(wrapper, hook)
                grad_fn.register_hook(wrapper)
            self._maybe_warn_non_full_backward_hook(input, result, grad_fn)

    return result
  • torch._C._get_tracing_state()의 상태가 Falseself.forwad()를 사용
  • hook이 없다면 앞서 할당한 foward_call attribute를 바로 반환한다.
    1
    2
    3
    4
    
      if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks 
              or _global_backward_hooks or _global_forward_hooks 
              or _global_forward_pre_hooks):
          return forward_call(*input, **kwargs)
    
  • hook이 등록된 경우의 순서는 다음과 같다.
    • 우선 backward와 관련된 hook들을 미리 확인하여 저장한다.
      1
      2
      3
      
        full_backward_hooks, non_full_backward_hooks = [], []
        if self._backward_hooks or _global_backward_hooks:
            full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
      
    • pre_hook들을 확인하고 존재한다면 input에 hook을 적용한다.
      1
      2
      3
      4
      5
      6
      7
      
        if _global_forward_pre_hooks or self._forward_pre_hooks:
            for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):
                result = hook(self, input)
                if result is not None:
                    if not isinstance(result, tuple):
                        result = (result,)
                    input = result
      
    • bw_hook으로 input에 hook을 설정하고 forward를 진행한다.
      1
      2
      3
      4
      5
      6
      
        bw_hook = None
        if full_backward_hooks:
            bw_hook = hooks.BackwardHook(self, full_backward_hooks)
            input = bw_hook.setup_input_hook(input)
      
        result = forward_call(*input, **kwargs)
      
    • 그 후 forward_hook과 관련된 내용들을 확인한다. 만약 값들이 존재한다면 앞서 연산한 forward의 결과에 forward_hook들을 적용한다. 이런 이유로 input을 변경해도 forward값에 영향을 주지 못하는 것으로 보인다.
      1
      2
      3
      4
      5
      
        if _global_forward_hooks or self._forward_hooks:
            for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
                hook_result = hook(self, input, result)
                if hook_result is not None:
                    result = hook_result
      
    • 그 후 앞서 backward_hook을 확인해서 저장하는 bw_hook에 값이 있다면 앞서 계산한 forward의 연산 결과에 setup_output_hook을 사용해 설정을 해주는 것으로 보인다.
      해당 메서드에 대한 설명이 자세히 나와있지 않아서 자세한 설명이 어렵다…
This post is licensed under CC BY 4.0 by the author.

[BoostCamp AI Tech] Day9

[BoostCamp AI Tech] Day10

Comments powered by Disqus.