圖片分類賽官方baseline解讀!

01?賽題背景
為進一步加快“6+5+6+1”西安現(xiàn)代產(chǎn)業(yè)以及養(yǎng)老服務(wù)等行業(yè)領(lǐng)域急需緊缺高技能人才培養(yǎng),動員廣大職工在迎十四運創(chuàng)文明城、建設(shè)國家中心城市、助力西安新時代追趕超越高質(zhì)量發(fā)展中展現(xiàn)新作為,市委組織部、市人社局、市總工會決定舉辦西安市2021年“迎全運、強技能、促提升”高技能人才技能大賽(全市計算機程序設(shè)計員技能大賽)。
02 數(shù)據(jù)分揀
本次比賽將提供10,000張垃圾圖片,其中8000張用于訓(xùn)練集,1,000張用于測試集。其中,每張圖片中的垃圾都屬于紙類、塑料、金屬、玻璃、廚余、電池這六類垃圾中的一類。
數(shù)據(jù)文件:
train.zip,訓(xùn)練集,包括7831張垃圾圖片。
validation.zip,測試集,包括2014張垃圾圖片。
train.csv,訓(xùn)練集圖片標簽,標簽為A-F,分別代表廚余、塑料、金屬、紙類、織物、玻璃。
03 數(shù)據(jù)分析
首先我們可以對賽題數(shù)據(jù)進行可視化,這里使用opencv讀取圖片并進行操作:
def show_image(paths):
plt.figure(figsize=(10, 8)) for idx, path in enumerate(paths):
plt.subplot(1, len(paths), idx+1)
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(img)
plt.xticks([]); plt.yticks([])
從圖中可以看出圖片主體尺寸較少,但背景所包含的像素較多。接下來可以對類別數(shù)量進行統(tǒng)計,在數(shù)據(jù)集中廚余最多,玻璃垃圾最少。數(shù)據(jù)集類別整體還是比較均衡,樣本比例沒有相差很大。

參考上面的操作,可以對數(shù)據(jù)集每類圖片進行可視化:

04 賽題建模
由于賽題任務(wù)是一個非常典型的圖像分類任務(wù),所以可以直接使用CNN模型訓(xùn)練的過程來完成。在本地比賽中如果使用得到的預(yù)訓(xùn)練模型越強,則最終的精度越好。
在構(gòu)建模型并進行訓(xùn)練之前,非常建議將訓(xùn)練集圖片提前進行縮放,這樣加快圖片的讀取速度,也可以加快模型的訓(xùn)練速度。具體的縮放代碼如下:
import cv2, glob, os import numpy as np
os.mkdir('train_512')
os.mkdir('validation_512')for path in glob.glob('./train/*'): if os.path.exists('train_512/' + path.split('/')[-1]): continue
img = cv2.imread(path) try:
img = cv2.resize(img, (512, 512))
cv2.imwrite('train_512/' + path.split('/')[-1], img) except: passfor path in glob.glob('./validation/*'): if os.path.exists('validation_512/' + path.split('/')[-1]): continue
img = cv2.imread(path) try:
img = cv2.resize(img, (512, 512))
cv2.imwrite('validation_512/' + path.split('/')[-1], img) except:
img = np.zeros((512, 512, 3))
cv2.imwrite('validation_512/' + path.split('/')[-1], img)Pytorch版本baseline
如果使用Pytorch,則需要按照如下步驟進行:
定義數(shù)據(jù)集
定義模型
模型訓(xùn)練和預(yù)測
class BiendataDataset(Dataset): def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
self.transform = transform def __getitem__(self, index): try:
img = Image.open(self.img_path[index]).convert('RGB') except:
index = 0
img = Image.open(self.img_path[index]).convert('RGB') if self.transform is not None:
img = self.transform(img)
label = torch.from_numpy(np.array([self.img_label[index]])) return img, label def __len__(self): return len(self.img_path)預(yù)訓(xùn)練模型推薦使用efficientnet,模型精度會更好。
import timm
model = timm.create_model('efficientnet_b4', num_classes=6,
pretrained=True, in_chans=3)具體的數(shù)據(jù)擴增方法為:
transforms.Compose([
transforms.Resize((300, 300)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomAffine(5, scale=[0.95, 1.05]),
transforms.RandomCrop((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])TF2.0版本baseline
如果使用TF2.0,則更加簡單:
定義ImageDataGenerator
定義模型
模型訓(xùn)練和預(yù)測
模型加載代碼為:
from efficientnet.tfkeras import EfficientNetB4
models = EfficientNetB4(weights='imagenet', include_top=False)具體的數(shù)據(jù)擴增方法為:
train_datagen = ImageDataGenerator(
rescale=1. / 255, # 歸一化
rotation_range=45, # 旋轉(zhuǎn)角度
width_shift_range=0.1, # 水平偏移
height_shift_range=0.1, # 垂直偏移
shear_range=0.1, # 隨機錯切變換的角度
zoom_range=0.25, # 隨機縮放的范圍
horizontal_flip=True, # 隨機將一半圖像水平翻轉(zhuǎn)
fill_mode='nearest' # 填充像素的方法
)05 賽題上分思路
如果使用baseline的思路,則可以取得線上0.85的成績。如果還想取得更優(yōu)的成績,可以考慮如下操作:
對數(shù)據(jù)集圖片的主體物體進行定位&檢測。
通過五折交叉驗證,訓(xùn)練得到5個模型然后對測試集進行投票。
對測試集結(jié)果進行數(shù)據(jù)擴增,然后進行投票。

?baseline地址:
https://www.biendata.xyz/media/download_file/21771129e38ed3f5b565af858fcd80b1.zip
參賽辦法
掃描下方二維碼或點擊“閱讀原文“
