<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          使用Transformer來做物體檢測

          共 13755字,需瀏覽 28分鐘

           ·

          2021-08-13 00:43

          點擊左上方藍字關注我們



          一個專注于目標檢測與深度學習知識分享的公眾號

          編者薦語
          DEtection TRansformer (DETR)是Facebook研究團隊巧妙地利用了Transformer 架構開發(fā)的一個目標檢測模型。文章中,作者將通過分析DETR架構的內(nèi)部工作方式來幫助提供一些關于它的直覺。

          轉載自 | 學算法的小黑狗



          介紹

          DEtection TRansformer (DETR)是Facebook研究團隊巧妙地利用了Transformer 架構開發(fā)的一個目標檢測模型。在這篇文章中,我將通過分析DETR架構的內(nèi)部工作方式來幫助提供一些關于它的直覺。

          下面,我將解釋一些結構,但是如果你只是想了解如何使用模型,可以直接跳到代碼部分。


          結構

          DETR模型由一個預訓練的CNN骨干(如ResNet)組成,它產(chǎn)生一組低維特征集。這些特征被格式化為一個特征集合并添加位置編碼,輸入一個由Transformer組成的編碼器和解碼器中,和原始的Transformer論文中描述的Encoder-Decoder的使用方式非常的類似。解碼器的輸出然后被送入固定數(shù)量的預測頭,這些預測頭由預定義數(shù)量的前饋網(wǎng)絡組成。每個預測頭的輸出都包含一個類預測和一個預測框。損失是通過計算二分匹配損失來計算的。

          該模型做出了預定義數(shù)量的預測,并且每個預測都是并行計算的。


          CNN主干

          假設我們的輸入圖像,有三個輸入通道。CNN backbone由一個(預訓練過的)CNN(通常是ResNet)組成,我們用它來生成C個具有寬度W和高度H的低維特征(在實踐中,我們設置C=2048, W=W?/32和H=H?/32)。

          這留給我們的是C個二維特征,由于我們將把這些特征傳遞給一個transformer,每個特征必須允許編碼器將每個特征處理為一個序列的方式重新格式化。這是通過將特征矩陣扁平化為H?W向量,然后將每個向量連接起來來實現(xiàn)的。

          扁平化的卷積特征再加上空間位置編碼,位置編碼既可以學習,也可以預定義。


          The Transformer

          Transformer幾乎與原始的編碼器-解碼器架構完全相同。不同之處在于,每個解碼器層并行解碼N個(預定義的數(shù)目)目標。該模型還學習了一組N個目標的查詢,這些查詢是(類似于編碼器)學習出來的位置編碼。


          目標查詢

          下圖描述了N=20個學習出來的目標查詢(稱為prediction slots)如何聚焦于一張圖像的不同區(qū)域。

          “我們觀察到,在不同的操作模式下,每個slot 都會學習特定的區(qū)域和框大小。“ —— DETR的作者

          理解目標查詢的直觀方法是想象每個目標查詢都是一個人。每個人都可以通過注意力來查看圖像的某個區(qū)域。一個目標查詢總是會問圖像中心是什么,另一個總是會問左下角是什么,以此類推。


          使用PyTorch實現(xiàn)簡單的DETR

          import torch
          import torch.nn as nn
          from torchvision.models import resnet50

          class SimpleDETR(nn.Module):
          """
          Minimal Example of the Detection Transformer model with learned positional embedding
          """

           def __init__(self, num_classes, hidden_dim, num_heads,
                       num_enc_layers, num_dec_layers)
          :

              super(SimpleDETR, self).__init__()
              self.num_classes = num_classes
              self.hidden_dim = hidden_dim
              self.num_heads = num_heads
              self.num_enc_layers = num_enc_layers
              self.num_dec_layers = num_dec_layers
              # CNN Backbone
              self.backbone = nn.Sequential(
                   *list(resnet50(pretrained=True).children())[:-2])
              self.conv = nn.Conv2d(2048, hidden_dim, 1)
              # Transformer
              self.transformer = nn.Transformer(hidden_dim, num_heads,
                   num_enc_layers, num_dec_layers)
              # Prediction Heads
              self.to_classes = nn.Linear(hidden_dim, num_classes+1)
              self.to_bbox = nn.Linear(hidden_dim, 4)
              # Positional Encodings
              self.object_query = nn.Parameter(torch.rand(100, hidden_dim))
              self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)
              self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
                                            
           def forward(self, X):
              X = self.backbone(X)
              h = self.conv(X)
              H, W = h.shape[-2:]
              pos_enc = torch.cat([
                    self.col_embed[:W].unsqueeze(0).repeat(H,1,1),
                    self.row_embed[:H].unsqueeze(1).repeat(1,W,1)],
                 dim=-1).flatten(0,1).unsqueeze(1)
              h = self.transformer(pos_enc + h.flatten(2).permute(2,0,1),
              self.object_query.unsqueeze(1))
              class_pred = self.to_classes(h)
              bbox_pred = self.to_bbox(h).sigmoid()
              
              return class_pred, bbox_pred


          二分匹配損失 (Optional)

          為預測的集合,其中是包括了預測類別(可以是空類別)和包圍框的二元組,其中上劃線表示框的中心點, 和表示框的寬和高。
          設y為ground truth集合。假設y和?之間的損失為L,每一個y?和??之間的損失為L?。由于我們是在集合的層次上工作,損失L必須是排列不變的,這意味著無論我們?nèi)绾闻判蝾A測,我們都將得到相同的損失。因此,我們想找到一個排列,它將預測的索引映射到ground truth目標的索引上。在數(shù)學上,我們求解:

          計算的過程稱為尋找最優(yōu)的二元匹配。這可以用匈牙利算法找到。但為了找到最優(yōu)匹配,我們需要實際定義一個損失函數(shù),計算之間的匹配成本。
          回想一下,我們的預測包含一個邊界框和一個類。現(xiàn)在讓我們假設類預測實際上是一個類集合上的概率分布。那么第i個預測的總損失將是類預測產(chǎn)生的損失和邊界框預測產(chǎn)生的損失之和。作者在http://arxiv.org/abs/1906.05909中將這種損失定義為邊界框損失和類預測概率的差異:

          其中,的argmax,是是來自包圍框的預測的損失,如果,則表示匹配損失為0。

          框損失的計算為預測值與ground truth的L?損失和的GIOU損失的線性組合。同樣,如果你想象兩個不相交的框,那么框的錯誤將不會提供任何有意義的上下文(我們可以從下面的框損失的定義中看到)。

          其中,λ???和是超參數(shù)。注意,這個和也是面積和距離產(chǎn)生的誤差的組合。為什么會這樣呢?

          可以把上面的等式看作是與預測相關聯(lián)的總損失,其中面積誤差的重要性是λ???,距離誤差的重要性是

          現(xiàn)在我們來定義GIOU損失函數(shù)。定義如下:


          由于我們從已知的已知類的數(shù)目來預測類,那么類預測就是一個分類問題,因此我們可以使用交叉熵損失來計算類預測誤差。我們將損失函數(shù)定義為每N個預測損失的總和:


          為目標檢測使用DETR

          在這里,你可以學習如何加載預訓練的DETR模型,以便使用PyTorch進行目標檢測。

          8.1 加載模型

          首先導入需要的模塊。

          # Import required modules
          import torch
          from torchvision import transforms as T import requests # for loading images from web
          from PIL import Image # for viewing images
          import matplotlib.pyplot as plt

          下面的代碼用ResNet50作為CNN骨干從torch hub加載預訓練的模型。其他主干請參見DETR github:https://github.com/facebookresearch/detr

          detr = torch.hub.load('facebookresearch/detr',
                                'detr_resnet50',
                                 pretrained=True)

          8.2 加載一張圖像

          要從web加載圖像,我們使用requests庫:

          url = 'https://www.tempetourism.com/wp-content/uploads/Postino-Downtown-Tempe-2.jpg' # Sample imageimage = Image.open(requests.get(url, stream=True).raw) plt.imshow(image)
          plt.show()

          8.3 設置目標檢測的Pipeline

          為了將圖像輸入到模型中,我們需要將PIL圖像轉換為張量,這是通過使用torchvision的transforms庫來完成的。

          transform = T.Compose([T.Resize(800),
                                 T.ToTensor(),
                                 T.Normalize([0.4850.4560.406],
                                            [0.2290.2240.225])])

          上面的變換調整了圖像的大小,將PIL圖像進行轉換,并用均值-標準差對圖像進行歸一化。其中[0.485,0.456,0.406]為各顏色通道的均值,[0.229,0.224,0.225]為各顏色通道的標準差。

          我們裝載的模型是預先在COCO Dataset上訓練的,有91個類,還有一個表示空類(沒有目標)的附加類。我們用下面的代碼手動定義每個標簽:

          CLASSES = 
          ['N/A''Person''Bicycle''Car''Motorcycle''Airplane''Bus''Train''Truck''Boat''Traffic-Light''Fire-Hydrant''N/A''Stop-Sign''Parking Meter''Bench''Bird''Cat''Dog''Horse''Sheep''Cow''Elephant''Bear''Zebra''Giraffe''N/A''Backpack''Umbrella''N/A''N/A''Handbag''Tie''Suitcase''Frisbee''Skis''Snowboard''Sports-Ball''Kite''Baseball Bat''Baseball Glove''Skateboard''Surfboard''Tennis Racket''Bottle''N/A''Wine Glass''Cup''Fork''Knife''Spoon''Bowl''Banana''Apple''Sandwich''Orange''Broccoli''Carrot''Hot-Dog''Pizza''Donut''Cake''Chair''Couch''Potted Plant''Bed''N/A''Dining Table''N/A','N/A''Toilet''N/A''TV''Laptop''Mouse''Remote''Keyboard''Cell-Phone''Microwave''Oven''Toaster''Sink''Refrigerator''N/A''Book''Clock''Vase''Scissors''Teddy-Bear''Hair-Dryer''Toothbrush']

          如果我們想輸出不同顏色的邊框,我們可以手動定義我們想要的RGB格式的顏色

          COLORS = [
              [0.0000.4470.741], 
              [0.8500.3250.098], 
              [0.9290.6940.125],
              [0.4940.1840.556],
              [0.4660.6740.188],
              [0.3010.7450.933]  
          ]

          8.4 格式化輸出

          我們還需要重新格式化模型的輸出。給定一個轉換后的圖像,模型將輸出一個字典,包含100個預測類的概率和100個預測邊框。

          每個包圍框的形式為(x, y, w, h),其中(x,y)為包圍框的中心(包圍框是單位正方形[0,1]×[0,1]), w, h為包圍框的寬度和高度。因此,我們需要將邊界框輸出轉換為初始和最終坐標,并重新縮放框以適應圖像的實際大小。

          下面的函數(shù)返回邊界框端點:

          # Get coordinates (x0, y0, x1, y0) from model output (x, y, w, h)def get_box_coords(boxes):
              x, y, w, h = boxes.unbind(1)
              x0, y0 = (x - 0.5 * w), (y - 0.5 * h)
              x1, y1 = (x + 0.5 * w), (y + 0.5 * h)
              box = [x0, y0, x1, y1]
              return torch.stack(box, dim=1)

          我們還需要縮放了框的大小。下面的函數(shù)為我們做了這些:

          # Scale box from [0,1]x[0,1] to [0, width]x[0, height]def scale_boxes(output_box, width, height):
              box_coords = get_box_coords(output_box)
              scale_tensor = torch.Tensor(
                           [width, height, width, height]).to(
                           torch.cuda.current_device())    return box_coords * scale_tensor

          現(xiàn)在我們需要一個函數(shù)來封裝我們的目標檢測pipeline。下面的detect函數(shù)為我們完成了這項工作。

          # Object Detection Pipelinedef detect(im, model, transform):
              device = torch.cuda.current_device()
              width = im.size[0]
              height = im.size[1]
             
              # mean-std normalize the input image (batch-size: 1)
              img = transform(im).unsqueeze(0)
              img = img.to(device)
              
              # demo model only support by default images with aspect ratio    between 0.5 and 2    assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600,    # propagate through the model
              outputs = model(img)    # keep only predictions with 0.7+ confidence
              probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
              keep = probas.max(-1).values > 0.85
             
              # convert boxes from [0; 1] to image scales
              bboxes_scaled = scale_boxes(outputs['pred_boxes'][0, keep], width, height)    return probas[keep], bboxes_scaled

          現(xiàn)在,我們需要做的是運行以下程序來獲得我們想要的輸出:

          probs, bboxes = detect(image, detr, transform)

          8.5 繪制結果

          現(xiàn)在我們有了檢測到的目標,我們可以使用一個簡單的函數(shù)來可視化它們。

          # Plot Predicted Bounding Boxesdef plot_results(pil_img, prob, boxes,labels=True):
              plt.figure(figsize=(16,10))
              plt.imshow(pil_img)
              ax = plt.gca()
              
              for prob, (x0, y0, x1, y1), color in zip(prob, boxes.tolist(),   COLORS * 100):        ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0,  
                       fill=False, color=color, linewidth=2))
                  cl = prob.argmax()
                  text = f'{CLASSES[cl]}{prob[cl]:0.2f}'
                  if labels:
                      ax.text(x0, y0, text, fontsize=15,
                          bbox=dict(facecolor=color, alpha=0.75))
              plt.axis('off')
              plt.show()

          現(xiàn)在可以可視化結果:

          plot_results(image, probs, bboxes, labels=True)




          英文原文

          https://medium.com/swlh/object-detection-with-transformers-437217a3d62e


          END



          雙一流大學研究生團隊創(chuàng)建,專注于目標檢測與深度學習,希望可以將分享變成一種習慣!

          整理不易,點贊三連↓

          瀏覽 65
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  欧美国产A片 | 超碰碰免费 | 午夜福利精品 | 欧美日韩亚洲中文字幕 | 无码理论片 |