詳解PyTorch中的ModuleList和Sequential
點(diǎn)擊上方“視學(xué)算法”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時間送達(dá)
導(dǎo)讀
本文詳細(xì)講解了PyTorch中的nn.Sequential和nn.ModuleList兩個模塊。
在使用PyTorch的時候,經(jīng)常遇到nn.Sequential和nn.ModuleList,今天將這兩個模塊認(rèn)真區(qū)分了一下,總結(jié)如下。PyTorch版本為1.0.0。本文也會隨著本人逐漸深入Torch和有新的體會時,會進(jìn)行更新。
本人才疏學(xué)淺,希望各位看官不吝賜教。
一、官方文檔
首先看官方文檔的解釋,僅列出了容器(Containers)中幾個比較常用的CLASS。
CLASS torch.nn.Module
Base class for all neural network modules.
Your models should also subclass this class.
import torch.nn as nnimport torch.nn.functional as Fclass Model(nn.Module):def __init__(self):super(Model, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x):x = F.relu(conv1(x))return F.relu(conv2(x))
CLASS torch.nn.Sequential(*args)
A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in.
# Example of using Sequentialmodel = nn.Sequential(nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU())# Example of using Sequential with OrderedDictmodel = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(1, 20, 5)),('ReLU1', nn.ReLU()),('conv2', nn.Conv2d(20, 64, 5)),('ReLU2', nn.ReLU())]))
CLASS torch.nn.ModuleList(modules=None)
Holds submodules in a list.
[ModuleList] can be indexed like a regular Python list, but modules it contains are properly registered, and will be visible by all [Module] methods.
ModuleList:https://pytorch.org/docs/stable/nn.html#torch.nn.ModuleList
Module:https://pytorch.org/docs/stable/nn.html#torch.nn.Module
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.linears = nn.ModuleList([nn.linear for i in range(10)])# ModuleList can act as an iterable, or be indexed using intsdef forward(self, x):for i, l in enumerate(self.linears):x = self.linears[i // 2](x) + l(x)return x
二、nn.Sequential與nn.ModuleList簡介
nn.Sequential
nn.Sequential里面的模塊按照順序進(jìn)行排列的,所以必須確保前一個模塊的輸出大小和下一個模塊的輸入大小是一致的。如下面的例子所示:
#首先導(dǎo)入torch相關(guān)包import torchimport torch.nn as nnimport torch.nn.functional as Fclass net_seq(nn.Module):def __init__(self):super(net2, self).__init__()self.seq = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())def forward(self, x):return self.seq(x)net_seq = net_seq()print(net_seq)#net_seq(# (seq): Sequential(# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (1): ReLU()# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (3): ReLU()# )#)
nn.Sequential中可以使用OrderedDict來指定每個module的名字,而不是采用默認(rèn)的命名方式(按序號 0,1,2,3...)。例子如下:
from collections import OrderedDictclass net_seq(nn.Module):def __init__(self):super(net_seq, self).__init__()self.seq = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(1,20,5)),('relu1', nn.ReLU()),('conv2', nn.Conv2d(20,64,5)),('relu2', nn.ReLU())]))def forward(self, x):return self.seq(x)net_seq = net_seq()print(net_seq)#net_seq(# (seq): Sequential(# (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (relu1): ReLU()# (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (relu2): ReLU()# )#)nn.ModuleList
nn.ModuleList,它是一個儲存不同 module,并自動將每個 module 的 parameters 添加到網(wǎng)絡(luò)之中的容器。你可以把任意 nn.Module 的子類 (比如 nn.Conv2d, nn.Linear 之類的) 加到這個 list 里面,方法和 Python 自帶的 list 一樣,無非是 extend,append 等操作。但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是會自動注冊到整個網(wǎng)絡(luò)上的,同時 module 的 parameters 也會自動添加到整個網(wǎng)絡(luò)中。若使用python的list,則會出問題。下面看一個例子:
class net_modlist(nn.Module):def __init__(self):super(net_modlist, self).__init__()self.modlist = nn.ModuleList([nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU()])def forward(self, x):for m in self.modlist:x = m(x)return xnet_modlist = net_modlist()print(net_modlist)#net_modlist(# (modlist): ModuleList(# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (1): ReLU()# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (3): ReLU()# )#)for param in net_modlist.parameters():print(type(param.data), param.size())#<class 'torch.Tensor'> torch.Size([20, 1, 5, 5])#<class 'torch.Tensor'> torch.Size([20])#<class 'torch.Tensor'> torch.Size([64, 20, 5, 5])#<class 'torch.Tensor'> torch.Size([64])
可以看到,這個網(wǎng)絡(luò)權(quán)重 (weithgs) 和偏置 (bias) 都在這個網(wǎng)絡(luò)之內(nèi)。接下來看看另一個作為對比的網(wǎng)絡(luò),它使用 Python 自帶的 list:
class net_modlist(nn.Module):def __init__(self):super(net_modlist, self).__init__()self.modlist = [nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU()]def forward(self, x):for m in self.modlist:x = m(x)return xnet_modlist = net_modlist()print(net_modlist)#net_modlist()for param in net_modlist.parameters():print(type(param.data), param.size())#None
顯然,使用 Python 的 list 添加的卷積層和它們的 parameters 并沒有自動注冊到我們的網(wǎng)絡(luò)中。當(dāng)然,我們還是可以使用 forward 來計(jì)算輸出結(jié)果。但是如果用其實(shí)例化的網(wǎng)絡(luò)進(jìn)行訓(xùn)練的時候,因?yàn)檫@些層的parameters不在整個網(wǎng)絡(luò)之中,所以其網(wǎng)絡(luò)參數(shù)也不會被更新,也就是無法訓(xùn)練。
三、nn.Sequential與nn.ModuleList的區(qū)別
不同點(diǎn)1:
nn.Sequential內(nèi)部實(shí)現(xiàn)了forward函數(shù),因此可以不用寫forward函數(shù)。而nn.ModuleList則沒有實(shí)現(xiàn)內(nèi)部forward函數(shù)。
對于nn.Sequential:
#例1:這是來自官方文檔的例子seq = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())print(seq)# Sequential(# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (1): ReLU()# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (3): ReLU()# )#對上述seq進(jìn)行輸入input = torch.randn(16, 1, 20, 20)print(seq(input))#torch.Size([16, 64, 12, 12])#例2:或者繼承nn.Module類的話,就要寫出forward函數(shù)class net1(nn.Module):def __init__(self):super(net1, self).__init__()self.seq = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())def forward(self, x):return self.seq(x)#注意:按照下面這種利用for循環(huán)的方式也是可以得到同樣結(jié)果的#def forward(self, x):# for s in self.seq:# x = s(x)# return x#對net1進(jìn)行輸入input = torch.randn(16, 1, 20, 20)net1 = net1()print(net1(input).shape)#torch.Size([16, 64, 12, 12])
而對于nn.ModuleList:
#例1:若按照下面這么寫,則會產(chǎn)生錯誤modlist = nn.ModuleList([nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU()])print(modlist)#ModuleList(# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (1): ReLU()# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (3): ReLU()#)input = torch.randn(16, 1, 20, 20)print(modlist(input))#產(chǎn)生NotImplementedError#例2:寫出forward函數(shù)class net2(nn.Module):def __init__(self):super(net2, self).__init__()self.modlist = nn.ModuleList([nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU()])#這里若按照這種寫法則會報NotImplementedError錯#def forward(self, x):# return self.modlist(x)#注意:只能按照下面利用for循環(huán)的方式def forward(self, x):for m in self.modlist:x = m(x)return xinput = torch.randn(16, 1, 20, 20)net2 = net2()print(net2(input).shape)#torch.Size([16, 64, 12, 12])
如果完全直接用 nn.Sequential,確實(shí)是可以的,但這么做的代價就是失去了部分靈活性,不能自己去定制 forward 函數(shù)里面的內(nèi)容了。
一般情況下 nn.Sequential 的用法是來組成卷積塊 (block),然后像拼積木一樣把不同的 block 拼成整個網(wǎng)絡(luò),讓代碼更簡潔,更加結(jié)構(gòu)化。
不同點(diǎn)2:
nn.Sequential可以使用OrderedDict對每層進(jìn)行命名,上面已經(jīng)闡述過了;
不同點(diǎn)3:
nn.Sequential里面的模塊按照順序進(jìn)行排列的,所以必須確保前一個模塊的輸出大小和下一個模塊的輸入大小是一致的。而nn.ModuleList 并沒有定義一個網(wǎng)絡(luò),它只是將不同的模塊儲存在一起,這些模塊之間并沒有什么先后順序可言。見下面代碼:
class net3(nn.Module):def __init__(self):super(net3, self).__init__()self.linears = nn.ModuleList([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)])def forward(self, x):x = self.linears[2](x)x = self.linears[0](x)x = self.linears[1](x)return xnet3 = net3()print(net3)#net3(# (linears): ModuleList(# (0): Linear(in_features=10, out_features=20, bias=True)# (1): Linear(in_features=20, out_features=30, bias=True)# (2): Linear(in_features=5, out_features=10, bias=True)# )#)input = torch.randn(32, 5)print(net3(input).shape)#torch.Size([32, 30])
根據(jù) net5 的結(jié)果,可以看出來這個 ModuleList 里面的順序不能決定什么,網(wǎng)絡(luò)的執(zhí)行順序是根據(jù) forward 函數(shù)來決定的。若將forward函數(shù)中幾行代碼互換,使輸入輸出之間的大小不一致,則程序會報錯。此外,為了使代碼具有更高的可讀性,最好把ModuleList和forward中的順序保持一致。
不同點(diǎn)4:
有的時候網(wǎng)絡(luò)中有很多相似或者重復(fù)的層,我們一般會考慮用 for 循環(huán)來創(chuàng)建它們,而不是一行一行地寫,比如:
layers = [nn.Linear(10, 10) for i in range(5)]
那么這里我們使用ModuleList:
class net4(nn.Module):def __init__(self):super(net4, self).__init__()layers = [nn.Linear(10, 10) for i in range(5)]self.linears = nn.ModuleList(layers)def forward(self, x):for layer in self.linears:x = layer(x)return xnet = net4()print(net)# net4(# (linears): ModuleList(# (0): Linear(in_features=10, out_features=10, bias=True)# (1): Linear(in_features=10, out_features=10, bias=True)# (2): Linear(in_features=10, out_features=10, bias=True)# )# )
參考:
官方文檔: Container(https://pytorch.org/docs/stable/nn.html#containers) PyTorch 中的 ModuleList 和 Sequential: 區(qū)別和使用場景(https://zhuanlan.zhihu.com/p/64990232)

點(diǎn)個在看 paper不斷!
