使用pytorch mask-rcnn進(jìn)行目標(biāo)檢測/分割訓(xùn)練
點(diǎn)擊上方“小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時間送達(dá)
現(xiàn)在github上面有3個版本的mask-rcnn, keras, caffe(Detectron), pytorch,這幾個版本中,據(jù)說pytorch是性能最佳的一個,于是就開始使用它進(jìn)行訓(xùn)練,然而實(shí)際跑通的過程中也遇到了不少問題,記錄一下。
官方源代碼: https://github.com/facebookresearch/maskrcnn-benchmark
安裝
參照 https://github.com/facebookresearch/maskrcnn-benchmark作者給的說明進(jìn)行安裝。需要注意兩個點(diǎn):
gcc >= 4.9,否則會出現(xiàn)吐核的錯誤。具體安裝方法寫在下面吐核的內(nèi)容里了。
pytorch==1.0, 安裝0.4.0等版本均會報錯
作者說是因?yàn)間cc版本過低引起的,嘗試了很多更新gcc的方法,都有各種問題,最后通過這位小哥的方法成功更新:
https://link.zhihu.com/?target=https%3A//gist.github.com/craigminihan/b23c06afd9073ec32e0c
升級完gcc(>=4.9.0)之后呢, 可能會出現(xiàn)類似 /usr/lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found 的報錯,主要是升級gcc生成的動態(tài)庫沒有替換老版本gcc的動態(tài)庫。參考方法可見:
https://blog.csdn.net/xg123321123/article/details/78117162
在自己的數(shù)據(jù)上訓(xùn)練
數(shù)據(jù)集組織:參見COCO的數(shù)據(jù)集格式,你可以使用COCO數(shù)據(jù)集或者將自己的數(shù)據(jù)集轉(zhuǎn)為COCO進(jìn)行訓(xùn)練。當(dāng)然也可以自己改寫Dataset類來加載數(shù)據(jù)。
我是通過
Pascal
提供的https://github.com/pascal1129/kaggle_airbus_ship_detection/tree/master/0_rle_to_coco將數(shù)據(jù)集轉(zhuǎn)換為COCO格式的json annotation格式的。
在分配好你的訓(xùn)練集、驗(yàn)證集和測試集后,并獲取了對應(yīng)的annotation文件后,通過修改/maskrcnn-benchmark/maskrcnn-benchmark/config/paths_catalog.py這個文件的DatasetCatalog類來修改目錄。
class DatasetCatalog(object):
DATA_DIR = "datasets"
DATASETS = {
"coco_2017_train": {
"img_dir": "coco/train2017",
"ann_file": "coco/annotations/instances_train2017.json"
},
"coco_2017_val": {
"img_dir": "coco/val2017",
"ann_file": "coco/annotations/instances_val2017.json"
}
}同時在/maskrcnn-benchmark/configs/下的你選用的配置文件yaml修改DATASETS參數(shù),注意這里不是直接目錄的地方,而是使用前面的DatasetCatalog類中的DATASETS的鍵值作為索引:
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_val")
TEST: ("coco_2014_val",)準(zhǔn)備好數(shù)據(jù)集之后,官方提供的默認(rèn)類別是81,而你的數(shù)據(jù)集可能只有1個類別,所以需要在/maskrcnn-benchmark/maskrcnn_benchmark/config/defaults.py中修改C.MODEL.ROI_BOX_HEAD.NUM_CLASSES參數(shù)。注意,這個參數(shù)應(yīng)該是類別+1(即background),所以只有一類時應(yīng)該設(shè)置為2
接下來就可以按照官方的traning代碼進(jìn)行單GPU/多GPU訓(xùn)練啦
python /path_to_maskrcnn_benchmark/tools/train_net.py --config-file "/path/to/config/file.yaml"開始訓(xùn)練之后過不了幾個iter就會出現(xiàn)所有的Loss為nan的現(xiàn)象,這是由于學(xué)習(xí)率過大引起的,自己調(diào)小就可以了。另外默認(rèn)的版本是用的是warm up lr,所以開始的幾個epoch可能和你設(shè)定的不一樣,沒關(guān)系~另外,配置參數(shù)有兩個地點(diǎn),一個是yaml文件,另外一個是defaults.py, 有一些相同的參數(shù),yaml的會覆蓋defaults.py的,大家配置的參數(shù)在這兩個文件里找就好了。
可視化
該版本的master分支上還沒有可視化的實(shí)現(xiàn),實(shí)際上可以通過繼承MetricLogger來實(shí)現(xiàn),相關(guān)的內(nèi)容在merge request https://github.com/achalddave/maskrcnn-benchmark/commit/4210b77d4aef69c411200b13c93d7e2fe628164d已經(jīng)實(shí)現(xiàn)了,根據(jù)文中描述修改完代碼后直接運(yùn)行tensorboard命令即可
tensorboard --logdir=path/to/log-directory
Fine-tune on Pre-trained Model
如果你引用了NUM_CLASS與你的數(shù)據(jù)不一致的預(yù)訓(xùn)練模型,就會出現(xiàn)類似
size mismatch for roi_heads.mask.predictor.mask_fcn_logits.weight: copying a param with shape torch.Size([81, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([2, 256, 1, 1])
的報錯。這是因?yàn)閘ogitis層的class類別不一致導(dǎo)致的。可以通過刪除預(yù)訓(xùn)練中包含logits層的參數(shù)來解決沖突。使用gist.github.com/wangg12 中提供的腳本對下載的比如說Detectron的預(yù)訓(xùn)練模型進(jìn)行轉(zhuǎn)化,再在yaml文件中將WEIGHT參數(shù)改為預(yù)訓(xùn)練模型pkl路徑即可。
重設(shè)學(xué)習(xí)率
我開始訓(xùn)練的時候遇到一個問題就是改變學(xué)習(xí)率的參數(shù)重新開始訓(xùn)練時,加載的還是上次訓(xùn)練設(shè)置的參數(shù)。這個問題是由于pytorch在加載checkpoint的時候會把之前訓(xùn)練的optimizer和scheduler一起加載進(jìn)來。所以如果要重新設(shè)置學(xué)習(xí)率的話,需要在加載state_dict的時候不啟用上次訓(xùn)練保存的optimizer和scheduler參數(shù)。把maskrcnn_benchmark/utils/checkpoint.py文件中用于load optimizer和scheduler的兩行代碼注掉就可以了:
if "optimizer" in checkpoint and self.optimizer:
self.logger.info("Loading optimizer from {}".format(f))
# self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
if "scheduler" in checkpoint and self.scheduler:
self.logger.info("Loading scheduler from {}".format(f))
# self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
好消息!
小白學(xué)視覺知識星球
開始面向外開放啦??????
下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程 在「小白學(xué)視覺」公眾號后臺回復(fù):擴(kuò)展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺、目標(biāo)跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。 下載2:Python視覺實(shí)戰(zhàn)項(xiàng)目52講 在「小白學(xué)視覺」公眾號后臺回復(fù):Python視覺實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個視覺實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計算機(jī)視覺。 下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講 在「小白學(xué)視覺」公眾號后臺回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講,即可下載含有20個基于OpenCV實(shí)現(xiàn)20個實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。 交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會逐漸細(xì)分),請掃描下面微信號加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~

