第12章 PyTorch圖像分割代碼框架-3:推理與部署
推理模塊
模型訓(xùn)練完成后,需要單獨(dú)再寫一個推理模塊來供用戶測試或者使用,該模塊可以命名為test.py或者inference.py,導(dǎo)入訓(xùn)練好的模型文件和待測試的圖像,輸出該圖像的分割結(jié)果。inference.py主體部分如代碼11-7所示。
代碼11-7 推理模塊部分
# 導(dǎo)入相關(guān)庫import numpy as npimport torchfrom PIL import Image# 定義推理函數(shù)def inference(model, test_img):img = Image.open(test_img)img = val_transform(img)img = img.unsqueeze(0).to('cuda')with torch.no_grad():outputs = model(img)preds = outputs.detach().max(dim=1)[1].cpu().numpy()print(preds.shape)pred = VOCSegmentation.decode_target(preds[0]).astype(np.uint8)Image.fromarray(pred).save(os.path.join('s%_pred.png' % test_img.split('.')[0]))
上述代碼僅展示推理模塊的主體部分,完整代碼可參考本書配套的對應(yīng)章節(jié)代碼文件。實(shí)際執(zhí)行時,我們可以在命令行通過傳入待測試圖像和模型文件執(zhí)行inference.py。測試示例如下:
python inference.py --data_root 2007_000676.jpg --model deeplabv3plus_resnet101
測試圖像和模型預(yù)測結(jié)果示例如圖11-5所示。

部署模塊
雖然我們可以通過推理模塊來測試模型效果,但推理畢竟不是面向用戶級的使用體驗(yàn)。為了能夠在常見的用戶端使用我們的分割模型,還需要對模型進(jìn)行工程化的部署(deployment)。根據(jù)分割模型的應(yīng)用場景,一般最常見的部署場景是web端部署或者是基于C++的軟件集成部署。web端部署一般基于Flask等后端部署框架來完成,形式上可以分為為REST API和web應(yīng)用兩種表現(xiàn)形式。
一個web服務(wù)簡單而言就是用戶從客戶端發(fā)送一個HTTP請求,然后服務(wù)器收到請求后生成HTML文檔作為響應(yīng)返回給客戶端的過程。當(dāng)返回的內(nèi)容需要在前端頁面上呈現(xiàn)時,這個服務(wù)就是一個web端的應(yīng)用;當(dāng)返回內(nèi)容不需要在前端頁面體現(xiàn),而是直接以JSON等數(shù)據(jù)結(jié)構(gòu)給用戶時,這個服務(wù)就是一個REST API。
Flask是一個基于Python的輕量級web應(yīng)用框架,非常簡潔和靈活,也便于初學(xué)者快速上手。簡單幾行代碼即可快速定義一個web服務(wù),如代碼11-8所示。
# 導(dǎo)入flask相關(guān)模塊from flask import Flask, jsonify# 創(chuàng)建應(yīng)用app = Flask(__name__)# 定義預(yù)測路由def predict():return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
本節(jié)我們將分別展示基于PASCAL VOC 2012訓(xùn)練的Deeplab v3+模型的REST API和web應(yīng)用部署方式。
REST API部署
基于REST API部署相對較為簡單,我們可以直接編寫一個api.py的文件,將推理流程融入到Flask的預(yù)測路由函數(shù)中即可。在此之前需要先導(dǎo)入訓(xùn)練好的模型以及定義跟驗(yàn)證時同樣的數(shù)據(jù)轉(zhuǎn)換方法。基于上述策略可定義api.py如下:
代碼11-9 REST API部署
# 導(dǎo)入相關(guān)庫import torchfrom torchvision import transformsfrom PIL import Imageimport ioimport numpy as npfrom utils import ext_transforms_new as etfrom datasets import VOCSegmentationfrom flask import Flask, request, jsonifyimport models# 創(chuàng)建應(yīng)用app = Flask(__name__)# 模型字典model_map = {'deeplabv3plus_resnet50': models.deeplabv3plus_resnet50,'deeplabv3plus_resnet101': models.deeplabv3plus_resnet101,}# 創(chuàng)建模型model = model_map['deeplabv3plus_resnet101'](num_classes=21,output_stride=16)# 導(dǎo)入模型model.load_state_dict(torch.load('../checkpoints/deeplabv3plus_resnet101_voc.pth')['model_state'])model.to('cuda')model.eval()print('model loaded.')# 定義數(shù)據(jù)轉(zhuǎn)換方法transform = et.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),])# 定義模型預(yù)測路由def predict():if request.method == 'POST':# 從請求中讀取輸入圖像image = request.files['image'].read()image = Image.open(io.BytesIO(image))# 圖像變換input_tensor = transform(image).unsqueeze(0).to('cuda')# 模型預(yù)測with torch.no_grad():output = model(input_tensor)preds = output.detach().max(dim=1)[1].cpu().numpy()print(preds.shape)# 對輸出進(jìn)行解碼,轉(zhuǎn)換為maskpreds = VOCSegmentation.decode_target(preds[0]).astype(np.uint8)# 轉(zhuǎn)換成listresult = preds.tolist()return jsonify(result)if __name__ == '__main__':app.run(debug=True)
定義好app.py以后,直接在命令行啟動該REST API服務(wù):
python app.py
然后再單獨(dú)啟動一個Python終端,通過requests庫發(fā)起post請求,傳入一張待分割圖像:
resp = requests.post("http://localhost:5000/predict", files={"image": open('./deployment/2007_000676.jpg', 'rb')})
這時候我們可以在服務(wù)端看到相關(guān)響應(yīng)信息,如圖11-6所示。狀態(tài)碼顯示為200,說明請求成功,返回數(shù)據(jù)可以在requests返回對象中查看。

web端部署
REST API的部署方式更多的是方便開發(fā)者使用,對于普通用戶可能不是那么友好。為了更加方便用戶使用和更直觀的展示模型效果,我們可以通過web端部署的方式,讓用戶上傳圖像作為輸入,并將輸入圖像和分割結(jié)果直接在網(wǎng)頁上顯示。所以與API部署方式不同的是需要加上一個index.html的網(wǎng)頁模板文件,將輸入和分割結(jié)果在網(wǎng)頁模板上進(jìn)行渲染。同時原先的api.py文件也需要進(jìn)行修改,修改后的文件可命名為app.py,主體部分如代碼11-10所示。
代碼11-10 web端應(yīng)用app.py
# 創(chuàng)建應(yīng)用app = Flask(__name__)# 定義上傳和預(yù)測路由@app.route('/', methods=['GET', 'POST'])def upload_predict():# POST請求后讀取圖像if request.method == 'POST':image_file = request.files['image']if image_file:image_location = os.path.join(app.config['UPLOAD_FOLDER'],image_file.filename)image_file.save(image_location)# 圖像變換image = Image.open(image_location).convert('RGB')input_tensor = transform(image).unsqueeze(0).to('cuda')# 模型預(yù)測with torch.no_grad():output = model(input_tensor)preds = output.detach().max(dim=1)[1].cpu().numpy()print(preds.shape)# mask解碼preds = VOCSegmentation.decode_target(preds[0]).astype(np.uint8)# 保存圖像到指定路徑segmented_image = Image.fromarray(preds)segmented_image_path = image_location.replace('.jpg', '_segmented.jpg')segmented_image.save(segmented_image_path)display_input_path = '../' + image_locationdisplay_segmented_path = '../' + segmented_image_path# 渲染結(jié)果到網(wǎng)頁return render_template('index.html', input_image=display_input_path, segmented_image=display_segmented_path)return render_template('index.html')
代碼11-10與api.py的主要區(qū)別在于讀取圖像部分是需要讀取用戶上傳到指定目錄下的圖像,并且對輸入圖像和分割結(jié)果渲染呈現(xiàn)到網(wǎng)頁端。index.html是網(wǎng)頁HTML的模板文件,我們可以通過編輯該文件來實(shí)現(xiàn)自己想要的網(wǎng)頁效果。
執(zhí)行app.py文件啟動web服務(wù),然后打開服務(wù)運(yùn)行地址:http:127.0.0.1:5000即可看到網(wǎng)頁端效果,在網(wǎng)頁點(diǎn)擊“選擇文件”上傳輸入圖像,然后點(diǎn)擊“Segment”執(zhí)行模型分割,圖11-7為web部署后的使用效果圖。

總結(jié)
本章以PASCAL VOC 2012數(shù)據(jù)集和Deeplab v3+分割模型為例,給出了基于PyTorch的深度學(xué)習(xí)圖像分割項(xiàng)目代碼框架。一個相對完整的圖像分割代碼框架應(yīng)包含:預(yù)處理模塊、數(shù)據(jù)導(dǎo)入模塊、模型模塊、工具函數(shù)模塊、配置模塊、主函數(shù)模塊、推理模塊和部署模塊。啟中預(yù)處理、數(shù)據(jù)導(dǎo)入、模型、工具函數(shù)、配置和主函數(shù)模塊均為模型訓(xùn)練階段的工作模塊,而推理和部署則屬于模型訓(xùn)練完后的測試和使用階段工作模塊。
需要特別說明的是,本章的代碼框架僅作為深度學(xué)習(xí)圖像分割項(xiàng)目的一般性框架,具體使用時應(yīng)根據(jù)項(xiàng)目的實(shí)際情況酌情參考。
后續(xù)全書內(nèi)容和代碼將在github上開源,請關(guān)注倉庫:
https://github.com/luwill/Deep-Learning-Image-Segmentation
(本章完結(jié),其余章節(jié)待續(xù))
