目標(biāo)檢測(cè)漲點(diǎn)方法:目標(biāo)框加權(quán)融合-Weighted Boxes Fusion(源碼分享)
01
NMS&WBF
1.1 NMS
NMS消除冗余的邊界框的流程如下:
根據(jù)置信度得分進(jìn)行排序
選擇置信度最高的比邊界框添加到最終輸出列表中,將其從邊界框列表中刪除
計(jì)算所有邊界框的面積
計(jì)算置信度最高的邊界框與其它候選框的IoU。
刪除IoU大于閾值的邊界框
重復(fù)上述過(guò)程,直至邊界框列表為空。
NMS方法雖然簡(jiǎn)單有效,但在更高的目標(biāo)檢測(cè)需求下,也存在如下缺點(diǎn):
密集重疊場(chǎng)景造成誤過(guò)濾:將得分較低的邊框強(qiáng)制性地去掉,如果物體出現(xiàn)較為密集時(shí),本身屬于兩個(gè)物體的邊框,其中得分較低的也有可能被抑制掉,降低了模型的召回率。
速度:NMS的實(shí)現(xiàn)存在較多的循環(huán)步驟,GPU的并行化實(shí)現(xiàn)不是特別容易,尤其是預(yù)測(cè)框較多時(shí),耗時(shí)較多。
框的置信度和分類(lèi)的置信度并不是完全對(duì)齊的,NMS簡(jiǎn)單地將得分作為一個(gè)邊框的置信度,但在一些情況下,得分高的邊框不一定位置更準(zhǔn)
1.2 WBF
下面是WBF的算法步驟:
每個(gè)模型的每個(gè)預(yù)測(cè)框都添加到List B,并將此列表按置信度得分C降序排列
建立空List L 和 F(用于融合的)
循環(huán)遍歷B,并在F中找到于之匹配的box(同一類(lèi)別MIOU > 0.55)
如果 step3 中沒(méi)有找到匹配的box 就將這個(gè)框加到L和F的尾部
如果 step3 中找到了匹配的box 就將這個(gè)框加到L,加入的位置是box在F中匹配框的Index.
L中每個(gè)位置可能有多個(gè)框,需要根據(jù)這多個(gè)框更新對(duì)應(yīng)F[index]的值。使用所有的在L[index]的T個(gè)邊界框重新計(jì)算F [index] 中邊界框的坐標(biāo)和置信度得分:
(1)
02
WBF快速實(shí)現(xiàn)目標(biāo)檢測(cè)漲點(diǎn)的方法
WBF可以實(shí)現(xiàn)目標(biāo)框加權(quán)融合消除冗余的邊界框,基于它的特點(diǎn)我們可以使用WBF快速實(shí)現(xiàn)目標(biāo)檢測(cè)漲點(diǎn),主要方法有以下幾種:
1、用WBF替代檢測(cè)網(wǎng)絡(luò)中的NMS方法
2、利用WBF實(shí)現(xiàn)測(cè)試增強(qiáng)TTA結(jié)果的融合
3、利用WBF實(shí)現(xiàn)目標(biāo)檢測(cè)的多模型集成
接下來(lái)我們將以YOLOV5代碼進(jìn)行舉例,完整使用代碼可以關(guān)注公眾號(hào)后臺(tái)回復(fù)"WBF"獲取
2.1 用WBF替代檢測(cè)網(wǎng)絡(luò)中的NMS方法
pip install ensemble-boxes
pip install --upgrade pippip install --upgrade setuptools
pip install --no-deps '../input/weightedboxesfusion/'
2、調(diào)用WBF的庫(kù):
from ensemble_boxes import *
boxes, scores, labels = weighted_boxes_fusion(boxes_list, scores_list, labels_list, weights=weights, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
= model(img, augment=opt.augment)[0]pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
pred = model(img, augment=False)[0]boxes = []scores = []clses = []pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)for i, det in enumerate(pred): # detections per imageif det is not None and len(det):# Rescale boxes from img_size to im0 sizedet[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()for *xyxy, conf, cls in det:boxes.append([int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])])scores.append(conf)clses.append(cls)boxes = np.array(boxes)scores = np.array(scores)clses = np.array(clses )def run_wbf(boxes, scores, clses, image_size=1024, iou_thr=0.5, skip_box_thr=0.7, weights=None):boxes = [box/(image_size) for box in boxes]boxes, scores, labels = weighted_boxes_fusion(boxes, scores, clses, weights=None, iou_thr=iou_thr, skip_box_thr=skip_box_thr)boxes = boxes*(image_size)return boxes, scores, labelspred_boxes, scores, labels = run_wbf(boxes , scores , clses , image_size, iou_thr=iou_thr, skip_box_thr=skip_box_thr)

def TTAImage(image, index):image1 = image.copy()if index==0:rotated_image = cv2.rotate(image1, cv2.ROTATE_90_CLOCKWISE)return rotated_imageelif index==1:rotated_image2 = cv2.rotate(image1, cv2.ROTATE_90_CLOCKWISE)rotated_image2 = cv2.rotate(rotated_image2, cv2.ROTATE_90_CLOCKWISE)return rotated_image2elif index==2:rotated_image3 = cv2.rotate(image1, cv2.ROTATE_90_CLOCKWISE)rotated_image3 = cv2.rotate(rotated_image3, cv2.ROTATE_90_CLOCKWISE)rotated_image3 = cv2.rotate(rotated_image3, cv2.ROTATE_90_CLOCKWISE)return rotated_image3elif index == 3:return image1def rotBoxes90(boxes, im_w, im_h):ret_boxes =[]for box in boxes:y1, x2, y2 = boxy1, x2, y2 = x1-im_w//2, im_h//2 - y1, x2-im_w//2, im_h//2 - y2y1, x2, y2 = y1, -x1, y2, -x2y1, x2, y2 = int(x1+im_w//2), int(im_h//2 - y1), int(x2+im_w//2), int(im_h//2 - y2)y1a, x2a, y2a = min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)y1a, x2a, y2a])return np.array(ret_boxes)
im_w, im_h = im01.shape[:2]enboxes = []enscores = []for i in range(4):im0 = TTAImage(im01, i)boxes, scores = detect1Image(im0, imgsz, model, device, conf_thres, iou_thres)for _ in range(3-i):boxes = rotBoxes90(boxes, im_w, im_h)if 1: #i<3:enboxes.append(boxes)enscores.append(scores)boxes, scores = detect1Image(im01, imgsz, model, device, conf_thres, iou_thres)enboxes.append(boxes)enscores.append(scores)def run_wbf(boxes, scores, image_size=1023, iou_thr=0.5, skip_box_thr=0.7, weights=None):labels = [np.zeros(score.shape[0]) for score in scores]boxes = [box/(image_size) for box in boxes]boxes, scores, labels = weighted_boxes_fusion(boxes, scores, labels, weights=None, iou_thr=iou_thr, skip_box_thr=skip_box_thr)boxes = boxes*(image_size)return boxes, scores, labelsboxes, scores, labels = run_wbf(enboxes, enscores, image_size = im_w, iou_thr=0.6, skip_box_thr=0.5)
weights = 'weights/best.pt'weights1 = "../input/otherweight/best_yolov5x_fold0.pt"# load modelmodel = torch.load(weights, map_location=device)['model'].float() # load to FP32model.to(device).eval()# Load model 1model1 = torch.load(weights1, map_location=device)['model'].float() # load to FP32model1.to(device).eval()boxes, scores = detect1Image(im0, imgsz, model, device, conf_thres, iou_thres)boxes, scores = detect1Image(im0, imgsz, model1, device, conf_thres, iou_thres)boxes, scores, labels = weighted_boxes_fusion(boxes, scores, labels, weights=None, iou_thr=iou_thr, skip_box_thr=skip_box_thr)

03
總結(jié)
?------------------------------------------------
雙一流高校研究生團(tuán)隊(duì)創(chuàng)建 ↓
專(zhuān)注于計(jì)算機(jī)視覺(jué)原創(chuàng)并分享相關(guān)知識(shí) ?

