從零實(shí)現(xiàn)深度學(xué)習(xí)框架(十三)動(dòng)手實(shí)現(xiàn)邏輯回歸

橫屏觀看,效果更佳!更多文章請(qǐng)關(guān)注公眾號(hào)!
引言
本著“凡我不能創(chuàng)造的,我就不能理解”的思想,本系列文章會(huì)基于純Python以及NumPy從零創(chuàng)建自己的深度學(xué)習(xí)框架,該框架類似PyTorch能實(shí)現(xiàn)自動(dòng)求導(dǎo)。
要深入理解深度學(xué)習(xí),從零開(kāi)始創(chuàng)建的經(jīng)驗(yàn)非常重要,從自己可以理解的角度出發(fā),盡量不適用外部完備的框架前提下,實(shí)現(xiàn)我們想要的模型。本系列文章的宗旨就是通過(guò)這樣的過(guò)程,讓大家切實(shí)掌握深度學(xué)習(xí)底層實(shí)現(xiàn),而不是僅做一個(gè)調(diào)包俠。
上篇文章對(duì)邏輯回歸進(jìn)行了簡(jiǎn)單的介紹,本文我們就來(lái)從零實(shí)現(xiàn)邏輯回歸。
實(shí)現(xiàn)Sigmoid函數(shù)
首先實(shí)現(xiàn)邏輯回歸中的邏輯函數(shù)。
def?sigmoid(x:?Tensor)?->?Tensor:
????return?1?/?(1?+?(-x).exp())
正如使用PyTorch一樣,我們也只需要實(shí)現(xiàn)前向傳播。
實(shí)現(xiàn)交叉熵?fù)p失函數(shù)
def?binary_cross_entropy(input:?Tensor,?target:?Tensor,?reduction:?str?=?"mean")?->?Tensor:
????errors?=?-(target?*?input.log()?+?(1?-?target)?*?(1?-?input).log())
????N?=?len(target)
????if?reduction?==?"mean":
????????loss?=?errors.sum()?/?N
????elif?reduction?==?"sum":
????????loss?=?errors.sum()
????else:
????????loss?=?errors
????return?loss
這里的input是經(jīng)過(guò)Sigmoid函數(shù)的輸出,target是真實(shí)輸出。我們先定義這樣一個(gè)方法。然后再實(shí)現(xiàn)損失類:
class?BCELoss(_Loss):
????def?__init__(self,?reduction:?str?=?"mean")?->?None:
????????super().__init__(reduction)
????def?forward(self,?input:?Tensor,?target:?Tensor)?->?Tensor:
????????return?F.binary_cross_entropy(input,?target,?self.reduction)
實(shí)現(xiàn)邏輯回歸
有了激活函數(shù)、損失函數(shù)。我們就可以來(lái)實(shí)現(xiàn)邏輯回歸了。
class?LogisticRegression(Module):
????def?__init__(self,?input_dim,?output_dim):
????????self.linear?=?Linear(input_dim,?output_dim)
????def?forward(self,?x:?Tensor)?->?Tensor:
????????return?F.sigmoid(self.linear(x))
代碼非常簡(jiǎn)單,首先是一個(gè)線性回歸,然后經(jīng)過(guò)sigmoid即可。有了自動(dòng)求導(dǎo)工具,我們不必操心反向傳播。
下面通過(guò)一個(gè)實(shí)例來(lái)應(yīng)用我們實(shí)現(xiàn)的邏輯回歸。
學(xué)院錄取預(yù)測(cè)
數(shù)據(jù)集來(lái)自吳恩達(dá)老師的課程,橫坐標(biāo)表示第一次考試的成績(jī),縱坐標(biāo)表示第二次考試的成績(jī)。藍(lán)點(diǎn)表示被學(xué)院錄取,橙點(diǎn)表示沒(méi)有被錄取。如下圖所示:

我們要從中間畫(huà)一根線,代表決策邊界。在下方的樣本點(diǎn)判斷為沒(méi)有被錄取,在上方的樣本點(diǎn)判斷為被學(xué)院錄取。
首先定義加載數(shù)據(jù)集的函數(shù):
def?load_data(path,?draw_picture=False):
????data?=?pd.read_csv(path)
????X?=?data.iloc[:,?:-1]
????y?=?data.iloc[:,?-1]
????y?=?y[:,?np.newaxis]
????return?Tensor(X),?Tensor(y)
然后編寫(xiě)訓(xùn)練過(guò)程:
import?matplotlib.pyplot?as?plt
import?numpy?as?np
import?pandas?as?pd
from?tqdm?import?tqdm
import?metagrad.functions?as?F
from?metagrad.loss?import?BCELoss
from?metagrad.module?import?Module,?Linear
from?metagrad.optim?import?SGD
from?metagrad.tensor?import?Tensor
??
??
if?__name__?==?'__main__':
????X,?y?=?load_data("./data/marks.txt",?draw_picture=True)
????epochs?=?200_000?#?迭代20萬(wàn)次
????model?=?LogisticRegression(2,?1)?#?輸入有2個(gè)維度,輸出通過(guò)的概率。
????optimizer?=?SGD(model.parameters(),?lr=1e-3)
????loss?=?BCELoss()
????losses?=?[]
????for?epoch?in?tqdm(range(int(epochs))):?#?顯示進(jìn)度條
????????optimizer.zero_grad()
????????outputs?=?model(X)
????????l?=?loss(outputs,?y)
????????optimizer.zero_grad()
????????l.backward()
????????optimizer.step()
????????if?(epoch?+?1)?%?10000?==?0:
????????????total?=?0
????????????correct?=?0
????????????total?+=?len(y)
????????????correct?+=?np.sum(outputs.numpy().round()?==?y.numpy())?#?計(jì)算準(zhǔn)確率
????????????accuracy?=?100?*?correct?/?total
????????????losses.append(l.item())
????????????print(f"Train?-??Loss:?{l.item()}.?Accuracy:?{accuracy}\n")
??5%|▌?????????|?10024/200000?[00:04<01:23,?2271.00it/s]Train?-??Loss:?0.5822722315788269.?Accuracy:?60.60606060606061
?10%|▉?????????|?19855/200000?[00:08<01:19,?2253.23it/s]Train?-??Loss:?0.5447849631309509.?Accuracy:?65.65656565656566
?15%|█▍????????|?29889/200000?[00:13<01:15,?2262.13it/s]Train?-??Loss:?0.5130925178527832.?Accuracy:?67.67676767676768
?20%|█▉????????|?39947/200000?[00:17<01:10,?2271.09it/s]Train?-??Loss:?0.48615676164627075.?Accuracy:?75.75757575757575
?25%|██▌???????|?50040/200000?[00:22<01:05,?2302.86it/s]Train?-??Loss:?0.46311360597610474.?Accuracy:?78.78787878787878
?30%|██▉???????|?59910/200000?[00:26<01:00,?2300.01it/s]Train?-??Loss:?0.44325879216194153.?Accuracy:?80.8080808080808
?35%|███▍??????|?69928/200000?[00:30<01:00,?2150.51it/s]Train?-??Loss:?0.426025390625.?Accuracy:?84.84848484848484
?40%|███▉??????|?79903/200000?[00:35<00:53,?2227.82it/s]Train?-??Loss:?0.41095882654190063.?Accuracy:?85.85858585858585
?45%|████▍?????|?89928/200000?[00:39<00:48,?2287.55it/s]Train?-??Loss:?0.3976950943470001.?Accuracy:?88.88888888888889
?50%|████▉?????|?99961/200000?[00:44<00:44,?2269.85it/s]Train?-??Loss:?0.38594162464141846.?Accuracy:?89.8989898989899
?55%|█████▍????|?109821/200000?[00:48<00:40,?2223.05it/s]Train?-??Loss:?0.3754624128341675.?Accuracy:?90.9090909090909
?60%|█████▉????|?119952/200000?[00:53<00:34,?2305.71it/s]Train?-??Loss:?0.3660658299922943.?Accuracy:?91.91919191919192
?65%|██████▍???|?129877/200000?[00:57<00:30,?2295.31it/s]Train?-??Loss:?0.35759538412094116.?Accuracy:?90.9090909090909
?70%|███████???|?140024/200000?[01:01<00:26,?2275.92it/s]Train?-??Loss:?0.3499223589897156.?Accuracy:?90.9090909090909
?75%|███████▍??|?149947/200000?[01:06<00:21,?2294.55it/s]Train?-??Loss:?0.34294024109840393.?Accuracy:?91.91919191919192
?80%|███████▉??|?159852/200000?[01:10<00:17,?2280.32it/s]Train?-??Loss:?0.3365602195262909.?Accuracy:?91.91919191919192
?85%|████████▌?|?170017/200000?[01:14<00:12,?2310.93it/s]Train?-??Loss:?0.3307078182697296.?Accuracy:?91.91919191919192
?90%|████████▉?|?179960/200000?[01:19<00:08,?2313.34it/s]Train?-??Loss:?0.32532015442848206.?Accuracy:?91.91919191919192
?95%|█████████▍|?189875/200000?[01:23<00:04,?2276.93it/s]Train?-??Loss:?0.320343941450119.?Accuracy:?91.91919191919192
100%|██████████|?200000/200000?[01:27<00:00,?2273.90it/s]
Train?-??Loss:?0.31573355197906494.?Accuracy:?91.91919191919192
由于還沒(méi)有利用GPU,因此耗時(shí)了2分鐘左右,后續(xù)我們的求導(dǎo)工具也可以支持GPU加速。
最后得到的準(zhǔn)確率有91.9%,還行。
最后我們要畫(huà)出決策邊界,由于我們的數(shù)據(jù)只有兩個(gè)特征,回顧一下Sigmoid的函數(shù)圖像。當(dāng)的取值大于時(shí),就判斷為正例,否則判斷為負(fù)例。

因此,我們可以令,得到?jīng)Q策邊界的函數(shù):
把看成自變量,看因變量,就可以繪制出如下的圖像:

完整代碼
完整代碼筆者上傳到了程序員最大交友網(wǎng)站上去了,地址: [?? ?https://github.com/nlp-greyfoss/metagrad]
最后一句:BUG,走你!


Markdown筆記神器Typora配置Gitee圖床
不會(huì)真有人覺(jué)得聊天機(jī)器人難吧(一)
Spring Cloud學(xué)習(xí)筆記(一)
沒(méi)有人比我更懂Spring Boot(一)
入門(mén)人工智能必備的線性代數(shù)基礎(chǔ)
1.看到這里了就點(diǎn)個(gè)在看支持下吧,你的在看是我創(chuàng)作的動(dòng)力。
2.關(guān)注公眾號(hào),每天為您分享原創(chuàng)或精選文章!
3.特殊階段,帶好口罩,做好個(gè)人防護(hù)。
