PyTorch 深度剖析:并行訓(xùn)練的 DP 和 DDP 分別在什么情況下使用及實例

極市導(dǎo)讀
?這篇文章從應(yīng)用的角度出發(fā),介紹 DP 和 DDP 分別在什么情況下使用,以及各自的使用方法。以及 DDP 的保存和加載模型的策略,和如何同時使用 DDP 和模型并行 (model parallel)。?>>加入極市CV技術(shù)交流群,走在計算機(jī)視覺的最前沿
目錄
1 DP 和 DDP 分別在什么情況下使用
1.1 幾種并行訓(xùn)練的選項
1.2 DP 和 DDP 的比較2 Data Parallel 介紹
2.1 簡介
2.2 用法實例3 Distributed Data Parallel 介紹
3.1 簡介
3.2 用法實例
3.3 保存和加載模型
3.4 與模型并行的結(jié)合 (DDP + model parallel)
這篇文章從應(yīng)用的角度出發(fā),介紹 DP 和 DDP 分別在什么情況下使用,以及各自的使用方法。以及 DDP 的保存和加載模型的策略,和如何同時使用 DDP 和模型并行 (model parallel)。
1 DP 和 DDP 分別在什么情況下使用
1.1 幾種并行訓(xùn)練的選項
PyTorch 提供了幾種并行訓(xùn)練的選項。
如果:(1) 訓(xùn)練速度無所謂。(2) 模型和數(shù)據(jù)能夠 fit 進(jìn)一個 GPU 里面:這種情況建議不要分布式訓(xùn)練。 如果:(1) 想提升訓(xùn)練速度。(2) 非常不想過多地修改代碼。(3) 有1臺機(jī)器 (machine 或者叫做 node) (只能在單機(jī)上使用,俗稱 "單機(jī)多卡"),機(jī)器上有多張 GPU:這種情況建議使用 Data Parallel 分布式訓(xùn)練。 如果:(1) 想進(jìn)一步提升訓(xùn)練速度。(2) 可以適當(dāng)多地修改代碼。(3) 有1臺或者多臺的機(jī)器 (machine 或者叫做 node) (可以在多機(jī)上使用,俗稱 "多機(jī)多卡"),機(jī)器上有多張 GPU:這種情況建議使用 Distributed Data Parallel 分布式訓(xùn)練。
1.2 DP 和 DDP 的比較
Data Parallel:單進(jìn)程,多線程,只能適用于1臺機(jī)器的情況。Distributed Data Parallel:多進(jìn)程,可以適用于多臺機(jī)器的情況。 當(dāng)模型太大,一個 GPU 放不下時,Data Parallel:不能結(jié)合模型并行的方法。Distributed Data Parallel:可以結(jié)合模型并行的方法。
2 Data Parallel 介紹
2.1 簡介
Data Parallel 這種方法允許我們以最小的代碼修改代價實現(xiàn)有1臺機(jī)器上的多張 GPU 的訓(xùn)練。只需要修改1行代碼。但是盡管 Data Parallel 這種方法使用方便,但是 Data Parallel 的性能卻不是最好的。我們先介紹下 torch.nn.DataParallel 這個 PyTorch class。
定義:
CLASStorch.nn.DataParallel(module,device_ids=None,output_device=None,dim=0)
在 module 層面實現(xiàn)數(shù)據(jù)并行。
torch.nn.DataParallel 要輸入一個module ,在前向傳播過程中,這個module會在每個 device 上面復(fù)制一份。同時輸入數(shù)據(jù)在 batch 這個維度被分塊,這些數(shù)據(jù)會被按塊分配在不同的 device 上面。最后形成的局面就是:所有的 GPU 上面都有一樣的module,每個 GPU 都有單獨的數(shù)據(jù)。在反向傳播過程中,每一個 GPU 上得到的 gradient 會匯總到主 GPU (server) 上面。主 GPU (server) 更新參數(shù)之后,還會把新的參數(shù)模型參數(shù) broadcast 到每個其它的 GPU 上面。
DP 使用的是 Parameter Server (PS) 架構(gòu)。 Parameter Server 架構(gòu) (PS 模式) 由 server 節(jié)點和 worker 節(jié)點組成,server 節(jié)點的主要功能是初始化和保存模型參數(shù)、接受 worker 節(jié)點計算出的局部梯度、匯總計算全局梯度,并更新模型參數(shù)。
worker 節(jié)點的主要功能是各自保存部分訓(xùn)練數(shù)據(jù),初始化模型,從 server 節(jié)點拉取最新的模型參數(shù) (pull),再讀取參數(shù),根據(jù)訓(xùn)練數(shù)據(jù)計算局部梯度,上傳給 server 節(jié)點 (push)。
PS 模式下的 DP,會造成負(fù)載不均衡,因為充當(dāng) server 的 GPU 需要一定的顯存用來保存 worker 節(jié)點計算出的局部梯度;另外 server 還需要將更新后的模型參數(shù) broadcast 到每個 worker,server 的帶寬就成了 server 與worker 之間的通信瓶頸,server 與 worker 之間的通信成本會隨著 worker 數(shù)目的增加而線性增加。
所以讀完了以上的分析,自然而然的2個要求就是:
訓(xùn)練的 batch size 要能夠被 GPU 數(shù)量整除。 在使用 DataParallel 之前,輸入的 module必須首先已經(jīng)在device_ids[0]上面了。
下面是2條重要的注意信息:
每次 Forward 的時候, module會在每個 device 上面被淺復(fù)制。也就是說,DataParellel 保證了 device[0] 上的這個 replica (參數(shù)和 buffer) 和其他 device 上的 replica (參數(shù)和 buffer) 擁有著相同的存儲位置。也就是說,只有那些 in-place 的操作才能夠?qū)崿F(xiàn)牽一發(fā)而動全身的效果,即:in-place 操作改變 device[0] 上的某個參數(shù),會改變其他所有 device 上的參數(shù)。常見的 in-place 操作,比如有:[BatchNorm2d](https://link.zhihu.com/?target=https%3A//pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html%23torch.nn.BatchNorm2d)和[spectral_norm()](https://link.zhihu.com/?target=https%3A//pytorch.org/docs/stable/generated/torch.nn.utils.spectral_norm.html%23torch.nn.utils.spectral_norm)。module內(nèi)部定義的 Forward 和 backward hooks,一共會被激活len(device_ids)次。每次激活時輸入就依照當(dāng)前 device 上的 input 執(zhí)行。而且 hook 注冊和激活的順序無法控制。只能保證在當(dāng)前 GPU 上面,[register_forward_pre_hook()](https://link.zhihu.com/?target=https%3A//pytorch.org/docs/stable/generated/torch.nn.Module.html%23torch.nn.Module.register_forward_pre_hook)先于[forward()](https://link.zhihu.com/?target=https%3A//pytorch.org/docs/stable/generated/torch.nn.Module.html%23torch.nn.Module.forward)被執(zhí)行,而無法保證它先于所有的[forward()](https://link.zhihu.com/?target=https%3A//pytorch.org/docs/stable/generated/torch.nn.Module.html%23torch.nn.Module.forward)被執(zhí)行。
參數(shù)定義:
module (Module) – module to be parallelized device_ids (list of python:int or torch.device) – CUDA devices (default: all devices) output_device (int or torch.device) – device location of output (default: device_ids[0])
使用:
net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
output = net(input_var) # input_var can be on any device, including CPU
2.2 用法示例
這一節(jié)通過具體的例子展示 DataParallel 的用法。
1) 首先 Import PyTorch modules 和超參數(shù)。
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
# Parameters and DataLoaders
input_size = 5
output_size = 2
batch_size = 30
data_size = 100
2) 設(shè)置 device。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3) 制作一個dummy (random) dataset,這里我們只需要實現(xiàn) getitem 方法。
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),
batch_size=batch_size, shuffle=True)
4) 制作一個示例模型。
class Model(nn.Module):
# Our model
def __init__(self, input_size, output_size):
super(Model, self).__init__()
self.fc = nn.Linear(input_size, output_size)
def forward(self, input):
output = self.fc(input)
print("\tIn Model: input size", input.size(),
"output size", output.size())
return output
5) 創(chuàng)建 Model 和 DataParallel,首先要把模型實例化,再檢查下我們是否有多塊 GPU。最后是 put model on device:
model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
# dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
model = nn.DataParallel(model)
model.to(device)
輸出:
Let's use 2 GPUs!
6) Run the Model:
for data in rand_loader:
input = data.to(device)
output = model(input)
print("Outside: input size", input.size(),
"output_size", output.size())
輸出:
# on 2 GPUs
Let's use 2 GPUs!
In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2])
In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2])
Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2])
In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2])
Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2])
In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2])
Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
In Model: input size torch.Size([5, 5]) output size torch.Size([5, 2])
In Model: input size torch.Size([5, 5]) output size torch.Size([5, 2])
Outside: input size torch.Size([10, 5]) output_size torch.Size([10, 2])
以上就是 DataParellel 的極簡示例,注意我們并沒有告訴程序我們要使用多少塊 GPU,因為 torch.cuda.device_count() 會自動地計算出當(dāng)前的所有可用的 GPU 數(shù),假設(shè)電腦里面是8塊,那么輸出就會是:
Let's use 8 GPUs!
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
Outside: input size torch.Size([10, 5]) output_size torch.Size([10, 2])
3 Distributed Data Parallel 介紹
3.1 簡介
Distributed Data Parallel 這種方法允許我們在有1臺或者多臺的機(jī)器上分布式訓(xùn)練。與 Data Parallel 的不同之處是:
需要啟動這一步:init_process_group(https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) 模型在創(chuàng)建的時候就已經(jīng)復(fù)制到各個 GPU 上面,而不是在 Forward 函數(shù)里面復(fù)制的。
我們先介紹下 torch.nn.parallel.DistributedDataParallel 這個 PyTorch class。
定義:
CLASStorch.nn.parallel.DistributedDataParallel(module,device_ids=None,output_device=None,dim=0,broadcast_buffers=True,process_group=None,bucket_cap_mb=25,find_unused_parameters=False,check_reduction=False,gradient_as_bucket_view=False)
在 module 層面實現(xiàn)分布式數(shù)據(jù)并行。
torch.nn.DistributedDataParallel
torch.nn.DataParallel 要輸入一個 module ,在模型構(gòu)建的過程中,這個 module會在每個 device 上面復(fù)制一份。同時輸入數(shù)據(jù)在 batch 這個維度被分塊,這些數(shù)據(jù)會被按塊分配在不同的 device 上面。最后形成的局面就是:所有的 GPU 上面都有一樣的 module,每個 GPU 都有單獨的數(shù)據(jù)。在反向傳播過程中,每一個 GPU 上得到的 gradient 會被平均。
使用這個 class 需要torch.distributed的初始化,所以需要調(diào)用 [torch.distributed.init_process_group()](https://link.zhihu.com/?target=https%3A//pytorch.org/docs/stable/distributed.html%23torch.distributed.init_process_group) 。
如果想在一個有 N 個 GPU 的設(shè)備上面使用 DistributedDataParallel,則需要 spawn up N 個進(jìn)程,每個進(jìn)程對應(yīng)0-N-1 的一個 GPU。這可以通過下面的語句實現(xiàn):
torch.cuda.set_device(i)
i from 0-N-1,每個進(jìn)程中都需要:
torch.distributed.init_process_group(
backend='nccl', world_size=N, init_method='...'
)
model = DistributedDataParallel(model, device_ids=[i], output_device=i)
為了在每臺設(shè)備 (節(jié)點) 上建立多個進(jìn)程,我們可以使用torch.distributed.launch或者torch.multiprocessing.spawn。
如果你在一個進(jìn)程中使用 torch.save 來保存模型,并在其他一些進(jìn)程中使用 torch.load 來加載模型,請確保每個進(jìn)程的 map_location 都配置正確。如果沒有 map_location,torch.load 會將從保存的設(shè)備上加載模型。
幾點注意:
減少優(yōu)化器顯存: DistributedDataParallel 可以搭配 [torch.distributed.optim.ZeroRedundancyOptimizer](https://link.zhihu.com/?target=https%3A//pytorch.org/docs/stable/distributed.optim.html%23torch.distributed.optim.ZeroRedundancyOptimizer)一起使用來減少 optimizer states memory,具體這里就不過多介紹,可以參考下面鏈接:
封裝模型: 在用 DistributedDataParallel 封裝模型之后,千萬不要試圖改變你的模型的參數(shù)。因為,當(dāng)用DistributedDataParallel 包裝模型時,DistributedDataParallel 的構(gòu)造函數(shù)會在構(gòu)造時對模型本身的所有參數(shù)注冊額外的梯度還原函數(shù) (gradient reduction functions)。如果你事后改變了模型的參數(shù),梯度還原函數(shù)就沒法再與正確的參數(shù)集匹配。 梯度同步的機(jī)制: DistributedDataParallel 在 module 層面實現(xiàn)了數(shù)據(jù)并行,可以在多臺機(jī)器上運行。使用 DDP 的應(yīng)用程序應(yīng)該 spawn up 多個進(jìn)程,并在每個進(jìn)程中創(chuàng)建一個 DDP 實例。DDP 使用 Torch.distributed 包中的 collective communications 來同步梯度和緩沖區(qū) (synchronize gradients and buffers)。更具體地說,DDP 為model.parameters() 給出的每個參數(shù)注冊了一個 autograd hook,當(dāng)在反向傳播中計算出相應(yīng)的梯度時,該 hook 將被觸發(fā)。然后 DDP 使用該信號來觸發(fā)跨進(jìn)程的梯度同步。
參數(shù)定義:
module (Module) – module to be parallelized device_ids (list of python:int or torch.device) –CUDA devices. 1) For single-device modules, device_idscan contain exactly one device id, which represents the only CUDA device where the input module corresponding to this process resides. Alternatively,device_idscan also beNone.2) For multi-device modules and CPU modules, device_idsmust beNone.
Whendevice_idsisNonefor both cases, both the input data for the forward pass and the actual module must be placed on the correct device. (default:None)output_device (int or torch.device) – Device location of output for single-device CUDA modules. For multi-device modules and CPU modules, it must be None, and the module itself dictates the output location. (default:device_ids[0]for single-device modules)broadcast_buffers (bool) – Flag that enables syncing (broadcasting) buffers of the module at beginning of the forwardfunction. (default:True)process_group – The process group to be used for distributed data all-reduction. If None, the default process group, which is created by[torch.distributed.init_process_group()](https://link.zhihu.com/?target=https%3A//pytorch.org/docs/stable/distributed.html%23torch.distributed.init_process_group), will be used. (default:None)bucket_cap_mb – DistributedDataParallelwill bucket parameters into multiple buckets so that gradient reduction of each bucket can potentially overlap with backward computation.bucket_cap_mbcontrols the bucket size in MegaBytes (MB). (default: 25)find_unused_parameters (bool) – Traverse the autograd graph from all tensors contained in the return value of the wrapped module’s forwardfunction. Parameters that don’t receive gradients as part of this graph are preemptively marked as being ready to be reduced. In addition, parameters that may have been used in the wrapped module’sforwardfunction but were not part of loss computation and thus would also not receive gradients are preemptively marked as ready to be reduced. (default:False)check_reduction – This argument is deprecated. gradient_as_bucket_view (bool) – When set to True, gradients will be views pointing to different offsets ofallreducecommunication buckets. This can reduce peak memory usage, where the saved memory size will be equal to the total gradients size. Moreover, it avoids the overhead of copying between gradients andallreducecommunication buckets. When gradients are views,detach_()cannot be called on the gradients. If hitting such errors, please fix it by referring to the[zero_grad()](https://link.zhihu.com/?target=https%3A//pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html%23torch.optim.Optimizer.zero_grad)function intorch/optim/optimizer.pyas a solution.
3.2 用法示例
這一節(jié)通過具體的例子展示 DistributedDataParallel 的用法,這個例子假設(shè)我們有一個8卡 GPU。
1) 首先初始化進(jìn)程:
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
# On Windows platform, the torch.distributed package only
# supports Gloo backend, FileStore and TcpStore.
# For FileStore, set init_method parameter in init_process_group
# to a local file. Example as follow:
# init_method="file:///f:/libtmp/some_file"
# dist.init_process_group(
# "gloo",
# rank=rank,
# init_method=init_method,
# world_size=world_size)
# For TcpStore, same way as on Linux.
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
2) 創(chuàng)建一個 toy module,叫它 ToyModel,用 DDP 去包裹它。注意,由于 DDP 在構(gòu)造函數(shù)中把模型狀態(tài)從第rank 0 的進(jìn)程廣播給所有其他進(jìn)程,所以我們無需擔(dān)心不同的 DDP 進(jìn)程從不同的參數(shù)初始值啟動。PyTorch提供了mp.spawn來在一個節(jié)點啟動該節(jié)點所有進(jìn)程,每個進(jìn)程運行train(i, args),其中i從0到args.gpus \- 1。所以有以下 code。
執(zhí)行代碼時,GPU 數(shù)和進(jìn)程數(shù)都是 world_size。
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
3.3 保存和加載模型
當(dāng)使用 DDP 時,我們只在一個進(jìn)程中保存模型,然后將其加載到所有進(jìn)程中,以減少寫的開銷。這也很好理解,因為所有進(jìn)程從相同的參數(shù)開始,梯度在后向傳遞中是同步的,因此,所有進(jìn)程的梯度是相同的。所以讀者請確保所有進(jìn)程在保存完成之前不要開始加載。此外,在加載模塊時,我們需要提供一個適當(dāng)?shù)?map_location 參數(shù),以防止一個 process 踏入其他進(jìn)程的設(shè)備。如果缺少 map_location,torch.load 將首先把 module 加載到 CPU,然后把每個參數(shù)復(fù)制到它被保存的地方,這將導(dǎo)致同一臺機(jī)器上的所有進(jìn)程使用同一組設(shè)備。
def demo_checkpoint(rank, world_size):
print(f"Running DDP checkpoint example on rank {rank}.")
setup(rank, world_size)
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
if rank == 0:
# All processes should see same parameters as they all start from same
# random parameters and gradients are synchronized in backward passes.
# Therefore, saving it in one process is sufficient.
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
# Use a barrier() to make sure that process 1 loads the model after process
# 0 saves it.
dist.barrier()
# configure map_location properly
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
ddp_model.load_state_dict(
torch.load(CHECKPOINT_PATH, map_location=map_location))
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn = nn.MSELoss()
loss_fn(outputs, labels).backward()
optimizer.step()
# Not necessary to use a dist.barrier() to guard the file deletion below
# as the AllReduce ops in the backward pass of DDP already served as
# a synchronization.
if rank == 0:
os.remove(CHECKPOINT_PATH)
cleanup()
3.4 與模型并行的結(jié)合 (DDP + model parallel)
有關(guān)模型并行的介紹可以參考:
DDP 也適用于 multi-GPU 模型。DDP 包裹著 multi-GPU 模型,在用海量數(shù)據(jù)訓(xùn)練大型模型時特別有幫助。
class ToyMpModel(nn.Module):
def __init__(self, dev0, dev1):
super(ToyMpModel, self).__init__()
self.dev0 = dev0
self.dev1 = dev1
self.net1 = torch.nn.Linear(10, 10).to(dev0)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 5).to(dev1)
def forward(self, x):
x = x.to(self.dev0)
x = self.relu(self.net1(x))
x = x.to(self.dev1)
return self.net2(x)
當(dāng)把一個 multi-GPU 模型傳遞給 DDP 時,device_ids 和 output_device 不能被設(shè)置。輸入和輸出數(shù)據(jù)將被應(yīng)用程序或模型 forward() 方法放在適當(dāng)?shù)脑O(shè)備中。
def demo_model_parallel(rank, world_size):
print(f"Running DDP with model parallel example on rank {rank}.")
setup(rank, world_size)
# setup mp_model and devices for this process
dev0 = (rank * 2) % world_size
dev1 = (rank * 2 + 1) % world_size
mp_model = ToyMpModel(dev0, dev1)
ddp_mp_model = DDP(mp_model)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)
optimizer.zero_grad()
# outputs will be on dev1
outputs = ddp_mp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(dev1)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run_demo(demo_basic, world_size)
run_demo(demo_checkpoint, world_size)
run_demo(demo_model_parallel, world_size)
參考:
https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
https://pytorch.org/docs/stable/notes/ddp.html
如果覺得有用,就請分享到朋友圈吧!
公眾號后臺回復(fù)“transformer”獲取最新Transformer綜述論文下載~

#?極市平臺簽約作者#
科技猛獸
知乎:科技猛獸
清華大學(xué)自動化系19級碩士
研究領(lǐng)域:AI邊緣計算 (Efficient AI with Tiny Resource):專注模型壓縮,搜索,量化,加速,加法網(wǎng)絡(luò),以及它們與其他任務(wù)的結(jié)合,更好地服務(wù)于端側(cè)設(shè)備。
作品精選


