Tensorflow + OpenCV4 安全帽檢測模型訓(xùn)練與推理
點(diǎn)擊上方“小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時間送達(dá)
軟件版本信息:
Windows10 64位Tensorflow1.15Tensorflow object detection API 1.xPython3.6.5VS2015 VC++CUDA10.0
硬件:
CPUi7GPU 1050ti
如何安裝tensorflow object detection API框架,看這里:
Tensorflow Object Detection API 終于支持tensorflow1.x與tensorflow2.x了
首先需要下載數(shù)據(jù)集,下載地址為:
https://pan.baidu.com/s/1UbFkGm4EppdAU660Vu7SdQ總計(jì)7581張圖像,基于Pascal VOC2012完成標(biāo)注。分為兩個類別,分別是安全帽與人(hat與person),json格式如下:
item {id: 1name: 'hat'}item {id: 2name: 'person'}

數(shù)據(jù)集下載之后,并不能被tensorflow object detection API框架中的腳本轉(zhuǎn)換為tfrecord,主要是有幾個XML跟JPEG圖像格式錯誤,本人經(jīng)過一番磨難之后把它們?nèi)啃拚恕P拚蟮臄?shù)據(jù)運(yùn)行下面兩個腳本即可生成訓(xùn)練集與驗(yàn)證集的tfrecord數(shù)據(jù),命令行如下:


這里需要注意的是create_pascal_tf_record.py 腳本的165行把
'aeroplane_' + FLAGS.set + '.txt')修改為:
FLAGS.set + '.txt')原因是這里的數(shù)據(jù)集沒有做分類train/val。所以需要修改一下,修改完成之后保存。運(yùn)行上述的命令行,就可以正確生成tfrecord,否則會遇到錯誤。
基于faster_rcnn_inception_v2_coco對象檢測模型實(shí)現(xiàn)遷移學(xué)習(xí),首先需要配置遷移學(xué)習(xí)的config文件,對應(yīng)的配置文件可以從:
research\object_detection\samples\configs中發(fā)現(xiàn),發(fā)現(xiàn)文件:
faster_rcnn_inception_v2_coco.config之后,修改配置文件的中相關(guān)部分,關(guān)于如何修改,修改什么,可以看這里:



修完完成之后,在D盤下新建好幾個目錄之后,執(zhí)行下面的命令行參數(shù):

就會開始訓(xùn)練,總計(jì)訓(xùn)練40000 step。訓(xùn)練過程中可以通過tensorboard查看訓(xùn)練結(jié)果:

模型導(dǎo)出
完成了40000 step訓(xùn)練之后,就可以看到對應(yīng)的檢查點(diǎn)文件,借助tensorflow object detection API框架提供的模型導(dǎo)出腳本,可以把檢查點(diǎn)文件導(dǎo)出為凍結(jié)圖格式的PB文件。相關(guān)的命令行參數(shù)如下:

得到pb文件之后,使用OpenCV4.x中的tf_text_graph_faster_rcnn.py腳本,轉(zhuǎn)換生成graph.pbtxt配置文件。最終得到:
- frozen_inference_graph.pb- frozen_inference_graph.pbtxt
如何導(dǎo)出PB模型到OpenCV DNN支持看這里:
干貨 | tensorflow模型導(dǎo)出與OpenCV DNN中使用
在OpenCV DNN中直接調(diào)用訓(xùn)練出來的模型完成自定義對象檢測,這里需要特別說明一下的,因?yàn)樵谟?xùn)練階段我們選擇了模型支持600~1024保持比率的圖像輸入。所以在推理預(yù)測階段,我們可以直接使用輸入圖像的真實(shí)大小,模型的輸出格式依然是1x1xNx7,按照格式解析即可得到預(yù)測框與對應(yīng)的類別。最終的代碼實(shí)現(xiàn)如下:
1import?cv2?as?cv
2
3labels?=?['hat',?'person']
4model?=?"D:/safehat_train/models/train/frozen_inference_graph.pb"
5config?=?"D:/safehat_train/models/train/frozen_inference_graph.pbtxt"
6
7#?讀取測試圖像
8image?=?cv.imread("D:/123.jpg")
9h,?w?=?image.shape[:2]
10cv.imshow("input",?image)
11
12#?加載模型,執(zhí)行推理
13net?=?cv.dnn.readNetFromTensorflow(model,?config)
14blob?=?cv.dnn.blobFromImage(cv.resize(image,?(w,?h)),?swapRB=True,?crop=False)
15net.setInput(blob)
16detectOut?=?net.forward()
17
18#?解析輸出
19classIds?=?[]
20confidences?=?[]
21boxes?=?[]
22for?detection?in?detectOut[0,0,:,:]:
23????score?=?detection[2]
24????if?score?>?0.4:
25????????left?=?detection[3]*w
26????????top?=?detection[4]*h
27????????right?=?detection[5]*w
28????????bottom?=?detection[6]*h
29????????classId?=?int(detection[1])?+?1
30????????classIds.append(classId)
31????????boxes.append([int(left),?int(top),?int(right),?int(bottom)])
32????????confidences.append(float(score))
33
34#?非最大抑制
35nms_indices?=?cv.dnn.NMSBoxes(boxes,?confidences,?0.4,?0.4)
36for?i?in?range(len(nms_indices)):
37????index?=?nms_indices[i][0]
38????box?=?boxes[index]
39????cid?=?classIds[index]
40????if?cid?==?1:
41????????cv.rectangle(image,?(box[0],?box[1]),?(box[2],?box[3]),?(140,?199,?0),?4,?8,?0)
42????else:
43????????cv.rectangle(image,?(box[0],?box[1]),?(box[2],?box[3]),?(255,?0,?255),?4,?8,?0)
44????cv.putText(image,?labels[cid-1],?(box[0],?box[1]),?cv.FONT_HERSHEY_SIMPLEX,?0.75,?(255,?0,?0),?2)
45
46#?顯示輸出
47cv.imshow("safetyhat-detection-demo",?image)
48cv.imwrite("D:/result123.png",?image)
49cv.waitKey(0)
50cv.destroyAllWindows()一些測試圖像的運(yùn)行結(jié)果如下:





可以看到第二張途中有誤識別情況發(fā)生!可見模型還可以繼續(xù)訓(xùn)練!
避坑指南:
1. 下載的公開數(shù)據(jù)集,記得用opencv重新讀取一遍,然后resave為jpg格式,這個會避免在生成tfrecord時候的圖像格式數(shù)據(jù)錯誤。
ValueError: Image format not JPEG
2. 公開數(shù)據(jù)集中xml文件的filename有跟真實(shí)圖像文件名稱不一致的情況,要程序處理一下。不然會遇到
Windows fatal exception: access violation error?
3. 使用非最大抑制之后,
SystemError:
參考資料:
使用OpenCV 4.1.2的DNN模塊部署深度學(xué)習(xí)模型
https://github.com/njvisionpower/Safety-Helmet-Wearing-Dataset
https://github.com/opencv/opencv/wiki/Deep-Learning-in-OpenCV
https://github.com/tensorflow/models/tree/master/research/object_detection
交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計(jì)算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會逐漸細(xì)分),請掃描下面微信號加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三?+?上海交大?+?視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~

