基于卷積神經(jīng)網(wǎng)絡(luò)的手寫數(shù)字識別(附數(shù)據(jù)集+完整代碼+操作說明)
配置環(huán)境
使用環(huán)境:python3.8 平臺:Windows10 IDE:PyCharm
1.前言
手寫數(shù)字識別,作為機器視覺入門項目,無論是基于傳統(tǒng)的OpenCV方法還是基于目前火熱的深度學(xué)習(xí)、神經(jīng)網(wǎng)絡(luò)的方法都有這不錯的訓(xùn)練效果。當(dāng)然,這個項目也常常被作為大學(xué)/研究生階段的課程實驗。可惜的是,目前網(wǎng)絡(luò)上關(guān)于手寫數(shù)字識別的項目代碼很多,但是普遍不完整,對于初學(xué)者提出了不小的挑戰(zhàn)。為此,博主撰寫本文,無論你是希望借此完成課程實驗或者學(xué)習(xí)機器視覺,本文或許對你都有幫助。
2.問題描述
本文針對的問題為:隨機在黑板上寫一個數(shù)字,通過調(diào)用電腦攝像頭實時檢測出數(shù)字是0-9哪個數(shù)字
3.解決方案
基于Python的深度學(xué)習(xí)方法:
檢測流程如下:
4.實現(xiàn)步驟
4.1數(shù)據(jù)集選擇
手寫數(shù)字識別經(jīng)典數(shù)據(jù)集:本文數(shù)據(jù)集選擇的FishionMint數(shù)據(jù)集中的t10k,共含有一萬張28*28的手寫圖片(二值圖片)
數(shù)據(jù)集下載地址見:
4.2構(gòu)建網(wǎng)絡(luò)
采用Resnt(殘差網(wǎng)絡(luò)),殘差網(wǎng)絡(luò)的優(yōu)勢在于:
更易捕捉模型細(xì)微波動- 更快的收斂速度
本文的網(wǎng)絡(luò)結(jié)構(gòu)如下圖所示,代碼見第五節(jié):
4.3訓(xùn)練網(wǎng)絡(luò)
本文設(shè)置訓(xùn)練次數(shù)為100個循環(huán),其實網(wǎng)絡(luò)的訓(xùn)練過程是這樣的:
給網(wǎng)絡(luò)模型“喂”數(shù)據(jù)(圖像+標(biāo)簽)- 網(wǎng)絡(luò)根據(jù)“喂”來的數(shù)據(jù)不斷自我修正權(quán)重- 本文一共“喂”100次1萬張圖像- RTX2070上耗時2h 訓(xùn)練結(jié)果如下: 
4.4測試網(wǎng)絡(luò)
隨機選取數(shù)據(jù)集中37張圖片進(jìn)行檢測- 正確率為36/37- 選取其中6張進(jìn)行展示 
4.5圖像預(yù)處理

全部采取傳統(tǒng)機器視覺的方法- 速度“飛快”,僅做以上操作處理速度高達(dá)200fps
4.6傳入網(wǎng)絡(luò)進(jìn)行計算
手寫0-9的數(shù)字除了3識別不了其余均能識別- 檢測速度高達(dá)60fps

5.代碼實現(xiàn)
本文所有代碼都已經(jīng)上傳至Github上?
https://github.com/Hurri-cane/Hand_wrtten/tree/master

5.1文件說明
dataset文件夾存放的是訓(xùn)練數(shù)據(jù)集- logs文件夾為訓(xùn)練結(jié)束后權(quán)重文件所在- real_img、real_img_resize、test_imgs為用來測試的圖片文件夾- 下面的py文件為本文代碼
5.2使用方法
按照博主的環(huán)境配置自己的Python環(huán)境
其中主要的包有:numpy、struct、matplotlib、OpenCV、Pytorch、torchvision、tqdm
5.3 訓(xùn)練模型
本文提供了訓(xùn)練好的模型,大家可以直接調(diào)用,已經(jīng)上傳至GitHub,如果不想訓(xùn)練的話,可以跳過訓(xùn)練這一步驟
下面是訓(xùn)練的流程:
打開hand_wrtten_train.py文件,點擊運行(博主使用的是PyCharm,大家根據(jù)自己喜好選擇IDLE即可)
值得注意的是,數(shù)據(jù)集路徑需要修改為自己的路徑,即這一段
訓(xùn)練過程沒報錯會出現(xiàn)以下顯示
訓(xùn)練得到的權(quán)重會保存在logs文件夾下
模型訓(xùn)練需要時間,此時等待訓(xùn)練結(jié)束即可(RTX2070上訓(xùn)練了1h左右)
5.4使用訓(xùn)練好的模型測試網(wǎng)絡(luò)
測試采用圖片進(jìn)行測試,代碼見main_pthoto.py文件,使用方法與上面訓(xùn)練代碼一直,代開后運行即可
同樣值得注意的是,main_pthoto.py文件中圖片路徑需要修改為自己的路徑,即這一段
以及predict.py文件中權(quán)重片路徑需要修改為自己在5.3步中訓(xùn)練得到的.pth文件路徑,如圖所示
運行結(jié)果如下
5.5調(diào)用攝像頭實時檢測
代碼存在于main.py文件下,使用方法和5.4節(jié)圖片檢測一致,修改predict.py文件中權(quán)重片路徑需要修改為自己在5.3步中訓(xùn)練得到的.pth文件路徑,如圖所示
再運行main.py文件即可,可以看到載入網(wǎng)絡(luò)模型后開始調(diào)用攝像頭,并開始檢測
6.附錄
在此附上本文核心代碼:hand_wrtten_train.py
#?author:Hurricane
#?date:??2020/11/4
#?E-mail:[email protected]
import?numpy?as?np
import?struct
import?matplotlib.pyplot?as?plt
import?cv2?as?cv
import?random
import?torch
from?torch?import?nn,?optim
import?torch.nn.functional?as?F
#?import?d2lzh_pytorch?as?d2l
import?time
from?tqdm?import?tqdm
#?訓(xùn)練集文件
train_images_idx3_ubyte_file?=?'F:/PyCharm/Practice/hand_wrtten/dataset/train-images.idx3-ubyte'
#?訓(xùn)練集標(biāo)簽文件
train_labels_idx1_ubyte_file?=?'F:/PyCharm/Practice/hand_wrtten/dataset/train-labels.idx1-ubyte'
#?測試集文件
test_images_idx3_ubyte_file?=?'F:/PyCharm/Practice/hand_wrtten/dataset/t10k-images.idx3-ubyte'
#?測試集標(biāo)簽文件
test_labels_idx1_ubyte_file?=?'F:/PyCharm/Practice/hand_wrtten/dataset/t10k-labels.idx1-ubyte'
#?讀取數(shù)據(jù)部分
def?decode_idx3_ubyte(idx3_ubyte_file):
????bin_data?=?open(idx3_ubyte_file,?'rb').read()
????offset?=?0
????fmt_header?=?'>iiii'??#?因為數(shù)據(jù)結(jié)構(gòu)中前4行的數(shù)據(jù)類型都是32位整型,所以采用i格式,但我們需要讀取前4行數(shù)據(jù),所以需要4個i。我們后面會看到標(biāo)簽集中,只使用2個ii。
????magic_number,?num_images,?num_rows,?num_cols?=?struct.unpack_from(fmt_header,?bin_data,?offset)
????print('圖片數(shù)量:?%d張,?圖片大小:?%d*%d'?%?(num_images,?num_rows,?num_cols))
????#?解析數(shù)據(jù)集
????image_size?=?num_rows?*?num_cols
????offset?+=?struct.calcsize(fmt_header)??#?獲得數(shù)據(jù)在緩存中的指針位置,從前面介紹的數(shù)據(jù)結(jié)構(gòu)可以看出,讀取了前4行之后,指針位置(即偏移位置offset)指向0016。
????print(offset)
????fmt_image?=?'>'?+?str(
????????image_size)?+?'B'??#?圖像數(shù)據(jù)像素值的類型為unsigned char型,對應(yīng)的format格式為B。這里還有加上圖像大小784,是為了讀取784個B格式數(shù)據(jù),如果沒有則只會讀取一個值(即一副圖像中的一個像素值)
????print(fmt_image,?offset,?struct.calcsize(fmt_image))
????images?=?np.empty((num_images,?28,?28))
????#?plt.figure()
????for?i?in?tqdm(range(num_images)):
????????image?=?np.array(struct.unpack_from(fmt_image,?bin_data,?offset)).reshape((num_rows,?num_cols)).astype(np.uint8)
????????#?images[i]?=?cv.resize(image,?(96,?96))
????????images[i]?=?image
????????#?print(images[i])
????????offset?+=?struct.calcsize(fmt_image)
????return?images
def?decode_idx1_ubyte(idx1_ubyte_file):
????bin_data?=?open(idx1_ubyte_file,?'rb').read()
????offset?=?0
????fmt_header?=?'>ii'
????magic_number,?num_images?=?struct.unpack_from(fmt_header,?bin_data,?offset)
????print('圖片數(shù)量:?%d張'?%?(num_images))
????#?解析數(shù)據(jù)集
????offset?+=?struct.calcsize(fmt_header)
????fmt_image?=?'>B'
????labels?=?np.empty(num_images)
????for?i?in?tqdm(range(num_images)):
????????labels[i]?=?struct.unpack_from(fmt_image,?bin_data,?offset)[0]
????????offset?+=?struct.calcsize(fmt_image)
????return?labels
def?load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file):
????return?decode_idx3_ubyte(idx_ubyte_file)
def?load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file):
????return?decode_idx1_ubyte(idx_ubyte_file)
def?load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file):
????return?decode_idx3_ubyte(idx_ubyte_file)
def?load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file):
????return?decode_idx1_ubyte(idx_ubyte_file)
#?構(gòu)建網(wǎng)絡(luò)部分
class?Residual(nn.Module):??#?本類已保存在d2lzh_pytorch包中方便以后使用
????def?__init__(self,?in_channels,?out_channels,?use_1x1conv=False,?stride=1):
????????super(Residual,?self).__init__()
????????self.conv1?=?nn.Conv2d(in_channels,?out_channels,?kernel_size=3,?padding=1,?stride=stride)
????????self.conv2?=?nn.Conv2d(out_channels,?out_channels,?kernel_size=3,?padding=1)
????????if?use_1x1conv:
????????????self.conv3?=?nn.Conv2d(in_channels,?out_channels,?kernel_size=1,?stride=stride)
????????else:
????????????self.conv3?=?None
????????self.bn1?=?nn.BatchNorm2d(out_channels)
????????self.bn2?=?nn.BatchNorm2d(out_channels)
????def?forward(self,?X):
????????Y?=?F.relu(self.bn1(self.conv1(X)))
????????Y?=?self.bn2(self.conv2(Y))
????????if?self.conv3:
????????????X?=?self.conv3(X)
????????return?F.relu(Y?+?X)
class?GlobalAvgPool2d(nn.Module):
????#?全局平均池化層可通過將池化窗口形狀設(shè)置成輸入的高和寬實現(xiàn)
????def?__init__(self):
????????super(GlobalAvgPool2d,?self).__init__()
????def?forward(self,?x):
????????return?F.avg_pool2d(x,?kernel_size=x.size()[2:])
def?resnet_block(in_channels,?out_channels,?num_residuals,?first_block=False):
????#?num_residuals:殘差數(shù)
????if?first_block:
????????assert?in_channels?==?out_channels??#?第一個模塊的通道數(shù)同輸入通道數(shù)一致
????blk?=?[]
????for?i?in?range(num_residuals):
????????if?i?==?0?and?not?first_block:
????????????blk.append(Residual(in_channels,?out_channels,?use_1x1conv=True,?stride=2))
????????else:
????????????blk.append(Residual(out_channels,?out_channels))
????return?nn.Sequential(*blk)
def?evaluate_accuracy(img,?label,?net):
????device?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')
????acc_sum,?n?=?0.0,?0
????with?torch.no_grad():
????????X?=?torch.unsqueeze(img,?1)
????????if?isinstance(net,?torch.nn.Module):
????????????net.eval()??#?評估模式,?這會關(guān)閉dropout
????????????acc_sum?+=?(net(X.to(device)).argmax(dim=1)?==?label.to(device)).float().sum().cpu().item()
????????????net.train()??#?改回訓(xùn)練模式
????????else:??#?自定義的模型,?3.13節(jié)之后不會用到,?不考慮GPU
????????????if?('is_training'?in?net.__code__.co_varnames):??#?如果有is_training這個參數(shù)
????????????????#?將is_training設(shè)置成False
????????????????acc_sum?+=?(net(X,?is_training=False).argmax(dim=1)?==?label).float().sum().item()
????????????else:
????????????????acc_sum?+=?(net(X).argmax(dim=1)?==?label).float().sum().item()
????????n?+=?label.shape[0]
????return?acc_sum?/?n
class?FlattenLayer(torch.nn.Module):
????def?__init__(self):
????????super(FlattenLayer,?self).__init__()
????def?forward(self,?x):?#?x?shape:?(batch,?*,?*,?...)
????????return?x.view(x.shape[0],?-1)
if?__name__?==?'__main__':
????print("train:")
????train_images_org?=?load_train_images().astype(np.float32)
????train_labels_org?=?load_train_labels().astype(np.int64)
????print("test")
????test_images?=?load_test_images().astype(np.float32)[0:1000]
????test_labels?=?load_test_labels().astype(np.int64)[0:1000]
????#?數(shù)據(jù)轉(zhuǎn)換為Tensor
????train_images?=?torch.from_numpy(train_images_org)
????train_labels?=?torch.from_numpy(train_labels_org)
????test_images?=?torch.from_numpy(test_images)
????test_labels?=?torch.from_numpy(test_labels)
????#?test_images?=?load_test_images()
????#?test_labels?=?load_test_labels()
????#?查看前十個數(shù)據(jù)及其標(biāo)簽以讀取是否正確
????for?i?in?range(5):
????????j?=?random.randint(0,?60000)
????????print("now,?show?the?number?of?image[{}]:".format(j),?int(train_labels_org[j]))
????????img?=?train_images_org[j]
????????img?=?cv.resize(img,?(600,?600))
????????cv.imshow("image",?img)
????????cv.waitKey(0)
????cv.destroyAllWindows()
????print('all?done!')
????print("*"?*?50)
????#?ResNet模型
????net?=?nn.Sequential(
????????nn.Conv2d(1,?64,?kernel_size=7,?stride=2,?padding=3),
????????nn.BatchNorm2d(64),
????????nn.ReLU(),
????????nn.MaxPool2d(kernel_size=3,?stride=2,?padding=1))
????net.add_module("resnet_block1",?resnet_block(64,?64,?2,?first_block=True))
????net.add_module("resnet_block2",?resnet_block(64,?128,?2))
????net.add_module("resnet_block3",?resnet_block(128,?256,?2))
????net.add_module("global_avg_pool",?GlobalAvgPool2d())??#?GlobalAvgPool2d的輸出:?(Batch,?512,?1,?1)
????net.add_module("fc",?nn.Sequential(FlattenLayer(),?nn.Linear(256,?10)))
????#?測試網(wǎng)絡(luò)
????X?=?torch.rand((1,?1,?28,?28))
????for?name,?layer?in?net.named_children():
????????X?=?layer(X)
????????print(name,?'?output?shape:/t',?X.shape)
????#?訓(xùn)練
????device?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')
????lr,?num_epochs?=?0.001,?100
????optimizer?=?torch.optim.Adam(net.parameters(),?lr=lr)
????batch_size?=?1000
????net?=?net.to(device)
????print("training?on?",?device)
????loss?=?torch.nn.CrossEntropyLoss()
????loop_times?=?round(60000?/?batch_size)
????train_acc_plot?=?[]
????test_acc_plot?=?[]
????loss_plot?=?[]
????for?epoch?in?range(num_epochs):
????????train_l_sum,?train_acc_sum,?n,?batch_count,?start?=?0.0,?0.0,?0,?0,?time.time()
????????for?i?in?tqdm(range(1,?loop_times)):
????????????x?=?train_images[(i?-?1)?*?batch_size:i?*?batch_size]
????????????y?=?train_labels[(i?-?1)?*?batch_size:i?*?batch_size]
????????????x?=?torch.unsqueeze(x,?1)??#?對齊維度
????????????X?=?x.to(device)
????????????y?=?y.to(device)
????????????y_hat?=?net(X)
????????????l?=?loss(y_hat,?y)
????????????optimizer.zero_grad()
????????????l.backward()
????????????optimizer.step()
????????????train_l_sum?+=?l.cpu().item()
????????????train_acc_sum?+=?(y_hat.argmax(dim=1)?==?y).sum().cpu().item()
????????????n?+=?y.shape[0]
????????????batch_count?+=?1
????????test_acc?=?evaluate_accuracy(test_images,?test_labels,?net)
????????print('epoch?%d,?loss?%.4f,?train?acc?%.3f,?test?acc?%.3f,?time?%.1f?sec'
??????????????%?(epoch?+?1,?train_l_sum?/?batch_count,?train_acc_sum?/?n,?test_acc,?time.time()?-?start))
????????torch.save(net.state_dict(),?'logs/Epoch%d-Loss%.4f-train_acc%.4f-test_acc%.4f.pth'?%?(
????????????(epoch?+?1),?train_l_sum?/?batch_count,?train_acc_sum?/?n,?test_acc))
????????print("save?successfully")
????????test_acc_plot.append(test_acc)
????????train_acc_plot.append(train_acc_sum?/?n)
????????loss_plot.append(train_l_sum?/?batch_count)
????x?=?range(0,100)
????plt.plot(x,test_acc_plot,'r')
????plt.plot(x,?train_acc_plot,?'g')
????plt.plot(x,?loss_plot,?'b')
????print("*"?*?50)
main_pthoto.py
#?author:Hurricane
#?date:??2020/11/6
#?E-mail:[email protected]
import?cv2?as?cv
import?numpy?as?np
import?os
from?Pre_treatment?import?get_number?as?g_n
import?predict?as?pt
from?time?import?time
from?Pre_treatment?import?softmax
net?=?pt.get_net()
orig_path?=?r"F:\PyCharm\Practice\hand_wrtten\real_img_resize"
img_list?=?os.listdir(orig_path)
#?img_path?=?r'F:\PyCharm\Practice\hand_wrtten\real_img\7.jpg'
for?img_name?in?img_list:
????since?=?time()
????img_path?=?os.path.join(orig_path,?img_name)
????img?=?cv.imread(img_path)
????img_bw?=?g_n(img)
????img_bw_c?=?img_bw.sum(axis=1)?/?255
????img_bw_r?=?img_bw.sum(axis=0)?/?255
????r_ind,?c_ind?=?[],?[]
????for?k,?r?in?enumerate(img_bw_r):
????????if?r?>=?5:
????????????r_ind.append(k)
????for?k,?c?in?enumerate(img_bw_c):
????????if?c?>=?5:
????????????c_ind.append(k)
????img_bw_sg?=?img_bw[?c_ind[0]:c_ind[-1],r_ind[0]:r_ind[-1]]
????leng_c?=?len(c_ind)
????leng_r?=?len(r_ind)
????side_len?=?leng_c?+?20
????add_r?=?int((side_len-leng_r)/2)
????img_bw_sg_bord?=?cv.copyMakeBorder(img_bw_sg,10,10,add_r,add_r,cv.BORDER_CONSTANT,value=[0,0,0])
????#?展示圖片
????cv.imshow("img",?img_bw)
????cv.imshow("img_sg",?img_bw_sg_bord)
????c?=?cv.waitKey(1)?&?0xff
????img_in?=?cv.resize(img_bw_sg_bord,?(28,?28))
????result_org?=?pt.predict(img_in,??net)
????result?=?softmax(result_org)
????best_result?=?result.argmax(dim=1).item()
????best_result_num?=?max(max(result)).cpu().detach().numpy()
????if?best_result_num?<=?0.5:
????????best_result?=?None
????#?顯示結(jié)果
????img_show?=?cv.resize(img,?(600,?600))
????end_predict?=?time()
????fps?=?np.ceil(1?/?(end_predict?-?since))
????font?=?cv.FONT_HERSHEY_SIMPLEX
????cv.putText(img_show,?"The?number?is:"?+?str(best_result),?(1,?30),?font,?1,?(0,?0,?255),?2)
????cv.putText(img_show,?"Probability?is:"?+?str(best_result_num),?(1,?60),?font,?1,?(0,?255,?0),?2)
????cv.putText(img_show,?"FPS:"?+?str(fps),?(1,?90),?font,?1,?(255,?0,?0),?2)
????cv.imshow("result",?img_show)
????cv.waitKey(1)
????print(result)
????print("*"?*?50)
????print("The?number?is:",?best_result)
main.py
#?author:Hurricane
#?date:??2020/11/6
#?E-mail:[email protected]
import?cv2?as?cv
import?numpy?as?np
import?os
from?Pre_treatment?import?get_number?as?g_n
from?Pre_treatment?import?get_roi
import?predict?as?pt
from?time?import?time
from?Pre_treatment?import?softmax
#?實時檢測視頻
capture?=?cv.VideoCapture(0,cv.CAP_DSHOW)
capture.set(3,?1920)
capture.set(4,?1080)
net?=?pt.get_net()
#?img_path?=?r'F:\PyCharm\Practice\hand_wrtten\real_img\7.jpg'
while?(True):
????ret,?frame?=?capture.read()
????since?=?time()
????if?ret:
????????#?frame?=?cv.imread(img_path)
????????img_bw?=?g_n(frame)
????????img_bw_sg?=?get_roi(img_bw)
????????#?展示圖片
????????cv.imshow("img",?img_bw_sg)
????????c?=?cv.waitKey(1)?&?0xff
????????if?c?==?27:
????????????capture.release()
????????????break
????????img_in?=?cv.resize(img_bw_sg,?(28,?28))
????????result_org?=?pt.predict(img_in,?net)
????????result?=?softmax(result_org)
????????best_result?=?result.argmax(dim=1).item()
????????best_result_num?=?max(max(result)).cpu().detach().numpy()
????????if?best_result_num?<=?0.5:
????????????best_result?=?None
????????#?顯示結(jié)果
????????img_show?=?cv.resize(frame,?(600,?600))
????????end_predict?=?time()
????????fps?=?round(1/(end_predict-since))
????????font?=?cv.FONT_HERSHEY_SIMPLEX
????????cv.putText(img_show,?"The?number?is:"?+?str(best_result),?(1,?30),?font,?1,?(0,?0,?255),?2)
????????cv.putText(img_show,?"Probability?is:"?+?str(best_result_num),?(1,?60),?font,?1,?(0,?255,?0),?2)
????????cv.putText(img_show,?"FPS:"?+?str(fps),?(1,?90),?font,?1,?(255,?0,?0),?2)
????????cv.imshow("result",?img_show)
????????cv.waitKey(1)
????????print(result)
????????print("*"?*?50)
????????print("The?number?is:",?best_result)
????else:
????????print("please?check?camera!")
????????break
Pre_treatment.py
#?author:Hurricane
#?date:??2020/11/6
#?E-mail:[email protected]
import?cv2?as?cv
import?numpy?as?np
import?os
def?get_number(img):
????img_gray?=?cv.cvtColor(img,?cv.COLOR_RGB2GRAY)
????img_gray_resize?=?cv.resize(img_gray,?(600,?600))
????ret,?img_bw?=?cv.threshold(img_gray_resize,?200,?255,?cv.THRESH_BINARY)
????kernel?=?np.ones((3,?3),?np.uint8)
????#?img_open?=?cv.morphologyEx(img_bw,cv.MORPH_CLOSE,kernel)
????img_open?=?cv.dilate(img_bw,?kernel,?iterations=2)
????num_labels,?labels,?stats,?centroids?=?\
????????cv.connectedComponentsWithStats(img_open,?connectivity=8,?ltype=None)
????for?sta?in?stats:
????????if?sta[4]?<?1000:
????????????cv.rectangle(img_open,?tuple(sta[0:2]),?tuple(sta[0:2]?+?sta[2:4]),?(0,?0,?255),?thickness=-1)
????return?img_open
def?get_roi(img_bw):
????img_bw_c?=?img_bw.sum(axis=1)?/?255
????img_bw_r?=?img_bw.sum(axis=0)?/?255
????all_sum?=?img_bw_c.sum(axis=0)
????if?all_sum?!=?0:
????????r_ind,?c_ind?=?[],?[]
????????for?k,?r?in?enumerate(img_bw_r):
????????????if?r?>=?5:
????????????????r_ind.append(k)
????????for?k,?c?in?enumerate(img_bw_c):
????????????if?c?>=?5:
????????????????c_ind.append(k)
????????img_bw_sg?=?img_bw[c_ind[0]:c_ind[-1],?r_ind[0]:r_ind[-1]]
????????leng_c?=?len(c_ind)
????????leng_r?=?len(r_ind)
????????side_len?=?max(leng_c,?leng_r)?+?20
????????if?leng_c?==?side_len:
????????????add_r?=?int((side_len?-?leng_r)?/?2)
????????????add_c?=?10
????????else:
????????????add_r?=?10
????????????add_c?=?int((side_len?-?leng_c)?/?2)
????????img_bw_sg_bord?=?cv.copyMakeBorder(img_bw_sg,?add_c,?add_c,?add_r,?add_r,?cv.BORDER_CONSTANT,?value=[0,?0,?0])
????????return?img_bw_sg_bord
????else:
????????return?img_bw
def?softmax(X):
????X_exp?=?X.exp()
????partition?=?X_exp.sum(dim=1,?keepdim=True)
????return?X_exp?/?partition
predict.py
#?author:Hurricane
#?date:??2020/11/5
#?E-mail:[email protected]
#?-------------------------------------#
#???????對單張圖片進(jìn)行預(yù)測
#?-------------------------------------#
import?numpy?as?np
import?struct
import?matplotlib.pyplot?as?plt
import?cv2?as?cv
import?random
import?torch
from?torch?import?nn,?optim
import?torch.nn.functional?as?F
class?Residual(nn.Module):??#?本類已保存在d2lzh_pytorch包中方便以后使用
????def?__init__(self,?in_channels,?out_channels,?use_1x1conv=False,?stride=1):
????????super(Residual,?self).__init__()
????????self.conv1?=?nn.Conv2d(in_channels,?out_channels,?kernel_size=3,?padding=1,?stride=stride)
????????self.conv2?=?nn.Conv2d(out_channels,?out_channels,?kernel_size=3,?padding=1)
????????if?use_1x1conv:
????????????self.conv3?=?nn.Conv2d(in_channels,?out_channels,?kernel_size=1,?stride=stride)
????????else:
????????????self.conv3?=?None
????????self.bn1?=?nn.BatchNorm2d(out_channels)
????????self.bn2?=?nn.BatchNorm2d(out_channels)
????def?forward(self,?X):
????????Y?=?F.relu(self.bn1(self.conv1(X)))
????????Y?=?self.bn2(self.conv2(Y))
????????if?self.conv3:
????????????X?=?self.conv3(X)
????????return?F.relu(Y?+?X)
class?GlobalAvgPool2d(nn.Module):
????#?全局平均池化層可通過將池化窗口形狀設(shè)置成輸入的高和寬實現(xiàn)
????def?__init__(self):
????????super(GlobalAvgPool2d,?self).__init__()
????def?forward(self,?x):
????????return?F.avg_pool2d(x,?kernel_size=x.size()[2:])
def?resnet_block(in_channels,?out_channels,?num_residuals,?first_block=False):
????#?num_residuals:殘差數(shù)
????if?first_block:
????????assert?in_channels?==?out_channels??#?第一個模塊的通道數(shù)同輸入通道數(shù)一致
????blk?=?[]
????for?i?in?range(num_residuals):
????????if?i?==?0?and?not?first_block:
????????????blk.append(Residual(in_channels,?out_channels,?use_1x1conv=True,?stride=2))
????????else:
????????????blk.append(Residual(out_channels,?out_channels))
????return?nn.Sequential(*blk)
class?FlattenLayer(torch.nn.Module):
????def?__init__(self):
????????super(FlattenLayer,?self).__init__()
????def?forward(self,?x):?#?x?shape:?(batch,?*,?*,?...)
????????return?x.view(x.shape[0],?-1)
def?get_net():
????#?構(gòu)建網(wǎng)絡(luò)
????#?ResNet模型
????model_path?=?r"F:\PyCharm\Practice\hand_wrtten\logs\Epoch100-Loss0.0000-train_acc1.0000-test_acc0.9930.pth"
????device?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')
????net?=?nn.Sequential(
????????nn.Conv2d(1,?64,?kernel_size=7,?stride=2,?padding=3),
????????nn.BatchNorm2d(64),
????????nn.ReLU(),
????????nn.MaxPool2d(kernel_size=3,?stride=2,?padding=1))
????net.add_module("resnet_block1",?resnet_block(64,?64,?2,?first_block=True))
????net.add_module("resnet_block2",?resnet_block(64,?128,?2))
????net.add_module("resnet_block3",?resnet_block(128,?256,?2))
????net.add_module("global_avg_pool",?GlobalAvgPool2d())??#?GlobalAvgPool2d的輸出:?(Batch,?512,?1,?1)
????net.add_module("fc",?nn.Sequential(FlattenLayer(),?nn.Linear(256,?10)))
????#?測試網(wǎng)絡(luò)
????#?X?=?torch.rand((1,?1,?28,?28))
????#?for?name,?layer?in?net.named_children():
????#?????X?=?layer(X)
????#?????print(name,?'?output?shape:\t',?X.shape)
????#?加載網(wǎng)絡(luò)模型
????print("Load?weight?into?state?dict...")
????stat_dict?=?torch.load(model_path,?map_location=device)
????net.load_state_dict(stat_dict)
????net.to(device)
????net.eval()
????print("Load?finish!")
????return?net
def?predict(img,?net):
????device?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')
????img_in?=?torch.from_numpy(img)
????img_in?=?torch.unsqueeze(img_in,?0)
????img_in?=?torch.unsqueeze(img_in,?0).to(device)
????img_in?=?img_in.float()
????result_org?=?net(img_in)
????return?result_org
