<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          從零實(shí)現(xiàn)深度學(xué)習(xí)框架(十三)動(dòng)手實(shí)現(xiàn)邏輯回歸

          共 2852字,需瀏覽 6分鐘

           ·

          2022-01-21 00:06

          橫屏觀看,效果更佳!更多文章請(qǐng)關(guān)注公眾號(hào)!

          更多精彩推薦,請(qǐng)關(guān)注我們

          引言

          本著“凡我不能創(chuàng)造的,我就不能理解”的思想,本系列文章會(huì)基于純Python以及NumPy從零創(chuàng)建自己的深度學(xué)習(xí)框架,該框架類似PyTorch能實(shí)現(xiàn)自動(dòng)求導(dǎo)。

          要深入理解深度學(xué)習(xí),從零開(kāi)始創(chuàng)建的經(jīng)驗(yàn)非常重要,從自己可以理解的角度出發(fā),盡量不適用外部完備的框架前提下,實(shí)現(xiàn)我們想要的模型。本系列文章的宗旨就是通過(guò)這樣的過(guò)程,讓大家切實(shí)掌握深度學(xué)習(xí)底層實(shí)現(xiàn),而不是僅做一個(gè)調(diào)包俠。

          上篇文章對(duì)邏輯回歸進(jìn)行了簡(jiǎn)單的介紹,本文我們就來(lái)從零實(shí)現(xiàn)邏輯回歸。

          實(shí)現(xiàn)Sigmoid函數(shù)

          首先實(shí)現(xiàn)邏輯回歸中的邏輯函數(shù)。

          def?sigmoid(x:?Tensor)?->?Tensor:
          ????return?1?/?(1?+?(-x).exp())

          正如使用PyTorch一樣,我們也只需要實(shí)現(xiàn)前向傳播。

          實(shí)現(xiàn)交叉熵?fù)p失函數(shù)

          def?binary_cross_entropy(input:?Tensor,?target:?Tensor,?reduction:?str?=?"mean")?->?Tensor:
          ????errors?=?-(target?*?input.log()?+?(1?-?target)?*?(1?-?input).log())

          ????N?=?len(target)

          ????if?reduction?==?"mean":
          ????????loss?=?errors.sum()?/?N
          ????elif?reduction?==?"sum":
          ????????loss?=?errors.sum()
          ????else:
          ????????loss?=?errors
          ????return?loss

          這里的input是經(jīng)過(guò)Sigmoid函數(shù)的輸出,target是真實(shí)輸出。我們先定義這樣一個(gè)方法。然后再實(shí)現(xiàn)損失類:

          class?BCELoss(_Loss):
          ????def?__init__(self,?reduction:?str?=?"mean")?->?None:
          ????????super().__init__(reduction)

          ????def?forward(self,?input:?Tensor,?target:?Tensor)?->?Tensor:
          ????????return?F.binary_cross_entropy(input,?target,?self.reduction)

          實(shí)現(xiàn)邏輯回歸

          有了激活函數(shù)、損失函數(shù)。我們就可以來(lái)實(shí)現(xiàn)邏輯回歸了。

          class?LogisticRegression(Module):
          ????def?__init__(self,?input_dim,?output_dim):
          ????????self.linear?=?Linear(input_dim,?output_dim)

          ????def?forward(self,?x:?Tensor)?->?Tensor:
          ????????return?F.sigmoid(self.linear(x))

          代碼非常簡(jiǎn)單,首先是一個(gè)線性回歸,然后經(jīng)過(guò)sigmoid即可。有了自動(dòng)求導(dǎo)工具,我們不必操心反向傳播。

          下面通過(guò)一個(gè)實(shí)例來(lái)應(yīng)用我們實(shí)現(xiàn)的邏輯回歸。

          學(xué)院錄取預(yù)測(cè)

          數(shù)據(jù)集來(lái)自吳恩達(dá)老師的課程,橫坐標(biāo)表示第一次考試的成績(jī),縱坐標(biāo)表示第二次考試的成績(jī)。藍(lán)點(diǎn)表示被學(xué)院錄取,橙點(diǎn)表示沒(méi)有被錄取。如下圖所示:

          數(shù)據(jù)分布

          我們要從中間畫(huà)一根線,代表決策邊界。在下方的樣本點(diǎn)判斷為沒(méi)有被錄取,在上方的樣本點(diǎn)判斷為被學(xué)院錄取。

          首先定義加載數(shù)據(jù)集的函數(shù):

          def?load_data(path,?draw_picture=False):
          ????data?=?pd.read_csv(path)

          ????X?=?data.iloc[:,?:-1]
          ????y?=?data.iloc[:,?-1]

          ????y?=?y[:,?np.newaxis]

          ????return?Tensor(X),?Tensor(y)

          然后編寫(xiě)訓(xùn)練過(guò)程:

          import?matplotlib.pyplot?as?plt
          import?numpy?as?np
          import?pandas?as?pd
          from?tqdm?import?tqdm

          import?metagrad.functions?as?F
          from?metagrad.loss?import?BCELoss
          from?metagrad.module?import?Module,?Linear
          from?metagrad.optim?import?SGD
          from?metagrad.tensor?import?Tensor
          ??
          ??

          if?__name__?==?'__main__':

          ????X,?y?=?load_data("./data/marks.txt",?draw_picture=True)

          ????epochs?=?200_000?#?迭代20萬(wàn)次

          ????model?=?LogisticRegression(2,?1)?#?輸入有2個(gè)維度,輸出通過(guò)的概率。

          ????optimizer?=?SGD(model.parameters(),?lr=1e-3)

          ????loss?=?BCELoss()

          ????losses?=?[]

          ????for?epoch?in?tqdm(range(int(epochs))):?#?顯示進(jìn)度條

          ????????optimizer.zero_grad()
          ????????outputs?=?model(X)
          ????????l?=?loss(outputs,?y)
          ????????optimizer.zero_grad()
          ????????l.backward()
          ????????optimizer.step()

          ????????if?(epoch?+?1)?%?10000?==?0:
          ????????????total?=?0
          ????????????correct?=?0
          ????????????total?+=?len(y)
          ????????????correct?+=?np.sum(outputs.numpy().round()?==?y.numpy())?#?計(jì)算準(zhǔn)確率
          ????????????accuracy?=?100?*?correct?/?total
          ????????????losses.append(l.item())

          ????????????print(f"Train?-??Loss:?{l.item()}.?Accuracy:?{accuracy}\n")
          ??5%|▌?????????|?10024/200000?[00:04<01:23,?2271.00it/s]Train?-??Loss:?0.5822722315788269.?Accuracy:?60.60606060606061

          ?10%|▉?????????|?19855/200000?[00:08<01:19,?2253.23it/s]Train?-??Loss:?0.5447849631309509.?Accuracy:?65.65656565656566

          ?15%|█▍????????|?29889/200000?[00:13<01:15,?2262.13it/s]Train?-??Loss:?0.5130925178527832.?Accuracy:?67.67676767676768

          ?20%|█▉????????|?39947/200000?[00:17<01:10,?2271.09it/s]Train?-??Loss:?0.48615676164627075.?Accuracy:?75.75757575757575

          ?25%|██▌???????|?50040/200000?[00:22<01:05,?2302.86it/s]Train?-??Loss:?0.46311360597610474.?Accuracy:?78.78787878787878

          ?30%|██▉???????|?59910/200000?[00:26<01:00,?2300.01it/s]Train?-??Loss:?0.44325879216194153.?Accuracy:?80.8080808080808

          ?35%|███▍??????|?69928/200000?[00:30<01:00,?2150.51it/s]Train?-??Loss:?0.426025390625.?Accuracy:?84.84848484848484

          ?40%|███▉??????|?79903/200000?[00:35<00:53,?2227.82it/s]Train?-??Loss:?0.41095882654190063.?Accuracy:?85.85858585858585

          ?45%|████▍?????|?89928/200000?[00:39<00:48,?2287.55it/s]Train?-??Loss:?0.3976950943470001.?Accuracy:?88.88888888888889

          ?50%|████▉?????|?99961/200000?[00:44<00:44,?2269.85it/s]Train?-??Loss:?0.38594162464141846.?Accuracy:?89.8989898989899

          ?55%|█████▍????|?109821/200000?[00:48<00:40,?2223.05it/s]Train?-??Loss:?0.3754624128341675.?Accuracy:?90.9090909090909

          ?60%|█████▉????|?119952/200000?[00:53<00:34,?2305.71it/s]Train?-??Loss:?0.3660658299922943.?Accuracy:?91.91919191919192

          ?65%|██████▍???|?129877/200000?[00:57<00:30,?2295.31it/s]Train?-??Loss:?0.35759538412094116.?Accuracy:?90.9090909090909

          ?70%|███████???|?140024/200000?[01:01<00:26,?2275.92it/s]Train?-??Loss:?0.3499223589897156.?Accuracy:?90.9090909090909

          ?75%|███████▍??|?149947/200000?[01:06<00:21,?2294.55it/s]Train?-??Loss:?0.34294024109840393.?Accuracy:?91.91919191919192

          ?80%|███████▉??|?159852/200000?[01:10<00:17,?2280.32it/s]Train?-??Loss:?0.3365602195262909.?Accuracy:?91.91919191919192

          ?85%|████████▌?|?170017/200000?[01:14<00:12,?2310.93it/s]Train?-??Loss:?0.3307078182697296.?Accuracy:?91.91919191919192

          ?90%|████████▉?|?179960/200000?[01:19<00:08,?2313.34it/s]Train?-??Loss:?0.32532015442848206.?Accuracy:?91.91919191919192

          ?95%|█████████▍|?189875/200000?[01:23<00:04,?2276.93it/s]Train?-??Loss:?0.320343941450119.?Accuracy:?91.91919191919192

          100%|██████████|?200000/200000?[01:27<00:00,?2273.90it/s]
          Train?-??Loss:?0.31573355197906494.?Accuracy:?91.91919191919192

          由于還沒(méi)有利用GPU,因此耗時(shí)了2分鐘左右,后續(xù)我們的求導(dǎo)工具也可以支持GPU加速。

          最后得到的準(zhǔn)確率有91.9%,還行。

          最后我們要畫(huà)出決策邊界,由于我們的數(shù)據(jù)只有兩個(gè)特征,回顧一下Sigmoid的函數(shù)圖像。當(dāng)的取值大于時(shí),就判斷為正例,否則判斷為負(fù)例。

          Sigmoid函數(shù)圖像

          因此,我們可以令,得到?jīng)Q策邊界的函數(shù):

          看成自變量,看因變量,就可以繪制出如下的圖像:

          決策邊界

          完整代碼

          完整代碼筆者上傳到了程序員最大交友網(wǎng)站上去了,地址: [?? ?https://github.com/nlp-greyfoss/metagrad]

          最后一句:BUG,走你!

          Markdown筆記神器Typora配置Gitee圖床
          不會(huì)真有人覺(jué)得聊天機(jī)器人難吧(一)
          Spring Cloud學(xué)習(xí)筆記(一)
          沒(méi)有人比我更懂Spring Boot(一)
          入門(mén)人工智能必備的線性代數(shù)基礎(chǔ)

          1.看到這里了就點(diǎn)個(gè)在看支持下吧,你的在看是我創(chuàng)作的動(dòng)力。
          2.關(guān)注公眾號(hào),每天為您分享原創(chuàng)或精選文章!
          3.特殊階段,帶好口罩,做好個(gè)人防護(hù)。

          瀏覽 46
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  亚洲熟女一区二区 | 成人亚洲欧美 | 人人草人人草人人草 | 影音先锋 自拍 | 天堂a√8蜜桃 |