實(shí)戰(zhàn)教學(xué)!Pytorch圖像檢索實(shí)踐
點(diǎn)擊下方“AI算法與圖像處理”,一起進(jìn)步!
重磅干貨,第一時(shí)間送達(dá)

隨著電子商務(wù)和在線網(wǎng)站的出現(xiàn),圖像檢索在我們的日常生活中的應(yīng)用一直在增加。
亞馬遜、阿里巴巴、Myntra等公司一直在大量利用圖像檢索技術(shù)。當(dāng)然,只有當(dāng)通常的信息檢索技術(shù)失敗時(shí),圖像檢索才會(huì)開始工作。
背景
圖像檢索的基本本質(zhì)是根據(jù)查詢圖像的特征從集合或數(shù)據(jù)庫(kù)中查找圖像。
大多數(shù)情況下,這種特征是圖像之間簡(jiǎn)單的視覺相似性。在一個(gè)復(fù)雜的問題中,這種特征可能是兩幅圖像在風(fēng)格上的相似性,甚至是互補(bǔ)性。
由于原始形式的圖像不會(huì)在基于像素的數(shù)據(jù)中反映這些特征,因此我們需要將這些像素?cái)?shù)據(jù)轉(zhuǎn)換為一個(gè)潛空間,在該空間中,圖像的表示將反映這些特征。
一般來(lái)說,在潛空間中,任何兩個(gè)相似的圖像都會(huì)相互靠近,而不同的圖像則會(huì)相隔很遠(yuǎn)。這是我們用來(lái)訓(xùn)練我們的模型的基本管理規(guī)則。一旦我們這樣做,檢索部分只需搜索潛在空間,在給定查詢圖像表示的潛在空間中拾取最近的圖像。大多數(shù)情況下,它是在最近鄰搜索的幫助下完成的。
因此,我們可以將我們的方法分為兩部分:
圖像表現(xiàn)
搜索
我們將在Oxford 102 Flowers數(shù)據(jù)集上解決這兩個(gè)部分。
你可以在這里下載并閱讀有關(guān)數(shù)據(jù)集的信息:
https://www.tensorflow.org/datasets/catalog/oxford_flowers102
圖像表現(xiàn)
我們將使用一種叫做暹羅模型的東西,它本身并不是一種全新的模型,而是一種訓(xùn)練模型的技術(shù)。大多數(shù)情況下,這是與triplet loss一起使用的。這個(gè)技術(shù)的基本組成部分是三元組。
三元組是3個(gè)獨(dú)立的數(shù)據(jù)樣本,比如A(錨點(diǎn)),B(陽(yáng)性)和C(陰性);其中A和B相似或具有相似的特征(可能是同一類),而C與A和B都不相似。這三個(gè)樣本共同構(gòu)成了訓(xùn)練數(shù)據(jù)的一個(gè)單元——三元組。
注:任何圖像檢索任務(wù)的90%都體現(xiàn)在暹羅網(wǎng)絡(luò)、triplet loss和三元組的創(chuàng)建中。如果你成功地完成了這些,那么整個(gè)努力的成功或多或少是有保證的。
首先,我們將創(chuàng)建管道的這個(gè)組件——數(shù)據(jù)。下面我們將在PyTorch中創(chuàng)建一個(gè)自定義數(shù)據(jù)集和數(shù)據(jù)加載器,它將從數(shù)據(jù)集中生成三元組。
class?TripletData(Dataset):
????def?__init__(self,?path,?transforms,?split="train"):
????????self.path?=?path
????????self.split?=?split????#?train?or?valid
????????self.cats?=?102???????#?number?of?categories
????????self.transforms?=?transforms
????????
????def?__getitem__(self,?idx):
????????#?our?positive?class?for?the?triplet
????????idx?=?str(idx%self.cats?+?1)
????????#?choosing?our?pair?of?positive?images?(im1,?im2)
????????positives?=?os.listdir(os.path.join(self.path,?idx))
????????im1,?im2?=?random.sample(positives,?2)
????????#?choosing?a?negative?class?and?negative?image?(im3)
????????negative_cats?=?[str(x+1)?for?x?in?range(self.cats)]
????????negative_cats.remove(idx)
????????negative_cat?=?str(random.choice(negative_cats))
????????negatives?=?os.listdir(os.path.join(self.path,?negative_cat))
????????im3?=?random.choice(negatives)
????????im1,im2,im3?=?os.path.join(self.path,?idx,?im1),?os.path.join(self.path,?idx,?im2),?os.path.join(self.path,?negative_cat,?im3)
????????im1?=?self.transforms(Image.open(im1))
????????im2?=?self.transforms(Image.open(im2))
????????im3?=?self.transforms(Image.open(im3))
????????return?[im1,?im2,?im3]
????
????#?we'll?put?some?value?that?we?want?since?there?can?be?far?too?many?triplets?possible
????#?multiples?of?the?number?of?images/?number?of?categories?is?a?good?choice
????def?__len__(self):
????????return?self.cats*8
#?Transforms
train_transforms?=?transforms.Compose([
????transforms.Resize((224,224)),
????transforms.RandomHorizontalFlip(),
????transforms.ToTensor(),
????transforms.Normalize((0.4914,?0.4822,?0.4465),?(0.2023,?0.1994,?0.2010)),
])
val_transforms?=?transforms.Compose([
????transforms.Resize((224,?224)),
????transforms.ToTensor(),
????transforms.Normalize((0.4914,?0.4822,?0.4465),?(0.2023,?0.1994,?0.2010)),
])
#?Datasets?and?Dataloaders
train_data?=?TripletData(PATH_TRAIN,?train_transforms)
val_data?=?TripletData(PATH_VALID,?val_transforms)
train_loader?=?torch.utils.data.DataLoader(dataset?=?train_data,?batch_size=32,?shuffle=True,?num_workers=4)
val_loader?=?torch.utils.data.DataLoader(dataset?=?val_data,?batch_size=32,?shuffle=False,?num_workers=4)
現(xiàn)在我們有了數(shù)據(jù),讓我們轉(zhuǎn)到暹羅網(wǎng)絡(luò)。
暹羅網(wǎng)絡(luò)給人的印象是2個(gè)或3個(gè)模型,但是它本身是一個(gè)單一的模型。所有這些模型共享權(quán)重,即只有一個(gè)模型。

如前所述,將整個(gè)體系結(jié)構(gòu)結(jié)合在一起的關(guān)鍵因素是triplet loss。triplet loss產(chǎn)生了一個(gè)目標(biāo)函數(shù),該函數(shù)迫使相似輸入對(duì)(錨點(diǎn)和正)之間的距離小于不同輸入對(duì)(錨點(diǎn)和負(fù))之間的距離,并限定一定的閾值。
下面我們來(lái)看看triplet loss以及訓(xùn)練管道實(shí)現(xiàn)。
class?TripletLoss(nn.Module):
????def?__init__(self,?margin=1.0):
????????
????????super(TripletLoss,?self).__init__()
????????self.margin?=?margin
????????
????????
????def?calc_euclidean(self,?x1,?x2):
????????return?(x1?-?x2).pow(2).sum(1)
????
????
????#?Distances?in?embedding?space?is?calculated?in?euclidean
????def?forward(self,?anchor,?positive,?negative):
????????
????????distance_positive?=?self.calc_euclidean(anchor,?positive)
????????
????????distance_negative?=?self.calc_euclidean(anchor,?negative)
????????
????????losses?=?torch.relu(distance_positive?-?distance_negative?+?self.margin)
????????
????????return?losses.mean()
??????
device?=?'cuda'
#?Our?base?model
model?=?models.resnet18().cuda()
optimizer?=?optim.Adam(model.parameters(),?lr=0.001)
triplet_loss?=?TripletLoss()
#?Training
for?epoch?in?range(epochs):
????
????model.train()
????epoch_loss?=?0.0
????
????for?data?in?tqdm(train_loader):
????????
????????optimizer.zero_grad()
????????x1,x2,x3?=?data
????????e1?=?model(x1.to(device))
????????e2?=?model(x2.to(device))
????????e3?=?model(x3.to(device))?
????????
????????loss?=?triplet_loss(e1,e2,e3)
????????epoch_loss?+=?loss
????????loss.backward()
????????optimizer.step()
????????
????print("Train?Loss:?{}".format(epoch_loss.item()))
????
????
class?TripletLoss(nn.Module):
????def?__init__(self,?margin=1.0):
????????
????????super(TripletLoss,?self).__init__()
????????self.margin?=?margin
????????
????????
????def?calc_euclidean(self,?x1,?x2):
????????return?(x1?-?x2).pow(2).sum(1)
????
????
????#?Distances?in?embedding?space?is?calculated?in?euclidean
????def?forward(self,?anchor,?positive,?negative):
????????
????????distance_positive?=?self.calc_euclidean(anchor,?positive)
????????
????????distance_negative?=?self.calc_euclidean(anchor,?negative)
????????
????????losses?=?torch.relu(distance_positive?-?distance_negative?+?self.margin)
????????
????????return?losses.mean()
??????
device?=?'cuda'
#?Our?base?model
model?=?models.resnet18().cuda()
optimizer?=?optim.Adam(model.parameters(),?lr=0.001)
triplet_loss?=?TripletLoss()
#?Training
for?epoch?in?range(epochs):
????model.train()
????epoch_loss?=?0.0
????for?data?in?tqdm(train_loader):
????????optimizer.zero_grad()
????????
????????x1,x2,x3?=?data
????????
????????e1?=?model(x1.to(device))
????????e2?=?model(x2.to(device))
????????e3?=?model(x3.to(device))?
????????
????????loss?=?triplet_loss(e1,e2,e3)
????????epoch_loss?+=?loss
????????loss.backward()
????????optimizer.step()
????????
????print("Train?Loss:?{}".format(epoch_loss.item()))
到目前為止,我們的模型已經(jīng)經(jīng)過訓(xùn)練,可以將圖像轉(zhuǎn)換為一個(gè)嵌入空間。接下來(lái),我們進(jìn)入搜索部分。
搜索
我們可以很容易地使用Scikit Learn提供的最近鄰搜索。我們將探索新的更好的東西,而不是走簡(jiǎn)單的路線。
我們將使用Faiss。這比最近的鄰居要快得多,如果我們有大量的圖像,這種速度上的差異會(huì)變得更加明顯。
下面我們將演示如何在給定查詢圖像時(shí),在存儲(chǔ)的圖像表示中搜索最近的圖像。
#!pip?install?faiss-gpu
import?faiss????????????????????????????
faiss_index?=?faiss.IndexFlatL2(1000)???#?build?the?index
#?storing?the?image?representations
im_indices?=?[]
with?torch.no_grad():
????for?f?in?glob.glob(os.path.join(PATH_TRAIN,?'*/*')):
????????
????????im?=?Image.open(f)
????????im?=?im.resize((224,224))
????????im?=?torch.tensor([val_transforms(im).numpy()]).cuda()
????
????????preds?=?model(im)
????????preds?=?np.array([preds[0].cpu().numpy()])
????????faiss_index.add(preds)?#add?the?representation?to?index
????????im_indices.append(f)???#store?the?image?name?to?find?it?later?on
????????
#?Retrieval?with?a?query?image
with?torch.no_grad():
????for?f?in?os.listdir(PATH_TEST):
????????
????????#?query/test?image
????????im?=?Image.open(os.path.join(PATH_TEST,f))
????????im?=?im.resize((224,224))
????????im?=?torch.tensor([val_transforms(im).numpy()]).cuda()
????
????????test_embed?=?model(im).cpu().numpy()
????????
????????_,?I?=?faiss_index.search(test_embed,?5)
????????print("Retrieved?Image:?{}".format(im_indices[I[0][0]]))
這涵蓋了基于現(xiàn)代深度學(xué)習(xí)的圖像檢索,但不會(huì)使其變得太復(fù)雜。大多數(shù)檢索問題都可以通過這個(gè)基本管道解決。
相關(guān)資源:
筆記本鏈接:https://www.kaggle.com/mayukh18/oxford-flowers-image-retrieval-pytorch
圖像檢索社區(qū)中流行的基準(zhǔn)數(shù)據(jù)集:https://paperswithcode.com/task/image-retrieva
交流群
歡迎加入公眾號(hào)讀者群一起和同行交流,目前有美顏、三維視覺、計(jì)算攝影、檢測(cè)、分割、識(shí)別、醫(yī)學(xué)影像、GAN、算法競(jìng)賽等微信群
個(gè)人微信(如果沒有備注不拉群!) 請(qǐng)注明:地區(qū)+學(xué)校/企業(yè)+研究方向+昵稱
下載1:何愷明頂會(huì)分享
在「AI算法與圖像處理」公眾號(hào)后臺(tái)回復(fù):何愷明,即可下載。總共有6份PDF,涉及 ResNet、Mask RCNN等經(jīng)典工作的總結(jié)分析
下載2:終身受益的編程指南:Google編程風(fēng)格指南
在「AI算法與圖像處理」公眾號(hào)后臺(tái)回復(fù):c++,即可下載。歷經(jīng)十年考驗(yàn),最權(quán)威的編程規(guī)范!
下載3 CVPR2021 在「AI算法與圖像處理」公眾號(hào)后臺(tái)回復(fù):CVPR,即可下載1467篇CVPR?2020論文 和 CVPR 2021 最新論文

