從零開始深度學(xué)習(xí)Pytorch筆記(7)—— 使用Pytorch實現(xiàn)線性回歸


前文傳送門:
從零開始深度學(xué)習(xí)Pytorch筆記(1)——安裝Pytorch
從零開始深度學(xué)習(xí)Pytorch筆記(2)——張量的創(chuàng)建(上)
從零開始深度學(xué)習(xí)Pytorch筆記(3)——張量的創(chuàng)建(下)
從零開始深度學(xué)習(xí)Pytorch筆記(4)——張量的拼接與切分
從零開始深度學(xué)習(xí)Pytorch筆記(5)——張量的索引與變換
從零開始深度學(xué)習(xí)Pytorch筆記(6)——張量的數(shù)學(xué)運算
在該系列的上一篇,我們介紹了Pytorch中的張量的數(shù)學(xué)運算,本文教會大家使用Pytorch搭建一個線性回歸模型。
說到線性回歸,從某種程度上可以算是最簡單的機器學(xué)習(xí)模型了。具體的理論推導(dǎo)我這里就不多說了,網(wǎng)上隨手一搜就有。
我們著重講講使用Pytorch搭建模型的過程。
首先貼出可實現(xiàn)的代碼:
import?torch
import?matplotlib.pyplot?as?plt
torch.manual_seed(10)#隨機數(shù)種子
lr?=?0.1?#學(xué)習(xí)率
#創(chuàng)建訓(xùn)練數(shù)據(jù)
x?=?torch.rand(20,1)*10?#shape(20,1)
y?=?2*x?+?(5?+?torch.randn(20,1))?#shape(20,1)
#構(gòu)建線性回歸參數(shù)
w?=?torch.randn((1),requires_grad=True)#隨機初始化w,要用到自動梯度求導(dǎo)
b?=?torch.zeros((1),requires_grad=True)#使用0初始化b,要用到自動梯度求導(dǎo)
for?iteration?in?range(1000):
????#前向傳播
????wx?=?torch.mul(w,x) # w*x
????y_pred?=?torch.add(wx,b) # y = w*x + b
????#計算?MSE?loss
????loss?=?(0.5*(y-y_pred)**2).mean()
????#反向傳播
????loss.backward()
????#更新參數(shù)
????b.data.sub_(lr*b.grad)?#?b?=?b?-?lr*b.grad
????w.data.sub_(lr*w.grad)?#?w?=?w?-?lr*w.grad
????#繪圖
????if?iteration?%?20?==?0:
????????plt.scatter(x.data.numpy(),y.data.numpy())
????????plt.plot(x.data.numpy(),y_pred.data.numpy(),'r-',lw=5)
????????plt.text(2,20,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'red'})
????????plt.xlim(1.5,10)
????????plt.ylim(8,28)
????????plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))
????????plt.pause(0.5)
????????if?loss.data.numpy()?1:#停止條件
????????????break
我們來分步驟講講上面的代碼具體的內(nèi)容。
首先導(dǎo)入相關(guān)的庫,設(shè)定學(xué)習(xí)率和隨機數(shù)種子,然后創(chuàng)建隨機數(shù)作為使用的數(shù)據(jù)。
初始化參數(shù) w、b,由于之后需要在模型訓(xùn)練中不斷調(diào)整 w、b 的參數(shù)值,并且會用到相關(guān)求導(dǎo),所以設(shè)置 requires_grad=True,代表需要用到該張量的求導(dǎo)。
之后寫了一個循環(huán),每次循環(huán)先進行前向傳播,計算 y 的預(yù)測值,計算 loss 損失值,然后反向傳播損失,去更新參數(shù) w、b。
之后是一個繪圖操作,繪制數(shù)據(jù)的散點圖和在訓(xùn)練過程中的線性回歸直線。
運行代碼后,我們可以看到如下的幾個訓(xùn)練過程中的可視化圖,當loss損失值小于1時,停止可視化。






歡迎關(guān)注公眾號學(xué)習(xí)之后的深度學(xué)習(xí)連載部分~
喜歡記得點在看哦,證明你來看過~