分享 | 基于圖像分類網(wǎng)絡(luò)ResNet50_vd實(shí)現(xiàn)桃子分類
點(diǎn)擊下方卡片,關(guān)注“新機(jī)器視覺”公眾號(hào)
視覺/圖像重磅干貨,第一時(shí)間送達(dá)
隨著時(shí)代的快速發(fā)展,人工智能已經(jīng)融入我們生活的方方面面。中國(guó)的農(nóng)業(yè)也因人工智能而受益進(jìn)入高速發(fā)展階段?,F(xiàn)今,看莊稼長(zhǎng)勢(shì)有衛(wèi)星遙感技術(shù),水果分揀有智能分揀系統(tǒng),灌溉施肥有自動(dòng)化機(jī)械……

本實(shí)踐旨在通過桃子分類來讓大家對(duì)圖像分類問題有一個(gè)初步了解,同時(shí)理解和掌握如何使用PaddleHub搭建一個(gè)經(jīng)典的卷積神經(jīng)網(wǎng)絡(luò)。
方案設(shè)計(jì)

環(huán)境搭建與準(zhǔn)備
!pip install paddlehub==2.0.4 -i https://pypi.tuna.tsinghua.edu.cn/simple
數(shù)據(jù)處理

├─data: 數(shù)據(jù)目錄
├─train_list.txt:訓(xùn)練集數(shù)據(jù)列表
├─test_list.txt:測(cè)試集數(shù)據(jù)列表
├─validate_list.txt:驗(yàn)證集數(shù)據(jù)列表
├─label_list.txt:標(biāo)簽列表
└─……
圖片1路徑 圖片1標(biāo)簽
圖片2路徑 圖片2標(biāo)簽
...
分類1名稱
分類2名稱
...
!unzip -q -o ./data/data67225/peach.zip -d ./work
__init__,__getitem__和__len__三個(gè)部分。示例如下:#coding:utf-8
import os
import paddle
import paddlehub as hub
class DemoDataset(paddle.io.Dataset):
def __init__(self, transforms, num_classes=4, mode='train'):
# 數(shù)據(jù)集存放位置
self.dataset_dir = "./work/peach-classification" #dataset_dir為數(shù)據(jù)集實(shí)際路徑,需要填寫全路徑
self.transforms = transforms
self.num_classes = num_classes
self.mode = mode
if self.mode == 'train':
self.file = 'train_list.txt'
elif self.mode == 'test':
self.file = 'test_list.txt'
else:
self.file = 'validate_list.txt'
self.file = os.path.join(self.dataset_dir , self.file)
self.data = []
with open(self.file, 'r') as f:
for line in f.readlines():
line = line.strip()
if line != '':
self.data.append(line)
def __getitem__(self, idx):
img_path, grt = self.data[idx].split(' ')
img_path = os.path.join(self.dataset_dir, img_path)
im = self.transforms(img_path)
return im, int(grt)
def __len__(self):
return len(self.data)
import paddlehub.vision.transforms as T
transforms = T.Compose(
[T.Resize((256, 256)),
T.CenterCrop(224),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])],
to_rgb=True)
peach_train = DemoDataset(transforms)
peach_validate = DemoDataset(transforms, mode='val')
peach_test = DemoDataset(transforms, mode='test')
模型構(gòu)建
#安裝預(yù)訓(xùn)練模型
!hub install resnet50_vd_imagenet_ssld==1.1.0
# 加載模型
import paddlehub as hub
model = hub.Module(name='resnet50_vd_imagenet_ssld', label_list=["R0", "B1", "M2", "S3"])
模型訓(xùn)練
from paddlehub.finetune.trainer import Trainer
import paddle
optimizer = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='img_classification_ckpt', use_gpu=True)
trainer.train(peach_train, epochs=10, batch_size=16, eval_dataset=peach_validate, save_interval=1)
learning_rate:全局學(xué)習(xí)率。默認(rèn)為1e-3; parameters:待優(yōu)化模型參數(shù)。
運(yùn)行配置
model:被優(yōu)化模型; optimizer:優(yōu)化器選擇; use_vdl:是否使用vdl可視化訓(xùn)練過程; checkpoint_dir:保存模型參數(shù)的地址; compare_metrics:保存最優(yōu)模型的衡量指標(biāo);
train_dataset:訓(xùn)練時(shí)所用的數(shù)據(jù)集; epochs:訓(xùn)練輪數(shù); batch_size:訓(xùn)練的批大小,如果使用GPU,請(qǐng)根據(jù)實(shí)際情況調(diào)整batch_size; num_workers:works的數(shù)量,默認(rèn)為0; eval_dataset:驗(yàn)證集; log_interval:打印日志的間隔, 單位為執(zhí)行批訓(xùn)練的次數(shù)。 save_interval:保存模型的間隔頻次,單位為執(zhí)行訓(xùn)練的輪數(shù)。
模型評(píng)估
# 模型評(píng)估
trainer.evaluate(peach_test, 16)
模型推理
import paddle
import paddlehub as hub
from PIL import Image
import matplotlib.pyplot as plt
img_path = './work/test.jpg'
img = Image.open(img_path)
plt.imshow(img)
plt.axis('off')
plt.show()
result = model.predict([img_path])
print("桃子的類別被預(yù)測(cè)為:{}".format(result))
模型部署
{
"modules_info": {
"resnet50_vd_imagenet_ssld": {
"init_args": {
"version": "1.1.0",
"label_list":["R0", "B1", "M2", "S3"],
"load_checkpoint": "img_classification_ckpt/best_model/model.pdparams"
},
"predict_args": {
"batch_size": 1
}
}
},
"port": 8866,
"gpu": "0"
}
$ hub serving start --config config.json
import requests
import json
import cv2
import base64
import numpy as np
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
# 發(fā)送HTTP請(qǐng)求
org_im = cv2.imread('/PATH/TO/IMAGE')
data = {'images':[cv2_to_base64(org_im)], 'top_k':1}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/resnet50_vd_imagenet_ssld"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
data =r.json()["results"]['data']
評(píng)論
圖片
表情
