Python 教你如何給圖像分類

文 |?潮汐
來源:Python 技術(shù)「ID: pythonall」

在日常生活中總是有給圖像分類的場景,比如垃圾分類、不同場景的圖像分類等;今天的文章主要是基于圖像識別場景進行模型構(gòu)建。圖像識別是通過 Python深度學(xué)習(xí)來進行模型訓(xùn)練,再使用模型對上傳的電子表單進行自動審核與比對后反饋相應(yīng)的結(jié)果。主要是利用 Python Torchvision 來構(gòu)造模型,Torchvision 服務(wù)于Pytorch 深度學(xué)習(xí)框架,主要是用來生成圖片、視頻數(shù)據(jù)集以及訓(xùn)練模型。
模型構(gòu)建
構(gòu)建模型為了直觀,需要使用 Jupyter notebook 進行模型的構(gòu)建,Jupyter notebook 的安裝及使用詳見公眾號歷史文章 一文吃透 Jupyter Notebook,進入 JupyterNotebook 頁面后即可進行編輯。詳細(xì)頁面如下:

導(dǎo)入所需包
圖像識別需要用到深度學(xué)習(xí)相關(guān)模塊,所以需要導(dǎo)入相應(yīng)的包,具體導(dǎo)入的包如下:
%reload_ext?autoreload
%autoreload?2
import?torch
from?torch.utils.data?import?DataLoader
from?torchvision.datasets?import?ImageFolder
from?torchvision?import?transforms?as?tfs
from?torchvision?import?models
from?torch?import?nn
import?matplotlib.pyplot?as?plt
%matplotlib?inline
import?os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
是否使用 GPU
模型的訓(xùn)練主要方式是基于 GPU 或者 CPU 訓(xùn)練,在沒有 GPU 的條件下就在 CPU 下進行訓(xùn)練,模型的訓(xùn)練需要花費一定的時間,訓(xùn)練時長根據(jù)訓(xùn)練集的數(shù)據(jù)和硬件性能而定,訓(xùn)練結(jié)果精確性根據(jù)數(shù)據(jù)的多少和準(zhǔn)確性而且,深度學(xué)習(xí)需要大量的素材才能判斷出精確的結(jié)果,所以需要申明使用 CPU 進行訓(xùn)練:
#?是否使用GPU
use_gpu?=?False
數(shù)據(jù)增強
將拿到的數(shù)據(jù)進行訓(xùn)練集的數(shù)據(jù)預(yù)處理并設(shè)置訓(xùn)練分層數(shù),再將拿到的圖片進行水平翻轉(zhuǎn)后對圖片進行剪裁, 剪裁后將圖片進行隨機翻轉(zhuǎn),增強隨機對比度以及圖片顏色變化
#?數(shù)據(jù)增強
train_transform?=?tfs.Compose([
????#?訓(xùn)練集的數(shù)據(jù)預(yù)處理
????tfs.Resize([224,?224]),
????tfs.RandomHorizontalFlip(),
????tfs.RandomCrop(128),
????tfs.ToTensor(),
????tfs.Normalize([0.5,0.5,0.5],?[0.5,0.5,0.5])
])
test_transform?=?tfs.Compose([
????tfs.Resize([224,224]),
#?????tfs.RandomCrop(128),
????tfs.ToTensor(),
????tfs.Normalize([0.5,0.5,0.5],?[0.5,0.5,0.5])
])
#?每一個batch的數(shù)據(jù)集數(shù)目
batch_size?=?10
數(shù)據(jù)集和驗證集準(zhǔn)備
模型訓(xùn)練需要準(zhǔn)備數(shù)據(jù)集和驗證集,只有足夠的照片才能得到更精準(zhǔn)的答案。訓(xùn)練集和驗證集部分代碼如下:
#?構(gòu)建訓(xùn)練集和驗證集
#?
train_set?=?ImageFolder('./dataset1/train',?train_transform)
train_data?=?DataLoader(train_set,?batch_size,?shuffle=True,?num_workers=0)
valid_set?=?ImageFolder('./dataset1/valid',?test_transform)
valid_data?=?DataLoader(valid_set,?2*batch_size,?shuffle=False,?num_workers=0)
train_set.class_to_idx
len(valid_data)
#?數(shù)據(jù)集準(zhǔn)備
try:
????if?iter(train_data).next()[0].shape[0]?==?batch_size?and?\
????iter(valid_data).next()[0].shape[0]?==?2*batch_size:
????????print('Dataset?is?ready!')
????else:
????????print('Not?success,?maybe?the?batch?size?is?wrong')
except:
????print('not?success,?image?transform?is?wrong!')
模型構(gòu)建并準(zhǔn)備模型
#?構(gòu)建模型
def?get_model():
????model?=?models.resnet50(pretrained=True)
????model.fc?=?nn.Linear(2048,?3)
????return?model
try:
????model?=?get_model()
????with?torch.no_grad():
????????scorce?=?model(iter(train_data).next()[0])
????????print(scorce.shape[0],?scorce.shape[1])
????if?scorce.shape[0]?==?batch_size?and?scorce.shape[1]?==?3:
????????print('Model?is?ready!')
????else:
????????print('Model?is?failed!')
except:
????print('model?is?wrong')
if?use_gpu:
????model?=?model.cuda()
構(gòu)建模型優(yōu)化器
#?構(gòu)建loss函數(shù)和優(yōu)化器
criterion?=?nn.CrossEntropyLoss()
optimizer?=?torch.optim.Adam(model.parameters(),?lr?=?1e-4)
#?訓(xùn)練的epoches數(shù)目
max_epoch?=?20
模型訓(xùn)練和訓(xùn)練結(jié)果可視化
數(shù)據(jù)集和訓(xùn)練集準(zhǔn)備好后進行模型訓(xùn)練和訓(xùn)練結(jié)果可視化,部分代碼如下:
def?train(model,?train_data,?valid_data,?max_epoch,?criterion,?optimizer):
????freq_print?=?int(len(train_data)?/?3)
????
????metric_log?=?dict()
????metric_log['train_loss']?=?list()
????metric_log['train_acc']?=?list()
????if?valid_data?is?not?None:
????????metric_log['valid_loss']?=?list()
????????metric_log['valid_acc']?=?list()
????
????for?e?in?range(max_epoch):
????????model.train()
????????running_loss?=?0
????????running_acc?=?0
????????for?i,?data?in?enumerate(train_data,?1):
????????????img,?label?=?data
????????????if?use_gpu:
????????????????img?=?img.cuda()
????????????????label?=?label.cuda()
????????????#?forward前向傳播
????????????out?=?model(img)
????????????#?計算誤差
????????????loss?=?criterion(out,?label.long())
????????????#?反向傳播,更新參數(shù)
????????????optimizer.zero_grad()
????????????loss.backward()
????????????optimizer.step()
????????????#?計算準(zhǔn)確率
????????????_,?pred?=?out.max(1)
????????????num_correct?=?(pred?==?label.long()).sum().item()
????????????acc?=?num_correct/img.shape[0]
????????????running_loss?+=?loss.item()
????????????running_acc?+=acc
????????????if?i?%?freq_print?==?0:
????????????????print('[{}]/[{}],?train?loss:?{:.3f},?train?acc:?{:.3f}'?\
????????????????.format(i,?len(train_data),?running_loss?/?i,?running_acc?/?i))
????????
????????metric_log['train_loss'].append(running_loss?/?len(train_data))
????????metric_log['train_acc'].append(running_acc?/?len(train_data))
????????if?valid_data?is?not?None:
????????????model.eval()
????????????running_loss?=?0
????????????running_acc?=?0
????????????for?data?in?valid_data:
????????????????img,?label?=?data
????????????????if?use_gpu:
????????????????????img?=?img.cuda()
????????????????????label?=?label.cuda()
????????????????
????????????????#?forward前向傳播
????????????????out?=?model(img)
????????????????#?計算誤差
????????????????loss?=?criterion(out,?label.long())
????????????????#?計算準(zhǔn)確度
????????????????_,?pred?=?out.max(1)
????????????????num_correct?=?(pred==label.long()).sum().item()
????????????????acc?=?num_correct/img.shape[0]
????????????????running_loss?+=?loss.item()
????????????????running_acc?+=?acc
????????????metric_log['valid_loss'].append(running_loss/len(valid_data))
????????????metric_log['valid_acc'].append(running_acc/len(valid_data))
????????????print_str?=?'epoch:?{},?train?loss:?{:.3f},?train?acc:?{:.3f},?\
????????????valid?loss:?{:.3f},?valid?accuracy:?{:.3f}'.format(
????????????????????????e+1,?metric_log['train_loss'][-1],?metric_log['train_acc'][-1],
????????????????????????metric_log['valid_loss'][-1],?metric_log['valid_acc'][-1])
????????else:
????????????print_str?=?'epoch:?{},?train?loss:?{:.3f},?train?acc:?{:.3f}'.format(
????????????????e+1,
????????????????metric_log['train_loss'][-1],
????????????????metric_log['train_acc'][-1])
????????print(print_str)
????????
????#?可視化
????nrows?=?1
????ncols?=?2
????figsize=?(10,?5)
????_,?figs?=?plt.subplots(nrows,?ncols,?figsize=figsize)
????if?valid_data?is?not?None:
????????figs[0].plot(metric_log['train_loss'],?label='train?loss')
????????figs[0].plot(metric_log['valid_loss'],?label='valid?loss')
????????figs[0].axes.set_xlabel('loss')
????????figs[0].legend(loc='best')
????????figs[1].plot(metric_log['train_acc'],?label='train?acc')
????????figs[1].plot(metric_log['valid_acc'],?label='valid?acc')
????????figs[1].axes.set_xlabel('acc')
????????figs[1].legend(loc='best')
????else:
????????figs[0].plot(metric_log['train_loss'],?label='train?loss')
????????figs[0].axes.set_xlabel('loss')
????????figs[0].legend(loc='best')
????????figs[1].plot(metric_log['train_acc'],?label='train?acc')
????????figs[1].axes.set_xlabel('acc')
????????figs[1].legend(loc='best')
調(diào)參進行模型訓(xùn)練
#?用作調(diào)參
train(model,?train_data,?valid_data,?max_epoch,?criterion,?optimizer)
保存模型
#?保存模型
torch.save(model.state_dict(),?'./model/save_model2.pth')
總結(jié)
今天的文章主要是講圖像識別模型如何構(gòu)建。希望對大家有所幫助。
你安利到了嗎?
PS:公號內(nèi)回復(fù)「Python」即可進入Python 新手學(xué)習(xí)交流群,一起 100 天計劃!
老規(guī)矩,兄弟們還記得么,右下角的 “在看” 點一下,如果感覺文章內(nèi)容不錯的話,記得分享朋友圈讓更多的人知道!


【代碼獲取方式】
