詳解PyTorch中的ModuleList和Sequential

極市導(dǎo)讀
本文詳細(xì)講解了PyTorch中的nn.Sequential和nn.ModuleList兩個(gè)模塊。 >>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺的最前沿
在使用PyTorch的時(shí)候,經(jīng)常遇到nn.Sequential和nn.ModuleList,今天將這兩個(gè)模塊認(rèn)真區(qū)分了一下,總結(jié)如下。PyTorch版本為1.0.0。本文也會(huì)隨著本人逐漸深入Torch和有新的體會(huì)時(shí),會(huì)進(jìn)行更新。
本人才疏學(xué)淺,希望各位看官不吝賜教。
一、官方文檔
首先看官方文檔的解釋,僅列出了容器(Containers)中幾個(gè)比較常用的CLASS。
CLASS torch.nn.Module
Base class for all neural network modules.
Your models should also subclass this class.
import torch.nn as nnimport torch.nn.functional as Fclass Model(nn.Module):def __init__(self):super(Model, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x):x = F.relu(conv1(x))return F.relu(conv2(x))
CLASS torch.nn.Sequential(*args)
A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in.
# Example of using Sequentialmodel = nn.Sequential(nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU())# Example of using Sequential with OrderedDictmodel = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(1, 20, 5)),('ReLU1', nn.ReLU()),('conv2', nn.Conv2d(20, 64, 5)),('ReLU2', nn.ReLU())]))
CLASS torch.nn.ModuleList(modules=None)
Holds submodules in a list.
[ModuleList] can be indexed like a regular Python list, but modules it contains are properly registered, and will be visible by all [Module] methods.
ModuleList:https://pytorch.org/docs/stable/nn.html#torch.nn.ModuleList
Module:https://pytorch.org/docs/stable/nn.html#torch.nn.Module
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.linears = nn.ModuleList([nn.linear for i in range(10)])# ModuleList can act as an iterable, or be indexed using intsdef forward(self, x):for i, l in enumerate(self.linears):x = self.linears[i // 2](x) + l(x)return x
二、nn.Sequential與nn.ModuleList簡(jiǎn)介
nn.Sequential
nn.Sequential里面的模塊按照順序進(jìn)行排列的,所以必須確保前一個(gè)模塊的輸出大小和下一個(gè)模塊的輸入大小是一致的。如下面的例子所示:
#首先導(dǎo)入torch相關(guān)包import torchimport torch.nn as nnimport torch.nn.functional as Fclass net_seq(nn.Module):def __init__(self):super(net2, self).__init__()self.seq = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())def forward(self, x):return self.seq(x)net_seq = net_seq()print(net_seq)#net_seq(# (seq): Sequential(# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (1): ReLU()# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (3): ReLU()# )#)
nn.Sequential中可以使用OrderedDict來(lái)指定每個(gè)module的名字,而不是采用默認(rèn)的命名方式(按序號(hào) 0,1,2,3...)。例子如下:
from collections import OrderedDictclass net_seq(nn.Module):def __init__(self):super(net_seq, self).__init__()self.seq = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(1,20,5)),('relu1', nn.ReLU()),('conv2', nn.Conv2d(20,64,5)),('relu2', nn.ReLU())]))def forward(self, x):return self.seq(x)net_seq = net_seq()print(net_seq)#net_seq(# (seq): Sequential(# (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (relu1): ReLU()# (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (relu2): ReLU()# )#)nn.ModuleList
nn.ModuleList,它是一個(gè)儲(chǔ)存不同 module,并自動(dòng)將每個(gè) module 的 parameters 添加到網(wǎng)絡(luò)之中的容器。你可以把任意 nn.Module 的子類 (比如 nn.Conv2d, nn.Linear 之類的) 加到這個(gè) list 里面,方法和 Python 自帶的 list 一樣,無(wú)非是 extend,append 等操作。但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是會(huì)自動(dòng)注冊(cè)到整個(gè)網(wǎng)絡(luò)上的,同時(shí) module 的 parameters 也會(huì)自動(dòng)添加到整個(gè)網(wǎng)絡(luò)中。若使用python的list,則會(huì)出問題。下面看一個(gè)例子:
class net_modlist(nn.Module):def __init__(self):super(net_modlist, self).__init__()self.modlist = nn.ModuleList([nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU()])def forward(self, x):for m in self.modlist:x = m(x)return xnet_modlist = net_modlist()print(net_modlist)#net_modlist(# (modlist): ModuleList(# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (1): ReLU()# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (3): ReLU()# )#)for param in net_modlist.parameters():print(type(param.data), param.size())#<class 'torch.Tensor'> torch.Size([20, 1, 5, 5])#<class 'torch.Tensor'> torch.Size([20])#<class 'torch.Tensor'> torch.Size([64, 20, 5, 5])#<class 'torch.Tensor'> torch.Size([64])
可以看到,這個(gè)網(wǎng)絡(luò)權(quán)重 (weithgs) 和偏置 (bias) 都在這個(gè)網(wǎng)絡(luò)之內(nèi)。接下來(lái)看看另一個(gè)作為對(duì)比的網(wǎng)絡(luò),它使用 Python 自帶的 list:
class net_modlist(nn.Module):def __init__(self):super(net_modlist, self).__init__()self.modlist = [nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU()]def forward(self, x):for m in self.modlist:x = m(x)return xnet_modlist = net_modlist()print(net_modlist)#net_modlist()for param in net_modlist.parameters():print(type(param.data), param.size())#None
顯然,使用 Python 的 list 添加的卷積層和它們的 parameters 并沒有自動(dòng)注冊(cè)到我們的網(wǎng)絡(luò)中。當(dāng)然,我們還是可以使用 forward 來(lái)計(jì)算輸出結(jié)果。但是如果用其實(shí)例化的網(wǎng)絡(luò)進(jìn)行訓(xùn)練的時(shí)候,因?yàn)檫@些層的parameters不在整個(gè)網(wǎng)絡(luò)之中,所以其網(wǎng)絡(luò)參數(shù)也不會(huì)被更新,也就是無(wú)法訓(xùn)練。
三、nn.Sequential與nn.ModuleList的區(qū)別
不同點(diǎn)1:
nn.Sequential內(nèi)部實(shí)現(xiàn)了forward函數(shù),因此可以不用寫forward函數(shù)。而nn.ModuleList則沒有實(shí)現(xiàn)內(nèi)部forward函數(shù)。
對(duì)于nn.Sequential:
#例1:這是來(lái)自官方文檔的例子seq = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())print(seq)# Sequential(# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (1): ReLU()# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (3): ReLU()# )#對(duì)上述seq進(jìn)行輸入input = torch.randn(16, 1, 20, 20)print(seq(input))#torch.Size([16, 64, 12, 12])#例2:或者繼承nn.Module類的話,就要寫出forward函數(shù)class net1(nn.Module):def __init__(self):super(net1, self).__init__()self.seq = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())def forward(self, x):return self.seq(x)#注意:按照下面這種利用for循環(huán)的方式也是可以得到同樣結(jié)果的#def forward(self, x):# for s in self.seq:# x = s(x)# return x#對(duì)net1進(jìn)行輸入input = torch.randn(16, 1, 20, 20)net1 = net1()print(net1(input).shape)#torch.Size([16, 64, 12, 12])
而對(duì)于nn.ModuleList:
#例1:若按照下面這么寫,則會(huì)產(chǎn)生錯(cuò)誤modlist = nn.ModuleList([nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU()])print(modlist)#ModuleList(# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (1): ReLU()# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (3): ReLU()#)input = torch.randn(16, 1, 20, 20)print(modlist(input))#產(chǎn)生NotImplementedError#例2:寫出forward函數(shù)class net2(nn.Module):def __init__(self):super(net2, self).__init__()self.modlist = nn.ModuleList([nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU()])#這里若按照這種寫法則會(huì)報(bào)NotImplementedError錯(cuò)#def forward(self, x):# return self.modlist(x)#注意:只能按照下面利用for循環(huán)的方式def forward(self, x):for m in self.modlist:x = m(x)return xinput = torch.randn(16, 1, 20, 20)net2 = net2()print(net2(input).shape)#torch.Size([16, 64, 12, 12])
如果完全直接用 nn.Sequential,確實(shí)是可以的,但這么做的代價(jià)就是失去了部分靈活性,不能自己去定制 forward 函數(shù)里面的內(nèi)容了。
一般情況下 nn.Sequential 的用法是來(lái)組成卷積塊 (block),然后像拼積木一樣把不同的 block 拼成整個(gè)網(wǎng)絡(luò),讓代碼更簡(jiǎn)潔,更加結(jié)構(gòu)化。
不同點(diǎn)2:
nn.Sequential可以使用OrderedDict對(duì)每層進(jìn)行命名,上面已經(jīng)闡述過了;
不同點(diǎn)3:
nn.Sequential里面的模塊按照順序進(jìn)行排列的,所以必須確保前一個(gè)模塊的輸出大小和下一個(gè)模塊的輸入大小是一致的。而nn.ModuleList 并沒有定義一個(gè)網(wǎng)絡(luò),它只是將不同的模塊儲(chǔ)存在一起,這些模塊之間并沒有什么先后順序可言。見下面代碼:
class net3(nn.Module):def __init__(self):super(net3, self).__init__()self.linears = nn.ModuleList([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)])def forward(self, x):x = self.linears[2](x)x = self.linears[0](x)x = self.linears[1](x)return xnet3 = net3()print(net3)#net3(# (linears): ModuleList(# (0): Linear(in_features=10, out_features=20, bias=True)# (1): Linear(in_features=20, out_features=30, bias=True)# (2): Linear(in_features=5, out_features=10, bias=True)# )#)input = torch.randn(32, 5)print(net3(input).shape)#torch.Size([32, 30])
根據(jù) net5 的結(jié)果,可以看出來(lái)這個(gè) ModuleList 里面的順序不能決定什么,網(wǎng)絡(luò)的執(zhí)行順序是根據(jù) forward 函數(shù)來(lái)決定的。若將forward函數(shù)中幾行代碼互換,使輸入輸出之間的大小不一致,則程序會(huì)報(bào)錯(cuò)。此外,為了使代碼具有更高的可讀性,最好把ModuleList和forward中的順序保持一致。
不同點(diǎn)4:
有的時(shí)候網(wǎng)絡(luò)中有很多相似或者重復(fù)的層,我們一般會(huì)考慮用 for 循環(huán)來(lái)創(chuàng)建它們,而不是一行一行地寫,比如:
layers = [nn.Linear(10, 10) for i in range(5)]
那么這里我們使用ModuleList:
class net4(nn.Module):def __init__(self):super(net4, self).__init__()layers = [nn.Linear(10, 10) for i in range(5)]self.linears = nn.ModuleList(layers)def forward(self, x):for layer in self.linears:x = layer(x)return xnet = net4()print(net)# net4(# (linears): ModuleList(# (0): Linear(in_features=10, out_features=10, bias=True)# (1): Linear(in_features=10, out_features=10, bias=True)# (2): Linear(in_features=10, out_features=10, bias=True)# )# )
參考:
官方文檔: Container(https://pytorch.org/docs/stable/nn.html#containers) PyTorch 中的 ModuleList 和 Sequential: 區(qū)別和使用場(chǎng)景(https://zhuanlan.zhihu.com/p/64990232)
△點(diǎn)擊卡片關(guān)注極市平臺(tái),獲取最新CV干貨
推薦閱讀
2021-04-19
2021-04-13
2021-04-08

# CV技術(shù)社群邀請(qǐng)函 #
備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測(cè)-深圳)
即可申請(qǐng)加入極市目標(biāo)檢測(cè)/圖像分割/工業(yè)檢測(cè)/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群
每月大咖直播分享、真實(shí)項(xiàng)目需求對(duì)接、求職內(nèi)推、算法競(jìng)賽、干貨資訊匯總、與 10000+來(lái)自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動(dòng)交流~

