OpenCV4 部署DeepLabv3+模型
點擊上方“小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時間送達
本文轉(zhuǎn)自:opencv學(xué)堂
前面說了OpenCV DNN不光支持圖像分類與對象檢測模型。此外還支持各種自定義的模型,deeplabv3模型是圖像語義分割常用模型之一,本文我們演示OpenCV DNN如何調(diào)用Deeplabv3模型實現(xiàn)圖像語義分割,支持的backbone網(wǎng)絡(luò)分別為MobileNet與Inception。預(yù)訓(xùn)練模型下載地址如下:
https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md預(yù)訓(xùn)練的模型下載之后可以看到pb文件,ckpt文件,其中pb文件可以直接調(diào)用。

下載MobileNet版本的deeplabv3模型,把mobilenetv2 ckpt轉(zhuǎn)pb,腳本如下:
python deeplab/export_model.py \
--logtostderr \
--checkpoint_path="/home/lw/data/cityscapes/train/model.ckpt-2000" \
--export_path="/home/lw/data/pb/frozen_inference_graph.pb" \
--model_variant="mobilenet_v2" \
#--atrous_rates=6 \
#--atrous_rates=12 \
#--atrous_rates=18 \
#--output_stride=16 \
--decoder_output_stride=4 \
--num_classes=6 \
--crop_size=513 \
--crop_size=513 \
--inference_scales=1.0接下來使用opencv加載mobilenetv2轉(zhuǎn)換好的pb模型會報下面的錯誤:

使用mobilenetv2的解決辦法:
import tensorflow as tf
from tensorflow.tools.graph_transforms import TransformGraph
from tensorflow.python.tools import optimize_for_inference_lib
graph = 'frozen_inference_graph.pb'
with tf.gfile.FastGFile(graph, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.summary.FileWriter('logs', graph_def)
inp_node = 'MobilenetV2/MobilenetV2/input'
out_node = 'logits/semantic/BiasAdd'
graph_def = optimize_for_inference_lib.optimize_for_inference(graph_def, [inp_node], [out_node],
tf.float32.as_datatype_enum)
graph_def = TransformGraph(graph_def, [inp_node], [out_node], ["sort_by_execution_order"])
with tf.gfile.FastGFile('frozen_inference_graph_opt.pb', 'wb') as f:
f.write(graph_def.SerializeToString())
使用xception的解決辦法
import tensorflow as tf
from tensorflow.tools.graph_transforms import TransformGraph
from tensorflow.python.tools import optimize_for_inference_lib
graph = 'frozen_inference_graph.pb'
with tf.gfile.FastGFile(graph, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.summary.FileWriter('logs', graph_def)
# inp_node = 'sub_2' # 起始地節(jié)點
# out_node = 'logits/semantic/BiasAdd' # 結(jié)束的節(jié)點
graph_def = optimize_for_inference_lib.optimize_for_inference(graph_def, [inp_node], [out_node],
tf.float32.as_datatype_enum)
graph_def = TransformGraph(graph_def, [inp_node], [out_node], ["sort_by_execution_order"])
with tf.gfile.FastGFile('frozen_inference_graph_opt.pb', 'wb') as f:
f.write(graph_def.SerializeToString())import cv2
import numpy as np
np.random.seed(0)
color = np.random.randint(0, 255, size=[150, 3])
print(color)
# Load names of classes
#classes = None
#with open("labels.names", 'rt') as f:
# classes = f.read().rstrip('\n').split('\n')
#legend = None
#def showLegend(classes):
# global legend
# if not classes is None and legend is None:
# blockHeight = 30
# print(len(classes), len(colors))
# assert(len(classes) == len(colors))
# legend = np.zeros((blockHeight * len(colors), 200, 3), np.uint8)
# for i in range(len(classes)):
# block = legend[i * blockHeight:(i + 1) * blockHeight]
# block[:, :] = colors[i]
# cv2.putText(block, classes[i], (0, blockHeight//2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255))
# cv2.namedWindow('Legend', cv2.WINDOW_NORMAL)
# cv2.imshow('Legend', legend)
# cv2.waitKey()
# 讀取圖片
frame = cv2.imread("1.jpg")
frameHeight = frame.shape[0]
frameWidth = frame.shape[1]
# 加載模型
net = cv2.dnn.readNet("frozen_inference_graph_opt.pb")
blob = cv2.dnn.blobFromImage(frame, 0.007843, (513, 513), (127.5, 127.5, 127.5), swapRB=True)
net.setInput(blob)
score = net.forward()
numClasses = score.shape[1]
height = score.shape[2]
width = score.shape[3]
classIds = np.argmax(score[0], axis=0) # 在列上求最大的值的索引
segm = np.stack([color[idx] for idx in classIds.flatten()])
segm = segm.reshape(height, width, 3)
segm = cv2.resize(segm, (frameWidth, frameHeight), interpolation=cv2.INTER_NEAREST)
frame = (0.3*frame + 0.8*segm).astype(np.uint8)
#showLegend(classes)
cv2.imshow("img", frame)
cv2.waitKey()
交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~
評論
圖片
表情

