我用 PyTorch 復(fù)現(xiàn)了 LeNet-5 神經(jīng)網(wǎng)絡(luò)(CIFAR10 數(shù)據(jù)集篇)!

極市導(dǎo)讀
?大家好,我是紅色石頭!今天我們將使用 Pytorch 來繼續(xù)實(shí)現(xiàn) LeNet-5 模型,并用它來解決 CIFAR10 數(shù)據(jù)集的識(shí)別。?>>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺的最前沿
正文開始!
LeNet-5 網(wǎng)絡(luò)本是用來識(shí)別 MNIST 數(shù)據(jù)集的,下面我們來將 LeNet-5 應(yīng)用到一個(gè)比較復(fù)雜的例子,識(shí)別 CIFAR-10 數(shù)據(jù)集。
CIFAR-10 是由 Hinton 的學(xué)生 Alex Krizhevsky 和 Ilya Sutskever 整理的一個(gè)用于識(shí)別普適物體的小型數(shù)據(jù)集。一共包含 10 個(gè)類別的 RGB 彩色圖 片:飛機(jī)( airlane )、汽車( automobile )、鳥類( bird )、貓( cat )、鹿( deer )、狗( dog )、蛙類( frog )、馬( horse )、船( ship )和卡車( truck )。圖片的尺寸為 32×32 ,數(shù)據(jù)集中一共有 50000 張訓(xùn)練圄片和 10000 張測(cè)試圖片。
CIFAR-10 的圖片樣例如圖所示。

1 下載并加載數(shù)據(jù),并做出一定的預(yù)先處理
pipline_train?=?transforms.Compose([
????#隨機(jī)旋轉(zhuǎn)圖片
????transforms.RandomHorizontalFlip(),
????#將圖片尺寸resize到32x32
????transforms.Resize((32,32)),
????#將圖片轉(zhuǎn)化為Tensor格式
????transforms.ToTensor(),
????#正則化(當(dāng)模型出現(xiàn)過擬合的情況時(shí),用來降低模型的復(fù)雜度)
????transforms.Normalize((0.5,?0.5,?0.5),?(0.5,?0.5,?0.5))
])
pipline_test?=?transforms.Compose([
????#將圖片尺寸resize到32x32
????transforms.Resize((32,32)),
????transforms.ToTensor(),
????transforms.Normalize((0.5,?0.5,?0.5),?(0.5,?0.5,?0.5))
])
#下載數(shù)據(jù)集
train_set?=?datasets.CIFAR10(root="./data/CIFAR10",?train=True,?download=True,?transform=pipline_train)
test_set?=?datasets.CIFAR10(root="./data/CIFAR10",?train=False,?download=True,?transform=pipline_test)
#加載數(shù)據(jù)集
trainloader?=?torch.utils.data.DataLoader(train_set,?batch_size=64,?shuffle=True)
testloader?=?torch.utils.data.DataLoader(test_set,?batch_size=32,?shuffle=False)
#?類別信息也是需要我們給定的
classes?=?('plane',?'car',?'bird',?'cat','deer',?'dog',?'frog',?'horse',?'ship',?'truck')
2 搭建 LeNet-5 神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),并定義前向傳播的過程
LeNet-5 網(wǎng)絡(luò)上文已經(jīng)搭建過了,由于 CIFAR10 數(shù)據(jù)集圖像是 RGB 三通道的,因此 LeNet-5 網(wǎng)絡(luò) C1 層卷積選擇的濾波器需要 3 通道,網(wǎng)絡(luò)其它結(jié)構(gòu)跟上文都是一樣的。
class?LeNetRGB(nn.Module):
????def?__init__(self):
????????super(LeNetRGB,?self).__init__()
????????self.conv1?=?nn.Conv2d(3,?6,?5)???#?3表示輸入是3通道
????????self.relu?=?nn.ReLU()
????????self.maxpool1?=?nn.MaxPool2d(2,?2)
????????self.conv2?=?nn.Conv2d(6,?16,?5)
????????self.maxpool2?=?nn.MaxPool2d(2,?2)
????????self.fc1?=?nn.Linear(16*5*5,?120)
????????self.fc2?=?nn.Linear(120,?84)
????????self.fc3?=?nn.Linear(84,?10)
????def?forward(self,?x):
????????x?=?self.conv1(x)
????????x?=?self.relu(x)
????????x?=?self.maxpool1(x)
????????x?=?self.conv2(x)
????????x?=?self.maxpool2(x)
????????x?=?x.view(-1,?16*5*5)
????????x?=?F.relu(self.fc1(x))
????????x?=?F.relu(self.fc2(x))
????????x?=?self.fc3(x)
????????output?=?F.log_softmax(x,?dim=1)
????????return?output
3 將定義好的網(wǎng)絡(luò)結(jié)構(gòu)搭載到 GPU/CPU,并定義優(yōu)化器
使用 SGD(隨機(jī)梯度下降)優(yōu)化,學(xué)習(xí)率為 0.001,動(dòng)量為 0.9。
#創(chuàng)建模型,部署gpu
device?=?torch.device("cuda"?if?torch.cuda.is_available()?else?"cpu")
model?=?LeNetRGB().to(device)
#定義優(yōu)化器
optimizer?=?optim.SGD(model.parameters(),?lr=0.01,?momentum=0.9)
4 定義訓(xùn)練過程
def?train_runner(model,?device,?trainloader,?optimizer,?epoch):
????#訓(xùn)練模型,?啟用?BatchNormalization?和?Dropout,?將BatchNormalization和Dropout置為True
????model.train()
????total?=?0
????correct?=0.0
????#enumerate迭代已加載的數(shù)據(jù)集,同時(shí)獲取數(shù)據(jù)和數(shù)據(jù)下標(biāo)
????for?i,?data?in?enumerate(trainloader,?0):
????????inputs,?labels?=?data
????????#把模型部署到device上
????????inputs,?labels?=?inputs.to(device),?labels.to(device)
????????#初始化梯度
????????optimizer.zero_grad()
????????#保存訓(xùn)練結(jié)果
????????outputs?=?model(inputs)
????????#計(jì)算損失和
????????#多分類情況通常使用cross_entropy(交叉熵?fù)p失函數(shù)),?而對(duì)于二分類問題,?通常使用sigmod
????????loss?=?F.cross_entropy(outputs,?labels)
????????#獲取最大概率的預(yù)測(cè)結(jié)果
????????#dim=1表示返回每一行的最大值對(duì)應(yīng)的列下標(biāo)
????????predict?=?outputs.argmax(dim=1)
????????total?+=?labels.size(0)
????????correct?+=?(predict?==?labels).sum().item()
????????#反向傳播
????????loss.backward()
????????#更新參數(shù)
????????optimizer.step()
????????if?i?%?1000?==?0:
????????????#loss.item()表示當(dāng)前l(fā)oss的數(shù)值
????????????print("Train?Epoch{}?\t?Loss:?{:.6f},?accuracy:?{:.6f}%".format(epoch,?loss.item(),?100*(correct/total)))
????????????Loss.append(loss.item())
????????????Accuracy.append(correct/total)
????return?loss.item(),?correct/total
5 定義測(cè)試過程
def?test_runner(model,?device,?testloader):
????#模型驗(yàn)證,?必須要寫,?否則只要有輸入數(shù)據(jù),?即使不訓(xùn)練,?它也會(huì)改變權(quán)值
????#因?yàn)檎{(diào)用eval()將不啟用?BatchNormalization?和?Dropout,?BatchNormalization和Dropout置為False
????model.eval()
????#統(tǒng)計(jì)模型正確率,?設(shè)置初始值
????correct?=?0.0
????test_loss?=?0.0
????total?=?0
????#torch.no_grad將不會(huì)計(jì)算梯度,?也不會(huì)進(jìn)行反向傳播
????with?torch.no_grad():
????????for?data,?label?in?testloader:
????????????data,?label?=?data.to(device),?label.to(device)
????????????output?=?model(data)
????????????test_loss?+=?F.cross_entropy(output,?label).item()
????????????predict?=?output.argmax(dim=1)
????????????#計(jì)算正確數(shù)量
????????????total?+=?label.size(0)
????????????correct?+=?(predict?==?label).sum().item()
????????#計(jì)算損失值
????????print("test_avarage_loss:?{:.6f},?accuracy:?{:.6f}%".format(test_loss/total,?100*(correct/total)))
6 運(yùn)行
#調(diào)用
epoch?=?20
Loss?=?[]
Accuracy?=?[]
for?epoch?in?range(1,?epoch+1):
????print("start_time",time.strftime('%Y-%m-%d?%H:%M:%S',time.localtime(time.time())))
????loss,?acc?=?train_runner(model,?device,?trainloader,?optimizer,?epoch)
????Loss.append(loss)
????Accuracy.append(acc)
????test_runner(model,?device,?testloader)
????print("end_time:?",time.strftime('%Y-%m-%d?%H:%M:%S',time.localtime(time.time())),'\n')
print('Finished?Training')
plt.subplot(2,1,1)
plt.plot(Loss)
plt.title('Loss')
plt.show()
plt.subplot(2,1,2)
plt.plot(Accuracy)
plt.title('Accuracy')
plt.show()
經(jīng)歷 20 次 epoch 迭代訓(xùn)練之后:
start_time 2021-11-27 22:29:09
Train Epoch20 Loss: 0.659028, accuracy: 68.750000%
test_avarage_loss: 0.030969, accuracy: 67.760000%
end_time: ?2021-11-27 22:29:44
訓(xùn)練集的 loss 曲線和 Accuracy 曲線變化如下:

7 保存模型
print(model)
torch.save(model,?'./models/model-cifar10.pth')?#保存模型
LeNet-5 的模型會(huì) print 出來,并將模型模型命令為 model-cifar10.pth 保存在固定目錄下。
LeNetRGB(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(relu): ReLU()
(maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
8 模型測(cè)試
利用剛剛訓(xùn)練的模型進(jìn)行 CIFAR10 類型圖片的測(cè)試。
from?PIL?import?Image
import?numpy?as?np
if?__name__?==?'__main__':
????device?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')
????model?=?torch.load('./models/model-cifar10.pth')?#加載模型
????model?=?model.to(device)
????model.eval()????#把模型轉(zhuǎn)為test模式
????
????#讀取要預(yù)測(cè)的圖片
????#?讀取要預(yù)測(cè)的圖片
????img?=?Image.open("./images/test_cifar10.png").convert('RGB')?#?讀取圖像
????#img.show()
????plt.imshow(img)?#?顯示圖片
????plt.axis('off')?#?不顯示坐標(biāo)軸
????plt.show()
????
????#?導(dǎo)入圖片,圖片擴(kuò)展后為[1,1,32,32]
????trans?=?transforms.Compose(
????????[
????????????#將圖片尺寸resize到32x32
????????????transforms.Resize((32,32)),
????????????transforms.ToTensor(),
????????????transforms.Normalize((0.5,?0.5,?0.5),?(0.5,?0.5,?0.5))
????????])
????img?=?trans(img)
????img?=?img.to(device)
????img?=?img.unsqueeze(0)??#圖片擴(kuò)展多一維,因?yàn)檩斎氲奖4娴哪P椭惺?維的[batch_size,通道,長(zhǎng),寬],而普通圖片只有三維,[通道,長(zhǎng),寬]
????
????#?預(yù)測(cè)?
????classes?=?('plane',?'car',?'bird',?'cat',?'deer',?'dog',?'frog',?'horse',?'ship',?'truck')
????output?=?model(img)
????prob?=?F.softmax(output,dim=1)?#prob是10個(gè)分類的概率
????print("概率:",prob)
????print(predict.item())
????value,?predicted?=?torch.max(output.data,?1)
????predict?=?output.argmax(dim=1)
????pred_class?=?classes[predicted.item()]
????print("預(yù)測(cè)類別:",pred_class)

輸出:
概率:tensor([[7.6907e-01, 3.3997e-03, 4.8003e-03, 4.2978e-05, 1.2168e-02, 6.8751e-06, 3.2019e-06, 1.6024e-04, 1.2705e-01, 8.3300e-02]],
grad_fn=)
5
預(yù)測(cè)類別:plane
模型預(yù)測(cè)結(jié)果正確!
以上就是 PyTorch 構(gòu)建 LeNet-5 卷積神經(jīng)網(wǎng)!絡(luò)并用它來識(shí)別 CIFAR10 數(shù)據(jù)集的例子。全文的代碼都是可以順利運(yùn)行的,建議大家自己跑一邊。
值得一提的是,針對(duì) MNIST 數(shù)據(jù)集和 CIFAR10 數(shù)據(jù)集,最大的不同就是 MNIST 是單通道的,CIFAR10 是三通道的,因此在構(gòu)建 LeNet-5 網(wǎng)絡(luò)的時(shí)候,C1層需要做不同的設(shè)置。至于輸入圖片尺寸不一樣,我們可以使用 transforms.Resize 方法統(tǒng)一縮放到 32x32 的尺寸大小。
所有完整的代碼我都放在 GitHub 上,GitHub地址為:https://github.com/RedstoneWill/ObjectDetectionLearner/tree/main/LeNet-5
如果覺得有用,就請(qǐng)分享到朋友圈吧!
公眾號(hào)后臺(tái)回復(fù)“transformer”獲取最新Transformer綜述論文下載~

#?CV技術(shù)社群邀請(qǐng)函?#

備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測(cè)-深圳)
即可申請(qǐng)加入極市目標(biāo)檢測(cè)/圖像分割/工業(yè)檢測(cè)/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群
每月大咖直播分享、真實(shí)項(xiàng)目需求對(duì)接、求職內(nèi)推、算法競(jìng)賽、干貨資訊匯總、與?10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動(dòng)交流~

