<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          像 Keras 一樣優(yōu)雅地使用 pytorch-lightning

          共 7417字,需瀏覽 15分鐘

           ·

          2021-01-27 19:40

          你好,我是云哥。本篇文章為大家介紹一個(gè)可以幫助大家優(yōu)雅地進(jìn)行深度學(xué)習(xí)研究的工具:pytorch-lightning。

          公眾號(hào)后臺(tái)回復(fù)關(guān)鍵字:源碼,獲取本文源代碼!

          pytorch-lightning 是建立在pytorch之上的高層次模型接口,pytorch-lightning之于pytorch,就如同keras之于tensorflow。

          關(guān)于pytorch-lightning的完整入門介紹,可以參考我的另外一篇文章。

          使用pytorch-lightning漂亮地進(jìn)行深度學(xué)習(xí)研究

          我用了約80行代碼對(duì) pytorch-lightning 做了進(jìn)一步封裝,使得對(duì)它不熟悉的用戶可以用類似Keras的風(fēng)格使用它,輕而易舉地實(shí)現(xiàn)如下功能:

          • 模型訓(xùn)練(cpu,gpu,多GPU)

          • 模型評(píng)估 (自定義評(píng)估指標(biāo))

          • 最優(yōu)模型參數(shù)保存(ModelCheckPoint)

          • 自定義學(xué)習(xí)率 (lr_schedule)

          • 畫出優(yōu)美的Loss和Metric曲線

          它甚至?xí)菿eras還要更加簡(jiǎn)單和好用一些。

          這個(gè)封裝的類 LightModel 添加到了我的開源倉(cāng)庫(kù) torchkeras 中,用戶可以用pip進(jìn)行安裝。

          pip?install?-U?torchkeras

          以下是一個(gè)通過LightModel使用DNN模型進(jìn)行二分類的完整范例。

          在本例的最后,云哥將向大家表演一個(gè)"金蟬脫殼"的絕技。不要離開。????

          import?numpy?as?np?
          import?pandas?as?pd?
          from?matplotlib?import?pyplot?as?plt
          import?torch
          from?torch?import?nn
          import?torch.nn.functional?as?F
          from?torch.utils.data?import?Dataset,DataLoader,TensorDataset
          import?datetime

          #attention?these?two?lines
          import?pytorch_lightning?as?pl?
          import?torchkeras?

          一,準(zhǔn)備數(shù)據(jù)

          %matplotlib?inline
          %config?InlineBackend.figure_format?=?'svg'

          #number?of?samples
          n_positive,n_negative?=?2000,2000

          #positive?samples
          r_p?=?5.0?+?torch.normal(0.0,1.0,size?=?[n_positive,1])?
          theta_p?=?2*np.pi*torch.rand([n_positive,1])
          Xp?=?torch.cat([r_p*torch.cos(theta_p),r_p*torch.sin(theta_p)],axis?=?1)
          Yp?=?torch.ones_like(r_p)

          #negative?samples
          r_n?=?8.0?+?torch.normal(0.0,1.0,size?=?[n_negative,1])?
          theta_n?=?2*np.pi*torch.rand([n_negative,1])
          Xn?=?torch.cat([r_n*torch.cos(theta_n),r_n*torch.sin(theta_n)],axis?=?1)
          Yn?=?torch.zeros_like(r_n)

          #concat?positive?and?negative?samples
          X?=?torch.cat([Xp,Xn],axis?=?0)
          Y?=?torch.cat([Yp,Yn],axis?=?0)


          #visual?samples
          plt.figure(figsize?=?(6,6))
          plt.scatter(Xp[:,0],Xp[:,1],c?=?"r")
          plt.scatter(Xn[:,0],Xn[:,1],c?=?"g")
          plt.legend(["positive","negative"]);


          #?split?samples?into?train?and?valid?data.
          ds?=?TensorDataset(X,Y)
          ds_train,ds_valid?=?torch.utils.data.random_split(ds,[int(len(ds)*0.7),len(ds)-int(len(ds)*0.7)])
          dl_train?=?DataLoader(ds_train,batch_size?=?100,shuffle=True,num_workers=4)
          dl_valid?=?DataLoader(ds_valid,batch_size?=?100,num_workers=4)

          二,定義模型

          #define?the?network?like?torch
          class?Net(nn.Module):??
          ????def?__init__(self):
          ????????super().__init__()
          ????????self.fc1?=?nn.Linear(2,6)
          ????????self.fc2?=?nn.Linear(6,12)?
          ????????self.fc3?=?nn.Linear(12,1)
          ????????
          ????def?forward(self,x):
          ????????x?=?F.relu(self.fc1(x))
          ????????x?=?F.relu(self.fc2(x))
          ????????y?=?nn.Sigmoid()(self.fc3(x))
          ????????return?y???????
          class?Model(torchkeras.LightModel):
          ????def?shared_step(self,batch):
          ????????x,?y?=?batch
          ????????prediction?=?self(x)
          ????????loss?=?nn.BCELoss()(prediction,y)
          ????????preds?=?torch.where(prediction>0.5,torch.ones_like(prediction),torch.zeros_like(prediction))
          ????????acc?=?pl.metrics.functional.accuracy(preds,?y)
          ????????#?attention:?there?must?be?a?key?of?"loss"?in?the?returned?dict?
          ????????dic?=?{"loss":loss,"acc":acc}?
          ????????return?dic
          ????
          ????#optimizer,and?optional?lr_scheduler
          ????def?configure_optimizers(self):
          ????????optimizer?=?torch.optim.Adam(self.parameters(),?lr=1e-2)
          ????????lr_scheduler?=?torch.optim.lr_scheduler.StepLR(optimizer,?step_size=10,?gamma=0.0001)
          ????????return?{"optimizer":optimizer,"lr_scheduler":lr_scheduler}
          ?

          注意,下面我們把網(wǎng)絡(luò)結(jié)構(gòu)net包裝在一個(gè)model的殼之中。????

          pl.seed_everything(123)

          #?we?wrap?the?network?into?a?Model?
          net?=?Net()
          model?=?Model(net)

          torchkeras.summary(model,input_shape?=(2,))

          ----------------------------------------------------------------
          ????????Layer?(type)???????????????Output?Shape?????????Param?#
          ================================================================
          ????????????Linear-1????????????????????[-1,?4]??????????????12
          ????????????Linear-2????????????????????[-1,?8]??????????????40
          ????????????Linear-3????????????????????[-1,?1]???????????????9
          ================================================================
          Total?params:?61
          Trainable?params:?61
          Non-trainable?params:?0
          ----------------------------------------------------------------
          Input?size?(MB):?0.000008
          Forward/backward?pass?size?(MB):?0.000099
          Params?size?(MB):?0.000233
          Estimated?Total?Size?(MB):?0.000340
          ----------------------------------------------------------------

          三,訓(xùn)練模型


          ckpt_callback?=?pl.callbacks.ModelCheckpoint(
          ????monitor='val_loss',
          ????save_top_k=1,
          ????mode='min'
          )

          #?gpus=0?則使用cpu訓(xùn)練,gpus=1則使用1個(gè)gpu訓(xùn)練,gpus=2則使用2個(gè)gpu訓(xùn)練,gpus=-1則使用所有g(shù)pu訓(xùn)練,
          #?gpus=[0,1]則指定使用0號(hào)和1號(hào)gpu訓(xùn)練,?gpus="0,1,2,3"則使用0,1,2,3號(hào)gpu訓(xùn)練
          #?tpus=1?則使用1個(gè)tpu訓(xùn)練

          trainer?=?pl.Trainer(max_epochs=10,gpus=0,callbacks?=?[ckpt_callback])?

          #斷點(diǎn)續(xù)訓(xùn)
          #trainer?=?pl.Trainer(resume_from_checkpoint='./lightning_logs/version_31/checkpoints/epoch=02-val_loss=0.05.ckpt')

          trainer.fit(model,dl_train,dl_valid)
          GPU?available:?False,?used:?False
          TPU?available:?None,?using:?0?TPU?cores

          ??|?Name?|?Type?|?Params
          ------------------------------
          0?|?net??|?Net??|?115???
          ------------------------------
          115???????Trainable?params
          0?????????Non-trainable?params
          115???????Total?params

          ================================================================================2021-01-24?20:47:39
          epoch?=??0
          {'val_loss':?0.6492899060249329,?'val_acc':?0.6033333539962769}
          {'acc':?0.5374999642372131,?'loss':?0.6766871809959412}

          ================================================================================2021-01-24?20:47:40
          epoch?=??1
          {'val_loss':?0.5390750765800476,?'val_acc':?0.763333261013031}
          {'acc':?0.676428496837616,?'loss':?0.5993633270263672}

          ================================================================================2021-01-24?20:47:41
          epoch?=??2
          {'val_loss':?0.3617284595966339,?'val_acc':?0.8608333468437195}
          {'acc':?0.8050000071525574,?'loss':?0.4533742070198059}

          ================================================================================2021-01-24?20:47:42
          epoch?=??3
          {'val_loss':?0.21798092126846313,?'val_acc':?0.9158334732055664}
          {'acc':?0.8910714387893677,?'loss':?0.28334707021713257}

          ================================================================================2021-01-24?20:47:43
          epoch?=??4
          {'val_loss':?0.18157465755939484,?'val_acc':?0.9208333492279053}
          {'acc':?0.926428496837616,?'loss':?0.20261192321777344}

          ================================================================================2021-01-24?20:47:44
          epoch?=??5
          {'val_loss':?0.17406059801578522,?'val_acc':?0.9300000071525574}
          {'acc':?0.9203571677207947,?'loss':?0.1980973333120346}

          ================================================================================2021-01-24?20:47:45
          epoch?=??6
          {'val_loss':?0.16323940455913544,?'val_acc':?0.935833215713501}
          {'acc':?0.9242857694625854,?'loss':?0.1862144023180008}

          ================================================================================2021-01-24?20:47:46
          epoch?=??7
          {'val_loss':?0.16635416448116302,?'val_acc':?0.9300000071525574}
          {'acc':?0.925000011920929,?'loss':?0.18595384061336517}

          ================================================================================2021-01-24?20:47:47
          epoch?=??8
          {'val_loss':?0.1665605753660202,?'val_acc':?0.9258332848548889}
          {'acc':?0.9267856478691101,?'loss':?0.18308643996715546}

          ================================================================================2021-01-24?20:47:48
          epoch?=??9
          {'val_loss':?0.1757962554693222,?'val_acc':?0.9300000071525574}
          {'acc':?0.9246429204940796,?'loss':?0.18282662332057953}
          #?visual?the?results
          fig,?(ax1,ax2)?=?plt.subplots(nrows=1,ncols=2,figsize?=?(12,5))
          ax1.scatter(Xp[:,0],Xp[:,1],?c="r")
          ax1.scatter(Xn[:,0],Xn[:,1],c?=?"g")
          ax1.legend(["positive","negative"]);
          ax1.set_title("y_true")

          Xp_pred?=?X[torch.squeeze(model.forward(X)>=0.5)]
          Xn_pred?=?X[torch.squeeze(model.forward(X)<0.5)]

          ax2.scatter(Xp_pred[:,0],Xp_pred[:,1],c?=?"r")
          ax2.scatter(Xn_pred[:,0],Xn_pred[:,1],c?=?"g")
          ax2.legend(["positive","negative"]);
          ax2.set_title("y_pred")

          四,評(píng)估模型

          import?pandas?as?pd?

          history?=?model.history
          dfhistory?=?pd.DataFrame(history)?
          dfhistory?
          %matplotlib?inline
          %config?InlineBackend.figure_format?=?'svg'

          import?matplotlib.pyplot?as?plt

          def?plot_metric(dfhistory,?metric):
          ????train_metrics?=?dfhistory[metric]
          ????val_metrics?=?dfhistory['val_'+metric]
          ????epochs?=?range(1,?len(train_metrics)?+?1)
          ????plt.plot(epochs,?train_metrics,?'bo--')
          ????plt.plot(epochs,?val_metrics,?'ro-')
          ????plt.title('Training?and?validation?'+?metric)
          ????plt.xlabel("Epochs")
          ????plt.ylabel(metric)
          ????plt.legend(["train_"+metric,?'val_'+metric])
          ????plt.show()
          ????
          plot_metric(dfhistory,"loss")

          plot_metric(dfhistory,"acc")


          results?=?trainer.test(model,?test_dataloaders=dl_valid,?verbose?=?False)
          print(results[0])

          {'test_loss':?0.15939873456954956,?'test_acc':?0.9599999785423279}

          五,使用模型

          def?predict(model,dl):
          ????model.eval()
          ????result?=?torch.cat([model.forward(t[0].to(model.device))?for?t?in?dl])
          ????return(result.data)

          result?=?predict(model,dl_valid)

          result?
          tensor([[9.8850e-01],
          ????????[2.3642e-03],
          ????????[1.2128e-04],
          ????????...,
          ????????[9.9002e-01],
          ????????[9.6689e-01],
          ????????[1.5238e-02]])

          六,保存模型

          最優(yōu)模型默認(rèn)保存在 trainer.checkpoint_callback.best_model_path 的目錄下,可以直接加載。

          print(trainer.checkpoint_callback.best_model_path)
          print(trainer.checkpoint_callback.best_model_score)
          model_clone?=?Model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
          trainer_clone?=?pl.Trainer(max_epochs=10)?
          results?=?trainer_clone.test(model_clone,?test_dataloaders=dl_valid,?verbose?=?False)
          print(results[0])

          {'test_loss':?0.20505842566490173,?'test_acc':?0.9399999976158142}

          最后,給大家表演一個(gè)金蟬脫殼的絕技。????

          使用LightModel之殼訓(xùn)練后,可丟棄該軀殼,直接手動(dòng)保存最優(yōu)的網(wǎng)絡(luò)結(jié)構(gòu)net的權(quán)重。

          best_net?=?model.net?
          torch.save(best_net.state_dict(),"best_net.pt")

          #加載權(quán)重
          net_clone?=?Net()
          net_clone.load_state_dict(torch.load("best_net.pt"))


          data,label?=?next(iter(dl_valid))
          with?torch.no_grad():
          ????preds??=?model(data)
          ????preds_clone?=?net_clone(data)
          ????
          print("model?prediction:\n",preds[0:10],"\n")
          print("net_clone?prediction:\n",preds_clone[0:10])

          model?prediction:
          ?tensor([[9.8850e-01],
          ????????[2.3642e-03],
          ????????[1.2128e-04],
          ????????[1.0022e-04],
          ????????[9.3577e-01],
          ????????[4.9769e-02],
          ????????[9.8537e-01],
          ????????[9.9940e-01],
          ????????[4.1117e-04],
          ????????[9.4009e-01]])?

          net_clone?prediction:
          ?tensor([[9.8850e-01],
          ????????[2.3642e-03],
          ????????[1.2128e-04],
          ????????[1.0022e-04],
          ????????[9.3577e-01],
          ????????[4.9769e-02],
          ????????[9.8537e-01],
          ????????[9.9940e-01],
          ????????[4.1117e-04],
          ????????[9.4009e-01]])


          以上。


          如果對(duì)本文內(nèi)容理解上有需要進(jìn)一步和作者交流的地方,歡迎在公眾號(hào)"算法美食屋"下留言。作者時(shí)間和精力有限,會(huì)酌情予以回復(fù)。


          也可以在公眾號(hào)后臺(tái)回復(fù)關(guān)鍵字:加群,加入讀者交流群和大家討論。


          原創(chuàng)不易,不想被白嫖。歡迎大家三連支持云哥????:點(diǎn)贊,在看,分享。感謝。


          瀏覽 48
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  嫩草视频 | 黑人内射精品 | 中文字幕2区 | 亚洲无码不卡视频在线观看 | 九九九免费 |