實(shí)時(shí)實(shí)例分割模型YOLACT
點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師”
設(shè)為星標(biāo),干貨直達(dá)!
AI編輯:我是小將
本文作者:OpenMMLab?@00007
https://zhuanlan.zhihu.com/p/376347955
本文已由原作者授權(quán)轉(zhuǎn)載
0 前言
YOLACT 含義是 You Only Look At CoefficienTs,是一篇非常有創(chuàng)新性的實(shí)時(shí)實(shí)例分割算法。Mask R-CNN 一般被認(rèn)為是實(shí)例分割的 baseline,分割性能是非常不錯(cuò)的,但是其存在的問(wèn)題是速度較慢,且包括 RoIAlign 等層不容易部署,而 YOLACT 的貢獻(xiàn)是沒(méi)有在 Mask R-CNN 基礎(chǔ)上小修小補(bǔ),而是基于 one-stage 全卷積算法重新設(shè)計(jì),雖然在精度上稍低于 Mask R-CNN,但是也滿足大部分需求了,并且速度達(dá)到了實(shí)時(shí),容易部署,廣泛應(yīng)用于各類(lèi)落地場(chǎng)景。
YOLACT 的核心思想是并行預(yù)測(cè)當(dāng)前圖片的原型掩碼(prototype mask) 和每個(gè) bbox 實(shí)例的掩碼系數(shù)(mask coefficients),然后通過(guò)將原型與掩模系數(shù)線性組合來(lái)生成實(shí)例掩碼(instance masks)。由于并行預(yù)測(cè),不需要 two-stage 的 roipool 等操作,可以保持高的輸出分辨率,故分割精度比較高。如下圖所示,假設(shè)所有待檢測(cè)物體的每個(gè)像素點(diǎn)都可以采用長(zhǎng)度為 4 (超參,論文中是 32 )的原型向量表征,則原型掩碼 shape 是 (h, w, 4),相應(yīng)的每個(gè)實(shí)例的掩碼系數(shù) shape 是 (n, 4),n 是檢測(cè)物體個(gè)數(shù),然后利用每個(gè) bbox 預(yù)測(cè)的掩碼系數(shù)向量去加權(quán)原型掩碼即 (h, w, 4) @ (4, ), 從而得到 (h,w) 的 mask 輸出,遍歷每個(gè) bbox 預(yù)測(cè)的掩碼系數(shù)去加權(quán)當(dāng)前圖片的全局原型掩碼,就可以得到每個(gè) bbox 所對(duì)應(yīng)的 mask。

1 算法實(shí)現(xiàn)
和前系列解讀一樣,按照模塊方式進(jìn)行分析。
1.1 Backbone
作者選擇的是標(biāo)準(zhǔn)的 ResNet 系列網(wǎng)絡(luò)
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=-1, # do not freeze stem
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False, # update the statistics of bn
zero_init_residual=False,
style='pytorch'), 需要注意的是:由于訓(xùn)練時(shí)長(zhǎng)比較長(zhǎng),作者并沒(méi)有采用常規(guī)的固定某些 stage 權(quán)重的做法,而且 backbone 層全部參與訓(xùn)練。
1.2 Neck
為了加強(qiáng)多層特征圖之間的信息融合和引入多尺度預(yù)測(cè),和 RetinaNet 一樣也采用了 FPN 層
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_input',
num_outs=5,
upsample_cfg=dict(mode='bilinear')), # 上采樣算子為雙線性上采樣1.3 Head
Head 網(wǎng)絡(luò)實(shí)際上包括 3 個(gè) Head
(1) 和 RetinaNet 一致的 bbox 分支,該分支包括 bbox 預(yù)測(cè)和類(lèi)別預(yù)測(cè)分支,以及每個(gè)實(shí)例的掩碼系數(shù)分支

輸出實(shí)際上包括 3個(gè)預(yù)測(cè)分支,為了加速,對(duì) RetinaNet 輸出 Head 進(jìn)行了適當(dāng)修改,主要是
使用了更少的 anchor,每個(gè)位置都是 3 個(gè) anchor
Bbox 分支 和 cls 分支共享卷積
額外多預(yù)測(cè)實(shí)例級(jí)別的掩碼系數(shù)
(2) 原型掩碼預(yù)測(cè)分支

為了似的預(yù)測(cè) mask 具備更多細(xì)節(jié),作者采用比較大的輸出圖,對(duì) FPN 輸出的 P3 特征圖還進(jìn)行了額外的上采樣操作,輸出 mask 大小是 138x138
(3) 語(yǔ)義分割輔助訓(xùn)練分支
為了加速收斂和提高性能,作者還額外引入了全圖的不區(qū)分實(shí)例的語(yǔ)義分割輔助訓(xùn)練分支,該分支在推理階段可以直接刪除,并且足夠簡(jiǎn)單,只有一層卷積層而已。
1.4 訓(xùn)練流程
(1) bbox 分支
bbox_head=dict(
type='YOLACTHead',
num_classes=80, # 類(lèi)別
in_channels=256,
feat_channels=256,
# 和 RetinaNet 一致,只不過(guò) anchor 更少,參數(shù)也重新設(shè)計(jì)了
anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=3,
scales_per_octave=1,
base_sizes=[8, 16, 32, 64, 128],
ratios=[0.5, 1.0, 2.0],
strides=[550.0 / x for x in [69, 35, 18, 9, 5]],
centers=[(550 * 0.5 / x, 550 * 0.5 / x)
for x in [69, 35, 18, 9, 5]]),
# 編解碼過(guò)程和 RetinaNet 一致
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
# ce loss,而沒(méi)有采用 focal loss
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
reduction='none',
loss_weight=1.0),
# bbox loss
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.5),
num_head_convs=1,
num_protos=32,
use_ohem=True), # 默認(rèn) cls loss 還采用了 ohem 策略,克服不平衡問(wèn)題如果不考慮每個(gè)實(shí)例的掩碼系數(shù),那么這個(gè)分支的推理和訓(xùn)練流程和 RetinaNet 完全相同。如果不熟悉,請(qǐng)參考 RetinaNet 算法解讀。
由于每個(gè) bbox 的 mask 系數(shù)沒(méi)有標(biāo)簽,故 bbox head 分支僅僅對(duì) bbox 和 cls 分支計(jì)算 loss,mask 系數(shù)監(jiān)督信息來(lái)自后面的原型掩碼預(yù)測(cè)分支。
(2) 原型掩碼預(yù)測(cè)分支
mask_head=dict(
type='YOLACTProtonet',
in_channels=256,
num_protos=32, # 核心超參
num_classes=80,
# 考慮到特征圖很大,通道很多,為了防止實(shí)例過(guò)多而OOM,強(qiáng)制訓(xùn)練最大 100 個(gè)實(shí)例
max_masks_to_train=100,
loss_mask_weight=6.125), 為了方便大家理解該分支的訓(xùn)練流程,首先需要看下 mmdet/models/detectors/yolact.py 中的整個(gè)訓(xùn)練流
# 實(shí)例級(jí) mask 轉(zhuǎn)為 tensor
gt_masks = [
gt_mask.to_tensor(dtype=torch.uint8, device=img.device)
for gt_mask in gt_masks
]
# 特征提取,包括 FPN,輸出是 5 個(gè)不同尺度的特征圖
x = self.extract_feat(img)
# bbox 分支進(jìn)行 forward
cls_score, bbox_pred, coeff_pred = self.bbox_head(x)
bbox_head_loss_inputs = (cls_score, bbox_pred) + (gt_bboxes, gt_labels,
img_metas)
#bbox 分支計(jì)算 loss,可以看出沒(méi)有傳入 coeff_pred
losses, sampling_results = self.bbox_head.loss(
*bbox_head_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
# 額外的語(yǔ)義分割監(jiān)督層
segm_head_outs = self.segm_head(x[0])
loss_segm = self.segm_head.loss(segm_head_outs, gt_masks, gt_labels)
losses.update(loss_segm)
# 原型掩碼預(yù)測(cè)分支 forward,注意因?yàn)?mask 需要大的輸出圖,故作者直接采用了 P3 層而已,其余層沒(méi)有使用。coeff_pred 和 gt_bboxes 傳入,用于提取實(shí)例級(jí)別信息
mask_pred = self.mask_head(x[0], coeff_pred, gt_bboxes, img_metas,
sampling_results)
# 計(jì)算 loss
loss_mask = self.mask_head.loss(mask_pred, gt_masks, gt_bboxes,
img_metas, sampling_results)
losses.update(loss_mask) 需要特意注意原型掩碼預(yù)測(cè)分支 forward 的輸入?yún)?shù)。需要明白:既然認(rèn)為是實(shí)例分割,由于 target 也是每個(gè) bbox 實(shí)例的 mask,那么該分支就需要想辦法通過(guò) forward 得到實(shí)例的 mask,而原型預(yù)測(cè)分支僅僅輸出全圖的原型掩碼,需要利用預(yù)測(cè)的每個(gè)實(shí)例掩碼系數(shù)來(lái)提取。
對(duì)于原型掩碼 protonet 分支,由于直接的標(biāo)簽,且無(wú)法用預(yù)測(cè) bbox 來(lái)裁剪(前期不穩(wěn)定),但是我們有 實(shí)例 mask 標(biāo)注,故訓(xùn)練時(shí)候輸入的 bbox 是 gt bbox,然后利用 anchor 匹配時(shí)候的匹配規(guī)則即特征圖點(diǎn)上哪些 anchor 負(fù)責(zé)預(yù)測(cè)該 gt bbox (如果不做處理,直接采用所有 gt bbox 去 crop,可能會(huì)存在 bbox 分支和 protonet 分支監(jiān)督點(diǎn)不一致問(wèn)題),基于這些正樣本 anchor,然后采用 gt bbox 去裁剪對(duì)應(yīng)的預(yù)測(cè) mask 圖,就可以得到實(shí)例級(jí)別 mask,后續(xù)算 loss 就是自然的事情了。
原型掩碼 protonet 分支的訓(xùn)練過(guò)程有點(diǎn)點(diǎn)繞,不知道有沒(méi)有說(shuō)清楚,舉個(gè)例子,假設(shè) 5個(gè)輸出層所有特征圖 h,w,k 進(jìn)行拉伸,可以得到 (N,4) 個(gè)預(yù)測(cè)框,也可以得到 (N,32) mask 系數(shù),其中只有部分預(yù)測(cè)框?qū)儆谡龢颖?,因?yàn)闇y(cè)試時(shí)候是基于預(yù)測(cè)框和 prototype 相乘,為了保持訓(xùn)練和推理一致,在訓(xùn)練時(shí)候需要采用 gt bbox 代替,利用前面 anchor 匹配規(guī)則計(jì)算出的正樣本索引來(lái)提取 bbox 相關(guān)的 prototype 信息,此時(shí)就可以得到正樣本索引所對(duì)應(yīng)的 gt bbox 和 預(yù)測(cè) mask 系數(shù)得到 mask 預(yù)測(cè)圖,然后就可以采用 bce 進(jìn)行 mask 訓(xùn)練了。由于 mask 系數(shù)和 protonet 聯(lián)合得到的最終mask,故 mask 系數(shù)分支也得到了監(jiān)督。
prototypes = self.protonet(x)
prototypes = prototypes.permute(0, 2, 3, 1).contiguous()
# idx 表示圖片索引
cur_sampling_results = sampling_results[idx]
# 找出正樣本索引
pos_assigned_gt_inds = \
cur_sampling_results.pos_assigned_gt_inds
# cur_bboxes 是 gt bbox,(M,4)
bboxes_for_cropping = cur_bboxes[pos_assigned_gt_inds].clone()
pos_inds = cur_sampling_results.pos_inds
# cur_coeff_pred 是 正樣本索引所對(duì)應(yīng)的 mask 系數(shù)預(yù)測(cè)值 (M,32)
cur_coeff_pred = cur_coeff_pred[pos_inds]
# 每個(gè)實(shí)例對(duì)應(yīng)的 mask 預(yù)測(cè)圖 (138,138,M)
mask_pred = cur_prototypes @ cur_coeff_pred.t()
mask_pred = torch.sigmoid(mask_pred)
# 基于 bbox 裁剪出 mask 圖
mask_pred = self.crop(mask_pred, bboxes_for_cropping)
# 后面就可以計(jì)算 bce loss 了如果覺(jué)得上述過(guò)程還是難以理解,請(qǐng)先閱讀后續(xù)的推理流程,再反過(guò)來(lái)閱讀訓(xùn)練流程,就會(huì)輕松很多。
(3) 語(yǔ)義分割輔助訓(xùn)練分支
為了進(jìn)一步提高性能,作者還額外引入了一個(gè)簡(jiǎn)單的卷積層,命名為語(yǔ)義分割層,因?yàn)橛姓Z(yǔ)義分割標(biāo)注,故可以采用 bce 進(jìn)行監(jiān)督。語(yǔ)義分割 head 分支,輸入 shape 是 69,即 FPN 輸出的最大 size 層,然后將對(duì)應(yīng)分割 標(biāo)注雙線性插值到 69,coco 類(lèi)別是80,故分割圖通道是 80,然后采用 bce 進(jìn)行監(jiān)督即可。推理時(shí)候直接刪除本分支即可。
segm_head=dict(
type='YOLACTSegmHead',
num_classes=80,
in_channels=256,
loss_segm=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), 1.5 推理流程
輸入圖片固定是 550x550,ResNet 輸出 3 個(gè)輸出特征圖,strdie 分別是 8/16/32,輸出特征圖大小是 69/35/18,然后經(jīng)過(guò) FPN 特征融合并擴(kuò)展了兩個(gè)高語(yǔ)義層輸出,一共輸出 5 個(gè)特征圖,特征圖大小是69/35/18/9/5,即一共 5 個(gè)輸出預(yù)測(cè)層。
和 RetinaNet 流程一致,先利用 5個(gè)輸出層進(jìn)行 bbox 和類(lèi)別檢測(cè),同時(shí)輸出每個(gè) bbox 相關(guān)的 mask 系數(shù),由于 anchor 個(gè)數(shù)為3,故分類(lèi)預(yù)測(cè)圖通道為 81x3(softmax模式),bbox 預(yù)測(cè)圖通道為 4x3,mask 系數(shù)通道為 32x3,即每個(gè) bbox 需要采用 32 長(zhǎng)度的向量來(lái)表征。
Bbox 后處理流程和 RetinaNet 思想相同,大概流程是遍歷每個(gè)預(yù)測(cè)輸出層,對(duì)每個(gè)層先利用 nms_pre 參數(shù)過(guò)濾到指定數(shù)目的框;對(duì)這些框進(jìn)行解碼操作;最后對(duì)所有結(jié)果采用 nms 進(jìn)行抑制得到指定數(shù)目的 bbox 和對(duì)應(yīng)類(lèi)別、mask 系數(shù)值,默認(rèn)最多是100,即經(jīng)過(guò)本步驟,輸出bbox 的 shape 是 (N,4),類(lèi)別的 shape 是 (N,),mask 系數(shù)的 shape 是 (N,32)。
需要注意的是:為了達(dá)到實(shí)時(shí),作者對(duì)常規(guī) nms 進(jìn)行了修改,在盡量不降低太多性能情況下提出加速 nms 版本 fast nms。其核心思想是在一次抑制過(guò)程中,運(yùn)行已經(jīng)被移除的 bbox 去抑制其余 bbox,從而迅速移除大量 bbox,從而加速 nms。
對(duì)于原型分支 Protonet,為了使得 mask 更加精確,作者只選擇了 FPN 后輸出的最大尺度特征圖 size 為 (69,69) 預(yù)測(cè)全局原型,輸出 shape 為 (138,138,32),特征圖上面每個(gè)位置都采用長(zhǎng)度為 32 的 prototypes來(lái)表征,然后將 N 個(gè)預(yù)測(cè)框和 prototypes 矩陣進(jìn)行乘加操作即 (138,138,32) @ (32,N) 輸出 shape 為 (138,138,100),即可得到每個(gè) bbox 對(duì)應(yīng)的 mask,然后利用 bbox 坐標(biāo)去 mask 圖上進(jìn)行切割即可得到對(duì)應(yīng)的mask圖,最后利用二值化,插值函數(shù)將 bbox 和 mask 都還原到原始圖尺度,最終得到實(shí)例分割結(jié)果。
1.6 可視化分析
為了更加容易理解掩碼系數(shù)和原型掩碼,我特意挑選一種簡(jiǎn)單背景的圖片(meinvtu),該圖片來(lái)自 COCO 驗(yàn)證集,如下所示。

原型掩碼 shape 是 (138, 138, 32),可以直接將這 32 個(gè) tensor 進(jìn)行展開(kāi)按照索引順序顯示,如下所示,可以發(fā)現(xiàn)還是存在很多冗余信息的,由于沒(méi)有直接的監(jiān)督信號(hào),輸出的 tensor 并沒(méi)有特定順序。

然后提取 bbox 坐標(biāo)和對(duì)應(yīng)的實(shí)例掩碼系數(shù),可視化如下所示

將上述兩個(gè) tensor 進(jìn)行加權(quán)求和,可以得到最終的 mask,如下所示

看起來(lái)效果還是蠻好的,掩碼系數(shù)學(xué)的還是蠻準(zhǔn)確的,所有冗余背景位置的系數(shù)都是負(fù)數(shù)。注意看頭頂,可以發(fā)現(xiàn)其實(shí) mask 預(yù)測(cè)很準(zhǔn)確,但是 bbox 不是很準(zhǔn)確,導(dǎo)致最終的 mask 頭部有缺失。下圖是最終的可視化效果。

2 總結(jié)
YOLACT 算法實(shí)現(xiàn)是非常有創(chuàng)新性的,拋棄了 Mask R-CNN 那套繁雜的設(shè)計(jì),整個(gè)流程非常簡(jiǎn)潔優(yōu)雅,并行預(yù)測(cè)當(dāng)前圖片的原型掩碼(prototype mask) 和每個(gè) bbox 實(shí)例的掩碼系數(shù)(mask coefficients),然后通過(guò)將原型與掩模系數(shù)線性組合來(lái)生成實(shí)例掩碼(instance masks),思想獨(dú)特,是一篇不可多得的問(wèn)題,值得學(xué)習(xí)學(xué)習(xí)。
推薦閱讀
谷歌AI用30億數(shù)據(jù)訓(xùn)練了一個(gè)20億參數(shù)Vision Transformer模型,在ImageNet上達(dá)到新的SOTA!
Transformer在語(yǔ)義分割上的應(yīng)用
"未來(lái)"的經(jīng)典之作ViT:transformer is all you need!
PVT:可用于密集任務(wù)backbone的金字塔視覺(jué)transformer!
漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA
不妨試試MoCo,來(lái)替換ImageNet上pretrain模型!
機(jī)器學(xué)習(xí)算法工程師
? ??? ? ? ? ? ? ? ? ? ? ????????? ??一個(gè)用心的公眾號(hào)

