PyTorch : Multi-GPU 학습
Introduction
- Single GPU란 GPU가 1개가 존재하는 것이고 Multi GPU란 1개의 컴퓨터에 여러개의 GPU가 존재하는 것을 의미한다.
- 일반적으로 Node라고 표현하면 1대의 컴퓨터를 의미한다.
- 요즘 많은 컴퓨터는 Single Node Multi GPU 방식을 사용한다.
- 대용량 서버실은 Multi Node, Multi GPU방식을 쓴다.
- NVIDIA에서 이런 Multi GPU를 위한 TensorRT라는 기술을 공개하기도 했다.
Parallel
- 다중 GPU에서 학습을 분사하는 방법은 2가지가 있다.
- 모델의 병렬화
- 데이터 병렬화
- 모델을 나누는 방식은 이전부터 많이 활용했었고 대표적으로는 AlexNet이 있다.
요즘은 큰 모델을 생성하는 추세에 맞춰 모델을 나누는 방식을 사용한다.- 하지만 모델 병렬화는 모델의 병목과 파이프라인 구성의 어려움으로 쉽지 않은 과제이다.
Model Parallel
- 모델 병렬화 과정에서는 forward와 backward의 수행과정이 서로 다른 GPU간에 병렬적으로 수행해야한다.
만약 하나의 수행이 진행되는 동안 다른 GPU가 작업을 진행하지 않는다면 이는 병렬화의 의미가 많이 없다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class ModelParallelResNet(ResNet):
def __init__(self, *args, **kwargs):
super(ModelParallelResNet,, self).__init__(
Botleneck, [3, 4, 6, 3], num_classes=num_classes, *args, **kwargs)
self.seq1 = nn.Seqeuntial(....).to('cuda:0')
self.seq2 = nn.Sequential(....).to('cuda:1')
self.fc.to('cuda:1')
def forward(self, x):
'''
your forward function
'''
x = self.seq2(self.seq1(x).to('cuda:1'))
return self.fc(x.view(x.size(0), -1))
seq1
과seq2
를 각각 GPU에 할당해준다.- 이후
forward
과정에서 적절히 forward를 수행하고 한쪽의 값을 다른 GPU에 다시 할당하는 방식을 쓰면 된다.
그러면 모델을 각자 사용해서 연결하는 방식을 쓰는 것인데, 이러면 병목이 발생해서 잘 쓰지 않기도 한다.
Data Parallel
- 모델을 분할하듯이 데이터를 분할해서 각 GPU가 데이터를 처리하는 것이다.
- mini-batch 방식과 유사하지만 한번에 동시에 여러 GPU가 작업을 수행한다.
- PyTorch에서는 2가지 방식을 제공한다.
DataParallel
: 데이터를 분배한 후 다시 하나의 GPU에 합쳐 평균을 계산한다.- 이 방법은 GPU 1개가 떠앉는 overhead가 크다는 문제가 있다. GPU 병목현상이 발생할 수 있다.
DistributedDataParallel
DataParallel
의 단점을 해결하고자 나온 방식이다.- 각 CPU와 GPU를 매핑하여 개별적인 연산을 진행하고 연산결과만을 합쳐 평균을 계산한다.
- 단순 평균치 연산만 이뤄지므로 overhead가 적다.
- 기본적으로
DataParallel
로 하고 개별적 연산의 평균을 낸다.
- 기본 세팅
1
2
3
4
5
6
7
8
9
10
11
import torch
# DataParallel
parallel_model = torch.nn.DataParallel(model)
# DistributedDataParallel
sampler = torch.utils.data.distributed.DistributedSampler(train_data)
shuffle = False
pin_memory = True
train_loader = torch.utils.data.DataLoader(train_data, batch_size=20, shuffle=shuffle,
pin_memory=pin_memory, num_workers=3,sampler=sampler)
- 메인 코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
def main():
n_gpu = torch.cuda.device_count()
torch.multiprocessing.spawn(main_worker, nprocs=n_gpus, args=(n_gpus, ))
def main_worker(gpu, n_gpus):
'''
your settings
'''
batch_size = int(batch_size/n_gpus)
num_worker = int(num_worker/n_gpus)
# set multiprocessing protocol
model = MODEL
torch.cuda.set_device(gpu)
model = model.cuda(gpu)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
from multiprocessing import Pool
def f(x):
return # your function
if __name__ == '__main__':
with Pool(4) as p:
print(p.map(f, [1,2,3]))
Comments powered by Disqus.