目標(biāo)檢測 | Anchor free之CenterNet深度解析
點(diǎn)擊上方“小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時(shí)間送達(dá)
?
本文接著上一講對CornerNet的網(wǎng)絡(luò)結(jié)構(gòu)和損失函數(shù)的解析,鏈接如下
https://zhuanlan.zhihu.com/p/188587434https://zhuanlan.zhihu.com/p/195517472
本文來聊一聊Anchor-Free領(lǐng)域耳熟能詳?shù)腃enterNet。
原論文名為《Objects as Points》,有沒有覺得這種簡單的名字特別霸氣,比什么"基于xxxx的xxxx的xxxx論文"帥氣多了哈。
雖然這名字夠短,但是內(nèi)容卻非常充實(shí)。將物體看成點(diǎn)進(jìn)行檢測,那么應(yīng)用主要有以下三點(diǎn)
(1)物體檢測 (2)3D定位 (3)人體姿態(tài)估計(jì)
本文的代碼看的是基于keras版本的。鏈接如下
https://github.com/see--/keras-centernet
?
顧名思義,CornerNet以檢測框的兩個(gè)角點(diǎn)為基礎(chǔ)進(jìn)行物體的檢測,而CenterNet以檢測框的中心為基礎(chǔ)進(jìn)行物體位置的檢測.
CenterNet和CornerNet的網(wǎng)絡(luò)結(jié)構(gòu)類似,如下為CornerNet的網(wǎng)絡(luò)結(jié)構(gòu)。

由于CornerNet需要進(jìn)行兩個(gè)關(guān)鍵點(diǎn)檢測(左上角點(diǎn)和右下角點(diǎn))來判斷物體的位置,所以共有兩個(gè)大分支(每個(gè)大分支中又包含了三個(gè)小分支)。
而 CenterNet只需要進(jìn)行一個(gè)關(guān)鍵點(diǎn)的檢測(中心點(diǎn)的檢測)來判斷物體的位置,所以只有一個(gè)大的分支,該分支包含了三個(gè)小分支(雖然這三個(gè)小分支和CornerNet的還是有區(qū)別的)?;贖ourglass backbone的CenterNet結(jié)構(gòu)如下圖所示

該網(wǎng)絡(luò)要比CornerNet更簡單,而且細(xì)心的小伙伴們應(yīng)該也發(fā)現(xiàn)了和CornerNet分支輸出存在一定的異同之處,該網(wǎng)絡(luò)輸出分支分別為
(1)HeatMap,大小為(W/4,H/4,80),輸出不同類別(80個(gè)類別)物體中心點(diǎn)的位置 (2) Offset,大小為(W/4,H/4,2),對HeatMap的輸出進(jìn)行精煉,提高定位準(zhǔn)確度 (3) Height&Width,大小為(W/4,H/4,2),預(yù)測以關(guān)鍵點(diǎn)為中心的檢測框的寬高
顯然,(1)(2)在CornerNet中也出現(xiàn)過,但是Corner的另一個(gè)分支是輸出每個(gè)被檢測角點(diǎn)的embedding,即左上點(diǎn)的embedding和右上點(diǎn)的embedding距離足夠近,則被認(rèn)定為同一檢測框的角點(diǎn)對。
另外在CornerNet中還有一個(gè)創(chuàng)新點(diǎn),為Corner Pooling的提出,在CenterNet中被剔除了。
那么結(jié)合CenterNet的結(jié)構(gòu)圖

可以將其分為以下幾個(gè)部分
(1)pre,通過一個(gè)步長為2的7x7卷積和步長為2的殘差單元,將圖片寬高壓縮為原來的1/4 (2)Hourglass Module 1,第一個(gè)沙漏型的卷積神經(jīng)網(wǎng)絡(luò)模塊 (3)joint,連接Hourglass Module 2和Hourglass Module 2 (4)Hourglass Module 2,第二個(gè)沙漏型的卷積神經(jīng)網(wǎng)絡(luò)模塊 (5)Head,輸出三個(gè)分支輸出
具體代碼實(shí)現(xiàn)為
def HourglassNetwork(heads, num_stacks, cnv_dim=256, inres=(512, 512), weights='ctdet_coco',dims=[256, 384, 384, 384, 512]):"""Instantiates the Hourglass architecture.Optionally loads weights pre-trained on COCO.Note that the data format convention used by the model isthe one specified in your Keras config at `~/.keras/keras.json`.# Argumentsnum_stacks: number of hourglass modules.cnv_dim: number of filters after the resolution is decreased.inres: network input shape, should be a multiple of 128.weights: one of `None` (random initialization),'ctdet_coco' (pre-training on COCO for 2D object detection),'hpdet_coco' (pre-training on COCO for human pose detection),or the path to the weights file to be loaded.dims: numbers of channels in the hourglass blocks.# ReturnsA Keras model instance.# RaisesValueError: in case of invalid argument for `weights`,or invalid input shape."""if not (weights in {'ctdet_coco', 'hpdet_coco', None} or os.path.exists(weights)):raise ValueError('The `weights` argument should be either ''`None` (random initialization), `ctdet_coco` ''(pre-trained on COCO), `hpdet_coco` (pre-trained on COCO) ''or the path to the weights file to be loaded.')input_layer = Input(shape=(inres[0], inres[1], 3), name='HGInput')inter = pre(input_layer, cnv_dim)prev_inter = Noneoutputs = []for i in range(num_stacks):prev_inter = inter_heads, inter = hourglass_module(heads, inter, cnv_dim, i, dims) # return the heads that include three branchsoutputs.extend(_heads)if i < num_stacks - 1:# the joint between the first hourglass module and the second onesinter_ = Conv2D(cnv_dim, 1, use_bias=False, name='inter_.%d.0' % i)(prev_inter)inter_ = BatchNormalization(epsilon=1e-5, name='inter_.%d.1' % i)(inter_)cnv_ = Conv2D(cnv_dim, 1, use_bias=False, name='cnv_.%d.0' % i)(inter)cnv_ = BatchNormalization(epsilon=1e-5, name='cnv_.%d.1' % i)(cnv_)inter = Add(name='inters.%d.inters.add' % i)([inter_, cnv_])inter = Activation('relu', name='inters.%d.inters.relu' % i)(inter)inter = residual(inter, cnv_dim, 'inters.%d' % i)model = Model(inputs=input_layer, outputs=outputs)if weights == 'ctdet_coco':weights_path = get_file('%s_hg.hdf5' % weights,CTDET_COCO_WEIGHTS_PATH,cache_subdir='models',file_hash='ce01e92f75b533e3ff8e396c76d55d97ff3ec27e99b1bdac1d7b0d6dcf5d90eb')model.load_weights(weights_path)elif weights == 'hpdet_coco':weights_path = get_file('%s_hg.hdf5' % weights,HPDET_COCO_WEIGHTS_PATH,cache_subdir='models',file_hash='5c562ee22dc383080629dae975f269d62de3a41da6fd0c821085fbee183d555d')model.load_weights(weights_path)elif weights is not None:model.load_weights(weights)return model
有關(guān)注釋都在上面了,具體定義請結(jié)合源代碼進(jìn)行查看。
前面我們已經(jīng)知道了CenterNet網(wǎng)絡(luò)有三個(gè)輸出,分別為
(1) HeatMap,大小為(W/4,H/4,80),輸出不同類別(80個(gè)類別)物體中心點(diǎn)的位置 (2) Offset,大小為(W/4,H/4,2),對HeatMap的輸出進(jìn)行精煉,提高定位準(zhǔn)確度 (3) Height&Width,大小為(W/4,H/4,2),預(yù)測以關(guān)鍵點(diǎn)為中心的檢測框的寬高
那么如何將這些輸出轉(zhuǎn)為直觀的檢測框信息呢?
在目標(biāo)檢測領(lǐng)域,通常將這一過程稱為decode,就是根據(jù)網(wǎng)絡(luò)的輸出獲取直觀的檢測框信息。
那么encode就是將檢測框信息(通常為ground-truth bounding box的坐標(biāo)、寬高信息)轉(zhuǎn)化為形為網(wǎng)絡(luò)輸出的信息,便于網(wǎng)絡(luò)損失函數(shù)的求解。
代碼中實(shí)現(xiàn)decode這一過程的代碼如下
def _ctdet_decode(hm, reg, wh, k=100, output_stride=4):"""將網(wǎng)絡(luò)的輸出轉(zhuǎn)換為標(biāo)準(zhǔn)的檢測框信息"""hm = K.sigmoid(hm)hm = _nms(hm)hm_shape = K.shape(hm)reg_shape = K.shape(reg)wh_shape = K.shape(wh)# cat為通道數(shù)batch, width, cat = hm_shape[0], hm_shape[2], hm_shape[3]# 對輸出的特征圖進(jìn)行鋪平hm_flat = K.reshape(hm, (batch, -1))reg_flat = K.reshape(reg, (reg_shape[0], -1, reg_shape[-1]))wh_flat = K.reshape(wh, (wh_shape[0], -1, wh_shape[-1]))def _process_sample(args):_hm, _reg, _wh = args_scores, _inds = tf.math.top_k(_hm, k=k, sorted=True) # 尋找前k個(gè)heatmap的值_classes = K.cast(_inds % cat, 'float32') #獲取索引對應(yīng)的類別_inds = K.cast(_inds / cat, 'int32') #在某一類別中的位置(最大長度為 width*width),一維的# 一維位置轉(zhuǎn)二維坐標(biāo)_xs = K.cast(_inds % width, 'float32') #二維坐標(biāo)中的橫坐標(biāo)_ys = K.cast(K.cast(_inds / width, 'int32'), 'float32') #二維坐標(biāo)的縱坐標(biāo)_wh = K.gather(_wh, _inds) #根據(jù)索引獲得寬高數(shù)據(jù)_reg = K.gather(_reg, _inds) #根據(jù)坐標(biāo)獲得offset_xs = _xs + _reg[..., 0]_ys = _ys + _reg[..., 1]_x1 = _xs - _wh[..., 0] / 2_y1 = _ys - _wh[..., 1] / 2_x2 = _xs + _wh[..., 0] / 2_y2 = _ys + _wh[..., 1] / 2# rescale to image coordinates_x1 = output_stride * _x1_y1 = output_stride * _y1_x2 = output_stride * _x2_y2 = output_stride * _y2_detection = K.stack([_x1, _y1, _x2, _y2, _scores, _classes], -1)return _detectiondetections = K.map_fn(_process_sample, [hm_flat, reg_flat, wh_flat], dtype=K.floatx())return detections
主要通過非極大值抑制(NMS)后在heatmap上尋找topk個(gè)最大值,即可能為物體中心的索引。然后根據(jù)這topk個(gè)中心點(diǎn),尋找其對應(yīng)的類別、寬高和offset信息。
這里的NMS并不像Anchor-free中的NMS(即利用檢測框的IOU為距離基準(zhǔn)求解極大值,抑制非極大值)。
而CenterNet的NMS,是尋找某點(diǎn)與其周圍的八個(gè)點(diǎn)之間最大值,作為其NMS的極大值。那么該操作可以使用最簡單的3x3的MaxPooling實(shí)現(xiàn)。
實(shí)現(xiàn)代碼如下:
def _nms(heat, kernel=3):hmax = K.pool2d(heat, (kernel, kernel), padding='same', pool_mode='max')keep = K.cast(K.equal(hmax, heat), K.floatx())return heat * keep
貌似該keras代碼中,并沒有實(shí)現(xiàn)訓(xùn)練CenterNet的過程。所以我們沒辦法結(jié)合代碼進(jìn)行訓(xùn)練過程的解析,包括
(1)損失函數(shù)設(shè)定 (2)將ground-truth bounding box信息映射為類似網(wǎng)絡(luò)輸出的格式,被稱為encode。
那么下面直接結(jié)合論文進(jìn)行損失函數(shù)與encode的解析。
?
前面提到過Encode的過程是將ground-truth bounding box信息映射為類似網(wǎng)絡(luò)輸出的格式。這樣可以加速求解損失函數(shù)的計(jì)算。
我們知道在CornerNet中將檢測框的左上角點(diǎn)和右下角點(diǎn)映射到heatmap上的過程,并不是簡單的一一對應(yīng)關(guān)系的(也就是將原圖中的某關(guān)鍵點(diǎn)映射到heatmap中的某一關(guān)鍵點(diǎn)中),而是將原圖中的某關(guān)鍵點(diǎn)(在CenterNet中為檢測框的中點(diǎn))映射到heatmap中的某一高斯核區(qū)域內(nèi)。如下圖4所示,為每個(gè)檢測框中心點(diǎn)的高斯核區(qū)域顯示。


又或者借用
https://zhuanlan.zhihu.com/p/66048276
中的圖,為某一中心點(diǎn)在heatmap的映射可視化??梢灾庇^地感受其呈現(xiàn)二維高斯分布。

那么根據(jù)獲得的heatmap,我們可以將ground-truth bbox的偏移信息和寬高信息按照該映射關(guān)系,等同地映射到前面提到的Offset特征圖和Height&Width特征圖中,實(shí)現(xiàn)整個(gè)encode的過程?
實(shí)現(xiàn)了encode過程后,設(shè)定損失函數(shù)就變得非常簡單了。
4.1 focal loss
原論文中令為網(wǎng)絡(luò)輸出的heatmap,為ground_truth信息,即heatmap的標(biāo)簽/監(jiān)督信息。類似CornerNet使用focal loss進(jìn)行損失函數(shù)設(shè)定,實(shí)現(xiàn)過程如下

這里的和為focal loss的超參數(shù),N是圖片中關(guān)鍵點(diǎn)的個(gè)數(shù)。
4.2 offset loss
為了彌補(bǔ)由于stride的原因造成的偏移誤差,論文中設(shè)定了一個(gè)關(guān)于偏移的損失函數(shù),使得訓(xùn)練后的網(wǎng)絡(luò)能夠有效計(jì)算offset值,從而修正檢測框的位置。
不妨這里引用一下論文中的offset loss公式。

這里的p是檢測框中心點(diǎn)(原圖中)的真實(shí)坐標(biāo),p/R是理論上該中心點(diǎn)映射到特征圖的準(zhǔn)確位置區(qū)域(很可能是浮點(diǎn)型)。
但是我們知道在特征圖中,所有的點(diǎn)的位置都是整型的(即不存在某一個(gè)點(diǎn)的位置為(1.1,2.9)的),所以實(shí)際上,原圖中坐標(biāo)為p的點(diǎn)映射到特征圖后的位置應(yīng)該是

是p向下取整的結(jié)果,所以這里就造成了誤差了,那么這個(gè)誤差就是

公式中的是網(wǎng)絡(luò)的offset輸出特征圖。那么這個(gè)指的是關(guān)鍵點(diǎn)實(shí)際落入的區(qū)域。說明該offset loss只關(guān)注在關(guān)鍵點(diǎn)區(qū)域的offset輸出。
4.3 height&width loss
用來訓(xùn)練物體寬高大小的損失函數(shù)就非常簡單了。假設(shè)物體k的ground-truth坐標(biāo)為

那么他的寬高為

如果只考慮關(guān)鍵點(diǎn)實(shí)際落入的區(qū)域的輸出特征圖,也就是。該損失函數(shù)設(shè)定為

4.4 總損失
最后總損失函數(shù)為上面三個(gè)損失函數(shù)之和

總的來說,CenterNet要比CornerNet學(xué)起來更加簡單點(diǎn),而且比CornerNet更實(shí)用,應(yīng)用范圍也更廣!
該模型在Anchor-free目標(biāo)檢測領(lǐng)域和YOLO V3在Anchor-based目標(biāo)檢測領(lǐng)域的地位類似,非常推薦大家讀一下原文!有關(guān)其在3D location和姿態(tài)估計(jì)等任務(wù)的應(yīng)用,大家感興趣可以自行學(xué)習(xí)。
好消息!?
小白學(xué)視覺知識星球
開始面向外開放啦??????
下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程 在「小白學(xué)視覺」公眾號后臺回復(fù):擴(kuò)展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺、目標(biāo)跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。 下載2:Python視覺實(shí)戰(zhàn)項(xiàng)目52講 在「小白學(xué)視覺」公眾號后臺回復(fù):Python視覺實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計(jì)數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個(gè)視覺實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺。 下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講 在「小白學(xué)視覺」公眾號后臺回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講,即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。 交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動(dòng)駕駛、計(jì)算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會逐漸細(xì)分),請掃描下面微信號加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三?+?上海交大?+?視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~

