<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          Pytorch圖像檢索實(shí)踐

          共 14680字,需瀏覽 30分鐘

           ·

          2023-08-17 04:58

          點(diǎn)擊上方小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂

          重磅干貨,第一時(shí)間送達(dá)

          隨著電子商務(wù)和在線網(wǎng)站的出現(xiàn),圖像檢索在我們的日常生活中的應(yīng)用一直在增加。

          亞馬遜、阿里巴巴、Myntra等公司一直在大量利用圖像檢索技術(shù)。當(dāng)然,只有當(dāng)通常的信息檢索技術(shù)失敗時(shí),圖像檢索才會(huì)開始工作。

          背景

          圖像檢索的基本本質(zhì)是根據(jù)查詢圖像的特征從集合或數(shù)據(jù)庫中查找圖像。

          大多數(shù)情況下,這種特征是圖像之間簡(jiǎn)單的視覺相似性。在一個(gè)復(fù)雜的問題中,這種特征可能是兩幅圖像在風(fēng)格上的相似性,甚至是互補(bǔ)性。

          由于原始形式的圖像不會(huì)在基于像素的數(shù)據(jù)中反映這些特征,因此我們需要將這些像素?cái)?shù)據(jù)轉(zhuǎn)換為一個(gè)潛空間,在該空間中,圖像的表示將反映這些特征。

          一般來說,在潛空間中,任何兩個(gè)相似的圖像都會(huì)相互靠近,而不同的圖像則會(huì)相隔很遠(yuǎn)。這是我們用來訓(xùn)練我們的模型的基本管理規(guī)則。一旦我們這樣做,檢索部分只需搜索潛在空間,在給定查詢圖像表示的潛在空間中拾取最近的圖像。大多數(shù)情況下,它是在最近鄰搜索的幫助下完成的。

          因此,我們可以將我們的方法分為兩部分:

          1. 圖像表現(xiàn)

          2. 搜索

          我們將在Oxford 102 Flowers數(shù)據(jù)集上解決這兩個(gè)部分。

          你可以在這里下載并閱讀有關(guān)數(shù)據(jù)集的信息:

          https://www.tensorflow.org/datasets/catalog/oxford_flowers102

          圖像表現(xiàn)

          我們將使用一種叫做暹羅模型的東西,它本身并不是一種全新的模型,而是一種訓(xùn)練模型的技術(shù)。大多數(shù)情況下,這是與triplet loss一起使用的。這個(gè)技術(shù)的基本組成部分是三元組。

          三元組是3個(gè)獨(dú)立的數(shù)據(jù)樣本,比如A(錨點(diǎn)),B(陽性)和C(陰性);其中A和B相似或具有相似的特征(可能是同一類),而C與A和B都不相似。這三個(gè)樣本共同構(gòu)成了訓(xùn)練數(shù)據(jù)的一個(gè)單元——三元組。

          注:任何圖像檢索任務(wù)的90%都體現(xiàn)在暹羅網(wǎng)絡(luò)、triplet loss和三元組的創(chuàng)建中。如果你成功地完成了這些,那么整個(gè)努力的成功或多或少是有保證的。

          首先,我們將創(chuàng)建管道的這個(gè)組件——數(shù)據(jù)。下面我們將在PyTorch中創(chuàng)建一個(gè)自定義數(shù)據(jù)集和數(shù)據(jù)加載器,它將從數(shù)據(jù)集中生成三元組。

          class TripletData(Dataset):
              def __init__(self, path, transforms, split="train"):

                  self.path = path
                  self.split = split    # train or valid
                  self.cats = 102       # number of categories
                  self.transforms = transforms

                  
              def __getitem__(self, idx):

                  # our positive class for the triplet
                  idx = str(idx%self.cats + 1)

                  # choosing our pair of positive images (im1, im2)
                  positives = os.listdir(os.path.join(self.path, idx))
                  im1, im2 = random.sample(positives, 2)

                  # choosing a negative class and negative image (im3)
                  negative_cats = [str(x+1for x in range(self.cats)]
                  negative_cats.remove(idx)
                  negative_cat = str(random.choice(negative_cats))
                  negatives = os.listdir(os.path.join(self.path, negative_cat))

                  im3 = random.choice(negatives)

                  im1,im2,im3 = os.path.join(self.path, idx, im1), os.path.join(self.path, idx, im2), os.path.join(self.path, negative_cat, im3)

                  im1 = self.transforms(Image.open(im1))

                  im2 = self.transforms(Image.open(im2))

                  im3 = self.transforms(Image.open(im3))

                  return [im1, im2, im3]

              
              # we'll put some value that we want since there can be far too many triplets possible
              # multiples of the number of images/ number of categories is a good choice
              def __len__(self):
                  return self.cats*8


          # Transforms
          train_transforms = transforms.Compose([
              transforms.Resize((224,224)),
              transforms.RandomHorizontalFlip(),
              transforms.ToTensor(),
              transforms.Normalize((0.49140.48220.4465), (0.20230.19940.2010)),
          ])



          val_transforms = transforms.Compose([
              transforms.Resize((224224)),
              transforms.ToTensor(),
              transforms.Normalize((0.49140.48220.4465), (0.20230.19940.2010)),
          ])





          # Datasets and Dataloaders
          train_data = TripletData(PATH_TRAIN, train_transforms)
          val_data = TripletData(PATH_VALID, val_transforms)


          train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=32, shuffle=True, num_workers=4)

          val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=32, shuffle=False, num_workers=4)

          現(xiàn)在我們有了數(shù)據(jù),讓我們轉(zhuǎn)到暹羅網(wǎng)絡(luò)。

          暹羅網(wǎng)絡(luò)給人的印象是2個(gè)或3個(gè)模型,但是它本身是一個(gè)單一的模型。所有這些模型共享權(quán)重,即只有一個(gè)模型。



          如前所述,將整個(gè)體系結(jié)構(gòu)結(jié)合在一起的關(guān)鍵因素是triplet loss。triplet loss產(chǎn)生了一個(gè)目標(biāo)函數(shù),該函數(shù)迫使相似輸入對(duì)(錨點(diǎn)和正)之間的距離小于不同輸入對(duì)(錨點(diǎn)和負(fù))之間的距離,并限定一定的閾值。

          下面我們來看看triplet loss以及訓(xùn)練管道實(shí)現(xiàn)。

          class TripletLoss(nn.Module):
              def __init__(self, margin=1.0):
                  
                  super(TripletLoss, self).__init__()
                  self.margin = margin
                  
                  
              def calc_euclidean(self, x1, x2):
                  return (x1 - x2).pow(2).sum(1)
              
              
              # Distances in embedding space is calculated in euclidean
              def forward(self, anchor, positive, negative):
                  
                  distance_positive = self.calc_euclidean(anchor, positive)
                  
                  distance_negative = self.calc_euclidean(anchor, negative)
                  
                  losses = torch.relu(distance_positive - distance_negative + self.margin)
                  
                  return losses.mean()
                

          device = 'cuda'

          # Our base model
          model = models.resnet18().cuda()
          optimizer = optim.Adam(model.parameters(), lr=0.001)
          triplet_loss = TripletLoss()

          # Training
          for epoch in range(epochs):
              
              model.train()
              epoch_loss = 0.0
              
              for data in tqdm(train_loader):
                  
                  optimizer.zero_grad()
                  x1,x2,x3 = data
                  e1 = model(x1.to(device))
                  e2 = model(x2.to(device))
                  e3 = model(x3.to(device)) 
                  
                  loss = triplet_loss(e1,e2,e3)
                  epoch_loss += loss
                  loss.backward()
                  optimizer.step()
                  
              print("Train Loss: {}".format(epoch_loss.item()))

              
              
          class TripletLoss(nn.Module):
              def __init__(self, margin=1.0):
                  
                  super(TripletLoss, self).__init__()
                  self.margin = margin
                  
                  
              def calc_euclidean(self, x1, x2):
                  return (x1 - x2).pow(2).sum(1)
              
              
              # Distances in embedding space is calculated in euclidean
              def forward(self, anchor, positive, negative):
                  
                  distance_positive = self.calc_euclidean(anchor, positive)
                  
                  distance_negative = self.calc_euclidean(anchor, negative)
                  
                  losses = torch.relu(distance_positive - distance_negative + self.margin)
                  
                  return losses.mean()
                

          device = 'cuda'


          # Our base model
          model = models.resnet18().cuda()
          optimizer = optim.Adam(model.parameters(), lr=0.001)
          triplet_loss = TripletLoss()


          # Training
          for epoch in range(epochs):
              model.train()
              epoch_loss = 0.0
              for data in tqdm(train_loader):

                  optimizer.zero_grad()
                  
                  x1,x2,x3 = data
                  
                  e1 = model(x1.to(device))
                  e2 = model(x2.to(device))
                  e3 = model(x3.to(device)) 
                  
                  loss = triplet_loss(e1,e2,e3)
                  epoch_loss += loss
                  loss.backward()
                  optimizer.step()
                  
              print("Train Loss: {}".format(epoch_loss.item()))

          到目前為止,我們的模型已經(jīng)經(jīng)過訓(xùn)練,可以將圖像轉(zhuǎn)換為一個(gè)嵌入空間。接下來,我們進(jìn)入搜索部分。

          搜索

          我們可以很容易地使用Scikit Learn提供的最近鄰搜索。我們將探索新的更好的東西,而不是走簡(jiǎn)單的路線。

          我們將使用Faiss。這比最近的鄰居要快得多,如果我們有大量的圖像,這種速度上的差異會(huì)變得更加明顯。

          下面我們將演示如何在給定查詢圖像時(shí),在存儲(chǔ)的圖像表示中搜索最近的圖像。

          #!pip install faiss-gpu
          import faiss                            
          faiss_index = faiss.IndexFlatL2(1000)   # build the index

          # storing the image representations
          im_indices = []

          with torch.no_grad():
              for f in glob.glob(os.path.join(PATH_TRAIN, '*/*')):
                  
                  im = Image.open(f)
                  im = im.resize((224,224))
                  im = torch.tensor([val_transforms(im).numpy()]).cuda()
              
                  preds = model(im)
                  preds = np.array([preds[0].cpu().numpy()])
                  faiss_index.add(preds) #add the representation to index
                  im_indices.append(f)   #store the image name to find it later on

                  
          # Retrieval with a query image
          with torch.no_grad():
              for f in os.listdir(PATH_TEST):
                  
                  # query/test image
                  im = Image.open(os.path.join(PATH_TEST,f))
                  im = im.resize((224,224))
                  im = torch.tensor([val_transforms(im).numpy()]).cuda()
              
                  test_embed = model(im).cpu().numpy()
                  
                  _, I = faiss_index.search(test_embed, 5)
                  print("Retrieved Image: {}".format(im_indices[I[0][0]]))

          這涵蓋了基于現(xiàn)代深度學(xué)習(xí)的圖像檢索,但不會(huì)使其變得太復(fù)雜。大多數(shù)檢索問題都可以通過這個(gè)基本管道解決。

          相關(guān)資源

          筆記本鏈接:https://www.kaggle.com/mayukh18/oxford-flowers-image-retrieval-pytorch

          圖像檢索社區(qū)中流行的基準(zhǔn)數(shù)據(jù)集:https://paperswithcode.com/task/image-retrieva

               
          下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程
          在「小白學(xué)視覺」公眾號(hào)后臺(tái)回復(fù):擴(kuò)展模塊中文教程即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺、目標(biāo)跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。

          下載2:Python視覺實(shí)戰(zhàn)項(xiàng)目52講
          小白學(xué)視覺公眾號(hào)后臺(tái)回復(fù):Python視覺實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測(cè)、車道線檢測(cè)、車輛計(jì)數(shù)、添加眼線、車牌識(shí)別、字符識(shí)別、情緒檢測(cè)、文本內(nèi)容提取、面部識(shí)別等31個(gè)視覺實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺。

          下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講
          小白學(xué)視覺公眾號(hào)后臺(tái)回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講,即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。

          交流群


          歡迎加入公眾號(hào)讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動(dòng)駕駛、計(jì)算攝影、檢測(cè)、分割、識(shí)別、醫(yī)學(xué)影像、GAN、算法競(jìng)賽等微信群(以后會(huì)逐漸細(xì)分),請(qǐng)掃描下面微信號(hào)加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請(qǐng)按照格式備注,否則不予通過。添加成功后會(huì)根據(jù)研究方向邀請(qǐng)進(jìn)入相關(guān)微信群。請(qǐng)勿在群內(nèi)發(fā)送廣告,否則會(huì)請(qǐng)出群,謝謝理解~


          瀏覽 270
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  黄色毛片操逼视频 | 奇米影视7777狠狠狠狠色 | 大香蕉操逼网欧美 | 黄色www91 | 乱伦六区 |