<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          【小白學(xué)習(xí)PyTorch教程】四、基于nn.Module類實(shí)現(xiàn)線性回歸模型

          共 3993字,需瀏覽 8分鐘

           ·

          2021-08-10 02:55

          「@Author:Runsen」

          上次介紹了順序模型,但是在大多數(shù)情況下,我們基本都是以類的形式實(shí)現(xiàn)神經(jīng)網(wǎng)絡(luò)。

          大多數(shù)情況下創(chuàng)建一個(gè)繼承自 Pytorch 中的 nn.Module 的類,這樣可以使用 Pytorch 提供的許多高級(jí) API,而無需自己實(shí)現(xiàn)。

          下面展示了一個(gè)可以從nn.Module創(chuàng)建的最簡單的神經(jīng)網(wǎng)絡(luò)類的示例?;?nn.Module的類的最低要求是覆蓋__init__()方法和forward()方法。

          在這個(gè)類中,定義了一個(gè)簡單的線性網(wǎng)絡(luò),具有兩個(gè)輸入和一個(gè)輸出,并使用 Sigmoid()函數(shù)作為網(wǎng)絡(luò)的激活函數(shù)。

          import torch
          from torch import nn

          class LinearRegression(nn.Module):
              def __init__(self):
                  #繼承父類構(gòu)造函數(shù)
                  super(LinearRegression, self).__init__() 
                  #輸入和輸出的維度都是1
                  self.linear = nn.Linear(11
              def forward(self, x):
                  out = self.linear(x)
                  return out
                              

          現(xiàn)在讓我們測試一下模型。

          # 創(chuàng)建LinearRegression()的實(shí)例
          model = LinearRegression()
          print(model) 
          # 輸出如下
          LinearRegression(
            (linear): Linear(in_features=1, out_features=1, bias=True)
          )

          現(xiàn)在讓定義一個(gè)損失函數(shù)和優(yōu)化函數(shù)。

          model = LinearRegression()#實(shí)例化對(duì)象
          num_epochs = 1000#迭代次數(shù)
          learning_rate = 1e-2#學(xué)習(xí)率0.01
          Loss = torch.nn.MSELoss()#損失函數(shù)
          optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)#優(yōu)化函數(shù)

          我們創(chuàng)建一個(gè)由方程產(chǎn)生的數(shù)據(jù)集,并通過函數(shù)制造噪音

          import torch 
          from matplotlib import pyplot as plt
          from torch.autograd import Variable
          from torch import nn
          # 創(chuàng)建數(shù)據(jù)集  unsqueeze 相當(dāng)于
          x = Variable(torch.unsqueeze(torch.linspace(-11100), dim=1))
          y = Variable(x * 2 + 0.2 + torch.rand(x.size()))
          plt.scatter(x.data.numpy(),y.data.numpy())
          plt.show()

          關(guān)于torch.unsqueeze函數(shù)解讀。

          >>> x = torch.tensor([1234])
          >>> torch.unsqueeze(x, 0)
          tensor([[ 1,  2,  3,  4]])
          >>> torch.unsqueeze(x, 1)
          tensor([[ 1],
                  [ 2],
                  [ 3],
                  [ 4]])

          遍歷每次epoch,計(jì)算出loss,反向傳播計(jì)算梯度,不斷的更新梯度,使用梯度下降進(jìn)行優(yōu)化。

          for epoch in range(num_epochs):
              # 預(yù)測
              y_pred= model(x)
              # 計(jì)算loss
              loss = Loss(y_pred, y)
              #清空上一步參數(shù)值
              optimizer.zero_grad()
              #反向傳播
              loss.backward()
              #更新參數(shù)
              optimizer.step()
              if epoch % 200 == 0:
                  print("[{}/{}] loss:{:.4f}".format(epoch+1, num_epochs, loss))

          plt.scatter(x.data.numpy(), y.data.numpy())
          plt.plot(x.data.numpy(), y_pred.data.numpy(), 'r-',lw=5)
          plt.text(0.50,'Loss=%.4f' % loss.data.item(), fontdict={'size'20'color':  'red'})
          plt.show()
          ####結(jié)果如下####
          [1/1000] loss:4.2052
          [201/1000] loss:0.2690
          [401/1000] loss:0.0925
          [601/1000] loss:0.0810
          [801/1000] loss:0.0802
          [w, b] = model.parameters()
          print(w,b)
          # Parameter containing:
          tensor([[2.0036]], requires_grad=True) Parameter containing:
          tensor([0.7006], requires_grad=True)

          這里的b=0.7,等于0.2 + torch.rand(x.size()),經(jīng)過大量的訓(xùn)練torch.rand()一般約等于0.5。

          往期精彩回顧




          本站qq群851320808,加入微信群請(qǐng)掃碼:
          瀏覽 28
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  人妻无码高清 | 国产黄视频在线免费看 | 国产精品嫩草影院欧美成人精品a | 日本无码一区二区三三 | 大香操逼网 |