<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          我用 PyTorch 復(fù)現(xiàn)了 LeNet-5 神經(jīng)網(wǎng)絡(luò)(CIFAR10 數(shù)據(jù)集篇)!

          共 8201字,需瀏覽 17分鐘

           ·

          2022-01-01 23:23

          ↑ 點(diǎn)擊藍(lán)字?關(guān)注極市平臺(tái)

          作者 | 紅色石頭?
          來源 | AI有道?
          編輯 | 極市平臺(tái)

          極市導(dǎo)讀

          ?

          大家好,我是紅色石頭!今天我們將使用 Pytorch 來繼續(xù)實(shí)現(xiàn) LeNet-5 模型,并用它來解決 CIFAR10 數(shù)據(jù)集的識(shí)別。?>>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺的最前沿

          正文開始!

          LeNet-5 網(wǎng)絡(luò)本是用來識(shí)別 MNIST 數(shù)據(jù)集的,下面我們來將 LeNet-5 應(yīng)用到一個(gè)比較復(fù)雜的例子,識(shí)別 CIFAR-10 數(shù)據(jù)集。

          CIFAR-10 是由 Hinton 的學(xué)生 Alex Krizhevsky 和 Ilya Sutskever 整理的一個(gè)用于識(shí)別普適物體的小型數(shù)據(jù)集。一共包含 10 個(gè)類別的 RGB 彩色圖 片:飛機(jī)( airlane )、汽車( automobile )、鳥類( bird )、貓( cat )、鹿( deer )、狗( dog )、蛙類( frog )、馬( horse )、船( ship )和卡車( truck )。圖片的尺寸為 32×32 ,數(shù)據(jù)集中一共有 50000 張訓(xùn)練圄片和 10000 張測(cè)試圖片。

          CIFAR-10 的圖片樣例如圖所示。

          1 下載并加載數(shù)據(jù),并做出一定的預(yù)先處理

          pipline_train?=?transforms.Compose([
          ????#隨機(jī)旋轉(zhuǎn)圖片
          ????transforms.RandomHorizontalFlip(),
          ????#將圖片尺寸resize到32x32
          ????transforms.Resize((32,32)),
          ????#將圖片轉(zhuǎn)化為Tensor格式
          ????transforms.ToTensor(),
          ????#正則化(當(dāng)模型出現(xiàn)過擬合的情況時(shí),用來降低模型的復(fù)雜度)
          ????transforms.Normalize((0.5,?0.5,?0.5),?(0.5,?0.5,?0.5))
          ])
          pipline_test?=?transforms.Compose([
          ????#將圖片尺寸resize到32x32
          ????transforms.Resize((32,32)),
          ????transforms.ToTensor(),
          ????transforms.Normalize((0.5,?0.5,?0.5),?(0.5,?0.5,?0.5))
          ])
          #下載數(shù)據(jù)集
          train_set?=?datasets.CIFAR10(root="./data/CIFAR10",?train=True,?download=True,?transform=pipline_train)
          test_set?=?datasets.CIFAR10(root="./data/CIFAR10",?train=False,?download=True,?transform=pipline_test)
          #加載數(shù)據(jù)集
          trainloader?=?torch.utils.data.DataLoader(train_set,?batch_size=64,?shuffle=True)
          testloader?=?torch.utils.data.DataLoader(test_set,?batch_size=32,?shuffle=False)
          #?類別信息也是需要我們給定的
          classes?=?('plane',?'car',?'bird',?'cat','deer',?'dog',?'frog',?'horse',?'ship',?'truck')

          2 搭建 LeNet-5 神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),并定義前向傳播的過程

          LeNet-5 網(wǎng)絡(luò)上文已經(jīng)搭建過了,由于 CIFAR10 數(shù)據(jù)集圖像是 RGB 三通道的,因此 LeNet-5 網(wǎng)絡(luò) C1 層卷積選擇的濾波器需要 3 通道,網(wǎng)絡(luò)其它結(jié)構(gòu)跟上文都是一樣的。

          class?LeNetRGB(nn.Module):
          ????def?__init__(self):
          ????????super(LeNetRGB,?self).__init__()
          ????????self.conv1?=?nn.Conv2d(3,?6,?5)???#?3表示輸入是3通道
          ????????self.relu?=?nn.ReLU()
          ????????self.maxpool1?=?nn.MaxPool2d(2,?2)
          ????????self.conv2?=?nn.Conv2d(6,?16,?5)
          ????????self.maxpool2?=?nn.MaxPool2d(2,?2)

          ????????self.fc1?=?nn.Linear(16*5*5,?120)
          ????????self.fc2?=?nn.Linear(120,?84)
          ????????self.fc3?=?nn.Linear(84,?10)

          ????def?forward(self,?x):
          ????????x?=?self.conv1(x)
          ????????x?=?self.relu(x)
          ????????x?=?self.maxpool1(x)
          ????????x?=?self.conv2(x)
          ????????x?=?self.maxpool2(x)
          ????????x?=?x.view(-1,?16*5*5)
          ????????x?=?F.relu(self.fc1(x))
          ????????x?=?F.relu(self.fc2(x))
          ????????x?=?self.fc3(x)
          ????????output?=?F.log_softmax(x,?dim=1)
          ????????return?output

          3 將定義好的網(wǎng)絡(luò)結(jié)構(gòu)搭載到 GPU/CPU,并定義優(yōu)化器

          使用 SGD(隨機(jī)梯度下降)優(yōu)化,學(xué)習(xí)率為 0.001,動(dòng)量為 0.9。

          #創(chuàng)建模型,部署gpu
          device?=?torch.device("cuda"?if?torch.cuda.is_available()?else?"cpu")
          model?=?LeNetRGB().to(device)
          #定義優(yōu)化器
          optimizer?=?optim.SGD(model.parameters(),?lr=0.01,?momentum=0.9)

          4 定義訓(xùn)練過程

          def?train_runner(model,?device,?trainloader,?optimizer,?epoch):
          ????#訓(xùn)練模型,?啟用?BatchNormalization?和?Dropout,?將BatchNormalization和Dropout置為True
          ????model.train()
          ????total?=?0
          ????correct?=0.0

          ????#enumerate迭代已加載的數(shù)據(jù)集,同時(shí)獲取數(shù)據(jù)和數(shù)據(jù)下標(biāo)
          ????for?i,?data?in?enumerate(trainloader,?0):
          ????????inputs,?labels?=?data
          ????????#把模型部署到device上
          ????????inputs,?labels?=?inputs.to(device),?labels.to(device)
          ????????#初始化梯度
          ????????optimizer.zero_grad()
          ????????#保存訓(xùn)練結(jié)果
          ????????outputs?=?model(inputs)
          ????????#計(jì)算損失和
          ????????#多分類情況通常使用cross_entropy(交叉熵?fù)p失函數(shù)),?而對(duì)于二分類問題,?通常使用sigmod
          ????????loss?=?F.cross_entropy(outputs,?labels)
          ????????#獲取最大概率的預(yù)測(cè)結(jié)果
          ????????#dim=1表示返回每一行的最大值對(duì)應(yīng)的列下標(biāo)
          ????????predict?=?outputs.argmax(dim=1)
          ????????total?+=?labels.size(0)
          ????????correct?+=?(predict?==?labels).sum().item()
          ????????#反向傳播
          ????????loss.backward()
          ????????#更新參數(shù)
          ????????optimizer.step()
          ????????if?i?%?1000?==?0:
          ????????????#loss.item()表示當(dāng)前l(fā)oss的數(shù)值
          ????????????print("Train?Epoch{}?\t?Loss:?{:.6f},?accuracy:?{:.6f}%".format(epoch,?loss.item(),?100*(correct/total)))
          ????????????Loss.append(loss.item())
          ????????????Accuracy.append(correct/total)
          ????return?loss.item(),?correct/total

          5 定義測(cè)試過程

          def?test_runner(model,?device,?testloader):
          ????#模型驗(yàn)證,?必須要寫,?否則只要有輸入數(shù)據(jù),?即使不訓(xùn)練,?它也會(huì)改變權(quán)值
          ????#因?yàn)檎{(diào)用eval()將不啟用?BatchNormalization?和?Dropout,?BatchNormalization和Dropout置為False
          ????model.eval()
          ????#統(tǒng)計(jì)模型正確率,?設(shè)置初始值
          ????correct?=?0.0
          ????test_loss?=?0.0
          ????total?=?0
          ????#torch.no_grad將不會(huì)計(jì)算梯度,?也不會(huì)進(jìn)行反向傳播
          ????with?torch.no_grad():
          ????????for?data,?label?in?testloader:
          ????????????data,?label?=?data.to(device),?label.to(device)
          ????????????output?=?model(data)
          ????????????test_loss?+=?F.cross_entropy(output,?label).item()
          ????????????predict?=?output.argmax(dim=1)
          ????????????#計(jì)算正確數(shù)量
          ????????????total?+=?label.size(0)
          ????????????correct?+=?(predict?==?label).sum().item()
          ????????#計(jì)算損失值
          ????????print("test_avarage_loss:?{:.6f},?accuracy:?{:.6f}%".format(test_loss/total,?100*(correct/total)))

          6 運(yùn)行

          #調(diào)用
          epoch?=?20
          Loss?=?[]
          Accuracy?=?[]
          for?epoch?in?range(1,?epoch+1):
          ????print("start_time",time.strftime('%Y-%m-%d?%H:%M:%S',time.localtime(time.time())))
          ????loss,?acc?=?train_runner(model,?device,?trainloader,?optimizer,?epoch)
          ????Loss.append(loss)
          ????Accuracy.append(acc)
          ????test_runner(model,?device,?testloader)
          ????print("end_time:?",time.strftime('%Y-%m-%d?%H:%M:%S',time.localtime(time.time())),'\n')

          print('Finished?Training')
          plt.subplot(2,1,1)
          plt.plot(Loss)
          plt.title('Loss')
          plt.show()
          plt.subplot(2,1,2)
          plt.plot(Accuracy)
          plt.title('Accuracy')
          plt.show()

          經(jīng)歷 20 次 epoch 迭代訓(xùn)練之后:

          start_time 2021-11-27 22:29:09
          Train Epoch20 Loss: 0.659028, accuracy: 68.750000%
          test_avarage_loss: 0.030969, accuracy: 67.760000%
          end_time: ?2021-11-27 22:29:44

          訓(xùn)練集的 loss 曲線和 Accuracy 曲線變化如下:

          7 保存模型

          print(model)
          torch.save(model,?'./models/model-cifar10.pth')?#保存模型

          LeNet-5 的模型會(huì) print 出來,并將模型模型命令為 model-cifar10.pth 保存在固定目錄下。

          LeNetRGB(
          (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
          (relu): ReLU()
          (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
          (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (fc1): Linear(in_features=400, out_features=120, bias=True)
          (fc2): Linear(in_features=120, out_features=84, bias=True)
          (fc3): Linear(in_features=84, out_features=10, bias=True)
          )

          8 模型測(cè)試

          利用剛剛訓(xùn)練的模型進(jìn)行 CIFAR10 類型圖片的測(cè)試。

          from?PIL?import?Image
          import?numpy?as?np

          if?__name__?==?'__main__':
          ????device?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')
          ????model?=?torch.load('./models/model-cifar10.pth')?#加載模型
          ????model?=?model.to(device)
          ????model.eval()????#把模型轉(zhuǎn)為test模式
          ????
          ????#讀取要預(yù)測(cè)的圖片
          ????#?讀取要預(yù)測(cè)的圖片
          ????img?=?Image.open("./images/test_cifar10.png").convert('RGB')?#?讀取圖像
          ????#img.show()
          ????plt.imshow(img)?#?顯示圖片
          ????plt.axis('off')?#?不顯示坐標(biāo)軸
          ????plt.show()
          ????
          ????#?導(dǎo)入圖片,圖片擴(kuò)展后為[1,1,32,32]
          ????trans?=?transforms.Compose(
          ????????[
          ????????????#將圖片尺寸resize到32x32
          ????????????transforms.Resize((32,32)),
          ????????????transforms.ToTensor(),
          ????????????transforms.Normalize((0.5,?0.5,?0.5),?(0.5,?0.5,?0.5))
          ????????])
          ????img?=?trans(img)
          ????img?=?img.to(device)
          ????img?=?img.unsqueeze(0)??#圖片擴(kuò)展多一維,因?yàn)檩斎氲奖4娴哪P椭惺?維的[batch_size,通道,長(zhǎng),寬],而普通圖片只有三維,[通道,長(zhǎng),寬]
          ????
          ????#?預(yù)測(cè)?
          ????classes?=?('plane',?'car',?'bird',?'cat',?'deer',?'dog',?'frog',?'horse',?'ship',?'truck')
          ????output?=?model(img)
          ????prob?=?F.softmax(output,dim=1)?#prob是10個(gè)分類的概率
          ????print("概率:",prob)
          ????print(predict.item())
          ????value,?predicted?=?torch.max(output.data,?1)
          ????predict?=?output.argmax(dim=1)
          ????pred_class?=?classes[predicted.item()]
          ????print("預(yù)測(cè)類別:",pred_class)

          輸出:

          概率:tensor([[7.6907e-01, 3.3997e-03, 4.8003e-03, 4.2978e-05, 1.2168e-02, 6.8751e-06, 3.2019e-06, 1.6024e-04, 1.2705e-01, 8.3300e-02]],
          grad_fn=)
          5

          預(yù)測(cè)類別:plane

          模型預(yù)測(cè)結(jié)果正確!

          以上就是 PyTorch 構(gòu)建 LeNet-5 卷積神經(jīng)網(wǎng)!絡(luò)并用它來識(shí)別 CIFAR10 數(shù)據(jù)集的例子。全文的代碼都是可以順利運(yùn)行的,建議大家自己跑一邊。

          值得一提的是,針對(duì) MNIST 數(shù)據(jù)集和 CIFAR10 數(shù)據(jù)集,最大的不同就是 MNIST 是單通道的,CIFAR10 是三通道的,因此在構(gòu)建 LeNet-5 網(wǎng)絡(luò)的時(shí)候,C1層需要做不同的設(shè)置。至于輸入圖片尺寸不一樣,我們可以使用 transforms.Resize 方法統(tǒng)一縮放到 32x32 的尺寸大小。

          所有完整的代碼我都放在 GitHub 上,GitHub地址為:https://github.com/RedstoneWill/ObjectDetectionLearner/tree/main/LeNet-5

          如果覺得有用,就請(qǐng)分享到朋友圈吧!

          △點(diǎn)擊卡片關(guān)注極市平臺(tái),獲取最新CV干貨

          公眾號(hào)后臺(tái)回復(fù)“transformer”獲取最新Transformer綜述論文下載~


          極市干貨
          課程/比賽:珠港澳人工智能算法大賽保姆級(jí)零基礎(chǔ)人工智能教程
          算法trick目標(biāo)檢測(cè)比賽中的tricks集錦從39個(gè)kaggle競(jìng)賽中總結(jié)出來的圖像分割的Tips和Tricks
          技術(shù)綜述:一文弄懂各種loss function工業(yè)圖像異常檢測(cè)最新研究總結(jié)(2019-2020)


          #?CV技術(shù)社群邀請(qǐng)函?#

          △長(zhǎng)按添加極市小助手
          添加極市小助手微信(ID : cvmart4)

          備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測(cè)-深圳)


          即可申請(qǐng)加入極市目標(biāo)檢測(cè)/圖像分割/工業(yè)檢測(cè)/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群


          每月大咖直播分享、真實(shí)項(xiàng)目需求對(duì)接、求職內(nèi)推、算法競(jìng)賽、干貨資訊匯總、與?10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動(dòng)交流~



          覺得有用麻煩給個(gè)在看啦~??
          瀏覽 43
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  人人操人人干人人射 | 自拍偷拍网 | 久久色婷婷 | 久久婷婷免费视频 | 欧美成人娱乐视频免费 |