<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          詳解PyTorch中的ModuleList和Sequential

          共 10160字,需瀏覽 21分鐘

           ·

          2021-04-23 23:20

          ↑ 點(diǎn)擊藍(lán)字 關(guān)注極市平臺(tái)

          作者丨小占同學(xué)@知乎(已授權(quán))
          來(lái)源丨h(huán)ttps://zhuanlan.zhihu.com/p/75206669
          編輯丨極市平臺(tái)

          極市導(dǎo)讀

           

          本文詳細(xì)講解了PyTorch中的nn.Sequential和nn.ModuleList兩個(gè)模塊。 >>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺的最前沿

          在使用PyTorch的時(shí)候,經(jīng)常遇到nn.Sequential和nn.ModuleList,今天將這兩個(gè)模塊認(rèn)真區(qū)分了一下,總結(jié)如下。PyTorch版本為1.0.0。本文也會(huì)隨著本人逐漸深入Torch和有新的體會(huì)時(shí),會(huì)進(jìn)行更新。

          本人才疏學(xué)淺,希望各位看官不吝賜教。

          一、官方文檔

          首先看官方文檔的解釋,僅列出了容器(Containers)中幾個(gè)比較常用的CLASS。

          CLASS torch.nn.Module

          Base class for all neural network modules.

          Your models should also subclass this class.

          import torch.nn as nnimport torch.nn.functional as F
          class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5)
          def forward(self, x): x = F.relu(conv1(x)) return F.relu(conv2(x))

          CLASS torch.nn.Sequential(*args)

          A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in.

          # Example of using Sequentialmodel = nn.Sequential(    nn.Conv2d(1, 20, 5),    nn.ReLU(),    nn.Conv2d(20, 64, 5),    nn.ReLU()    )# Example of using Sequential with OrderedDictmodel = nn.Sequential(OrderedDict([    ('conv1', nn.Conv2d(1, 20, 5)),    ('ReLU1', nn.ReLU()),    ('conv2', nn.Conv2d(20, 64, 5)),    ('ReLU2', nn.ReLU())    ]))

          CLASS torch.nn.ModuleList(modules=None)

          Holds submodules in a list.

          [ModuleList] can be indexed like a regular Python list, but modules it contains are properly registered, and will be visible by all [Module] methods.

          ModuleList:https://pytorch.org/docs/stable/nn.html#torch.nn.ModuleList

          Module:https://pytorch.org/docs/stable/nn.html#torch.nn.Module

          class MyModel(nn.Module):    def __init__(self):        super(MyModel, self).__init__()        self.linears = nn.ModuleList([nn.linear for i in range(10)])
          # ModuleList can act as an iterable, or be indexed using ints def forward(self, x): for i, l in enumerate(self.linears): x = self.linears[i // 2](x) + l(x) return x

          二、nn.Sequential與nn.ModuleList簡(jiǎn)介

          nn.Sequential

          nn.Sequential里面的模塊按照順序進(jìn)行排列的,所以必須確保前一個(gè)模塊的輸出大小和下一個(gè)模塊的輸入大小是一致的。如下面的例子所示:

          #首先導(dǎo)入torch相關(guān)包import torchimport torch.nn as nnimport torch.nn.functional as Fclass net_seq(nn.Module):    def __init__(self):        super(net2, self).__init__()        self.seq = nn.Sequential(                        nn.Conv2d(1,20,5),                         nn.ReLU(),                          nn.Conv2d(20,64,5),                       nn.ReLU()                       )          def forward(self, x):        return self.seq(x)net_seq = net_seq()print(net_seq)#net_seq(#  (seq): Sequential(#    (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))#    (1): ReLU()#    (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))#    (3): ReLU()#  )#)

          nn.Sequential中可以使用OrderedDict來(lái)指定每個(gè)module的名字,而不是采用默認(rèn)的命名方式(按序號(hào) 0,1,2,3...)。例子如下:

          from collections import OrderedDict
          class net_seq(nn.Module): def __init__(self): super(net_seq, self).__init__() self.seq = nn.Sequential(OrderedDict([ ('conv1', nn.Conv2d(1,20,5)), ('relu1', nn.ReLU()), ('conv2', nn.Conv2d(20,64,5)), ('relu2', nn.ReLU()) ])) def forward(self, x): return self.seq(x)net_seq = net_seq()print(net_seq)#net_seq(# (seq): Sequential(# (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (relu1): ReLU()# (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (relu2): ReLU()# )#)
          nn.ModuleList

          nn.ModuleList,它是一個(gè)儲(chǔ)存不同 module,并自動(dòng)將每個(gè) module 的 parameters 添加到網(wǎng)絡(luò)之中的容器。你可以把任意 nn.Module 的子類 (比如 nn.Conv2d, nn.Linear 之類的) 加到這個(gè) list 里面,方法和 Python 自帶的 list 一樣,無(wú)非是 extend,append 等操作。但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是會(huì)自動(dòng)注冊(cè)到整個(gè)網(wǎng)絡(luò)上的,同時(shí) module 的 parameters 也會(huì)自動(dòng)添加到整個(gè)網(wǎng)絡(luò)中。若使用python的list,則會(huì)出問題。下面看一個(gè)例子:

          class net_modlist(nn.Module):    def __init__(self):        super(net_modlist, self).__init__()        self.modlist = nn.ModuleList([                       nn.Conv2d(1, 20, 5),                       nn.ReLU(),                        nn.Conv2d(20, 64, 5),                        nn.ReLU()                        ])
          def forward(self, x): for m in self.modlist: x = m(x) return x
          net_modlist = net_modlist()print(net_modlist)#net_modlist(# (modlist): ModuleList(# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (1): ReLU()# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (3): ReLU()# )#)
          for param in net_modlist.parameters(): print(type(param.data), param.size())#<class 'torch.Tensor'> torch.Size([20, 1, 5, 5])#<class 'torch.Tensor'> torch.Size([20])#<class 'torch.Tensor'> torch.Size([64, 20, 5, 5])#<class 'torch.Tensor'> torch.Size([64])

          可以看到,這個(gè)網(wǎng)絡(luò)權(quán)重 (weithgs) 和偏置 (bias) 都在這個(gè)網(wǎng)絡(luò)之內(nèi)。接下來(lái)看看另一個(gè)作為對(duì)比的網(wǎng)絡(luò),它使用 Python 自帶的 list:

          class net_modlist(nn.Module):    def __init__(self):        super(net_modlist, self).__init__()        self.modlist = [                       nn.Conv2d(1, 20, 5),                       nn.ReLU(),                        nn.Conv2d(20, 64, 5),                        nn.ReLU()                        ]
          def forward(self, x): for m in self.modlist: x = m(x) return x
          net_modlist = net_modlist()print(net_modlist)#net_modlist()for param in net_modlist.parameters(): print(type(param.data), param.size())#None

          顯然,使用 Python 的 list 添加的卷積層和它們的 parameters 并沒有自動(dòng)注冊(cè)到我們的網(wǎng)絡(luò)中。當(dāng)然,我們還是可以使用 forward 來(lái)計(jì)算輸出結(jié)果。但是如果用其實(shí)例化的網(wǎng)絡(luò)進(jìn)行訓(xùn)練的時(shí)候,因?yàn)檫@些層的parameters不在整個(gè)網(wǎng)絡(luò)之中,所以其網(wǎng)絡(luò)參數(shù)也不會(huì)被更新,也就是無(wú)法訓(xùn)練。

          三、nn.Sequential與nn.ModuleList的區(qū)別

          不同點(diǎn)1:

          nn.Sequential內(nèi)部實(shí)現(xiàn)了forward函數(shù),因此可以不用寫forward函數(shù)。而nn.ModuleList則沒有實(shí)現(xiàn)內(nèi)部forward函數(shù)。

          對(duì)于nn.Sequential:

          #例1:這是來(lái)自官方文檔的例子seq = nn.Sequential(          nn.Conv2d(1,20,5),          nn.ReLU(),          nn.Conv2d(20,64,5),          nn.ReLU()        )print(seq)# Sequential(#   (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))#   (1): ReLU()#   (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))#   (3): ReLU()# )
          #對(duì)上述seq進(jìn)行輸入input = torch.randn(16, 1, 20, 20)print(seq(input))#torch.Size([16, 64, 12, 12])
          #例2:或者繼承nn.Module類的話,就要寫出forward函數(shù)class net1(nn.Module): def __init__(self): super(net1, self).__init__() self.seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) def forward(self, x): return self.seq(x)
          #注意:按照下面這種利用for循環(huán)的方式也是可以得到同樣結(jié)果的 #def forward(self, x): # for s in self.seq: # x = s(x) # return x
          #對(duì)net1進(jìn)行輸入input = torch.randn(16, 1, 20, 20)net1 = net1()print(net1(input).shape)#torch.Size([16, 64, 12, 12])

          而對(duì)于nn.ModuleList:

          #例1:若按照下面這么寫,則會(huì)產(chǎn)生錯(cuò)誤modlist = nn.ModuleList([         nn.Conv2d(1, 20, 5),         nn.ReLU(),         nn.Conv2d(20, 64, 5),         nn.ReLU()         ])print(modlist)#ModuleList(#  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))#  (1): ReLU()#  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))#  (3): ReLU()#)
          input = torch.randn(16, 1, 20, 20)print(modlist(input))#產(chǎn)生NotImplementedError
          #例2:寫出forward函數(shù)class net2(nn.Module): def __init__(self): super(net2, self).__init__() self.modlist = nn.ModuleList([ nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU() ])
          #這里若按照這種寫法則會(huì)報(bào)NotImplementedError錯(cuò) #def forward(self, x): # return self.modlist(x)
          #注意:只能按照下面利用for循環(huán)的方式 def forward(self, x): for m in self.modlist: x = m(x) return x
          input = torch.randn(16, 1, 20, 20)net2 = net2()print(net2(input).shape)#torch.Size([16, 64, 12, 12])

          如果完全直接用 nn.Sequential,確實(shí)是可以的,但這么做的代價(jià)就是失去了部分靈活性,不能自己去定制 forward 函數(shù)里面的內(nèi)容了。

          一般情況下 nn.Sequential 的用法是來(lái)組成卷積塊 (block),然后像拼積木一樣把不同的 block 拼成整個(gè)網(wǎng)絡(luò),讓代碼更簡(jiǎn)潔,更加結(jié)構(gòu)化。

          不同點(diǎn)2:

          nn.Sequential可以使用OrderedDict對(duì)每層進(jìn)行命名,上面已經(jīng)闡述過了;

          不同點(diǎn)3:

          nn.Sequential里面的模塊按照順序進(jìn)行排列的,所以必須確保前一個(gè)模塊的輸出大小和下一個(gè)模塊的輸入大小是一致的。而nn.ModuleList 并沒有定義一個(gè)網(wǎng)絡(luò),它只是將不同的模塊儲(chǔ)存在一起,這些模塊之間并沒有什么先后順序可言。見下面代碼:

          class net3(nn.Module):    def __init__(self):        super(net3, self).__init__()        self.linears = nn.ModuleList([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)])    def forward(self, x):        x = self.linears[2](x)        x = self.linears[0](x)        x = self.linears[1](x)
          return x
          net3 = net3()print(net3)#net3(# (linears): ModuleList(# (0): Linear(in_features=10, out_features=20, bias=True)# (1): Linear(in_features=20, out_features=30, bias=True)# (2): Linear(in_features=5, out_features=10, bias=True)# )#)
          input = torch.randn(32, 5)print(net3(input).shape)#torch.Size([32, 30])

          根據(jù) net5 的結(jié)果,可以看出來(lái)這個(gè) ModuleList 里面的順序不能決定什么,網(wǎng)絡(luò)的執(zhí)行順序是根據(jù) forward 函數(shù)來(lái)決定的。若將forward函數(shù)中幾行代碼互換,使輸入輸出之間的大小不一致,則程序會(huì)報(bào)錯(cuò)。此外,為了使代碼具有更高的可讀性,最好把ModuleList和forward中的順序保持一致。

          不同點(diǎn)4:

          有的時(shí)候網(wǎng)絡(luò)中有很多相似或者重復(fù)的層,我們一般會(huì)考慮用 for 循環(huán)來(lái)創(chuàng)建它們,而不是一行一行地寫,比如:

          layers = [nn.Linear(10, 10) for i in range(5)]

          那么這里我們使用ModuleList:

          class net4(nn.Module):    def __init__(self):        super(net4, self).__init__()        layers = [nn.Linear(10, 10) for i in range(5)]        self.linears = nn.ModuleList(layers)
          def forward(self, x): for layer in self.linears: x = layer(x) return x
          net = net4()print(net)# net4(# (linears): ModuleList(# (0): Linear(in_features=10, out_features=10, bias=True)# (1): Linear(in_features=10, out_features=10, bias=True)# (2): Linear(in_features=10, out_features=10, bias=True)# )# )

          參考:

          1. 官方文檔: Container(https://pytorch.org/docs/stable/nn.html#containers)
          2. PyTorch 中的 ModuleList 和 Sequential: 區(qū)別和使用場(chǎng)景(https://zhuanlan.zhihu.com/p/64990232)
          如果覺得有用,就請(qǐng)分享到朋友圈吧!

          △點(diǎn)擊卡片關(guān)注極市平臺(tái),獲取最新CV干貨


          推薦閱讀


          實(shí)操教程|PyTorch自定義CUDA算子教程與運(yùn)行時(shí)間分析

          2021-04-19

          實(shí)操教程|PyTorch AutoGrad C++層實(shí)現(xiàn)

          2021-04-13

          PyTorch 源碼解讀之即時(shí)編譯篇

          2021-04-08



          # CV技術(shù)社群邀請(qǐng)函 #

          △長(zhǎng)按添加極市小助手
          添加極市小助手微信(ID : cvmart2)

          備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測(cè)-深圳)


          即可申請(qǐng)加入極市目標(biāo)檢測(cè)/圖像分割/工業(yè)檢測(cè)/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群


          每月大咖直播分享、真實(shí)項(xiàng)目需求對(duì)接、求職內(nèi)推、算法競(jìng)賽、干貨資訊匯總、與 10000+來(lái)自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動(dòng)交流~




          覺得有用麻煩給個(gè)在看啦~  
          瀏覽 34
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  亚洲欧美在线免费观看 | 亚洲黄色视屏 | 91精品干 | 99久久99视频 | 亚洲天堂在线观看成人 |