基于 PyTorch 的人臉關(guān)鍵點(diǎn)檢測
共 12432字,需瀏覽 25分鐘
·
2024-07-13 10:05
點(diǎn)擊上方“小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時(shí)間送達(dá)
重磅干貨,第一時(shí)間送達(dá)
計(jì)算機(jī)真的能理解人臉嗎?你是否想過Instagram是如何給你的臉上應(yīng)用驚人的濾鏡的?該軟件檢測你臉上的關(guān)鍵點(diǎn)并在其上投影一個(gè)遮罩。本教程將文章你如何使用PyTorch構(gòu)建一個(gè)類似的軟件。
數(shù)據(jù)集
在本教程中,我們將使用官方的DLib數(shù)據(jù)集,其中包含6666張尺寸不同的圖像。此外,labels_ibug_300W_train.xml(隨數(shù)據(jù)集提供)包含每張人臉的68個(gè)關(guān)鍵點(diǎn)的坐標(biāo)。下面的腳本將在Colab筆記本中下載數(shù)據(jù)集并解壓縮。
if not os.path.exists('/content/ibug_300W_large_face_landmark_dataset'):!wget http://dlib.net/files/data/ibug_300W_large_face_landmark_dataset.tar.gz!tar -xvzf 'ibug_300W_large_face_landmark_dataset.tar.gz'!rm -r 'ibug_300W_large_face_landmark_dataset.tar.gz'
這是數(shù)據(jù)集中的一張樣本圖像。我們可以看到,人臉只占整個(gè)圖像的一小部分。如果我們將完整圖像輸入神經(jīng)網(wǎng)絡(luò),它也會處理背景(無關(guān)信息),這會使模型難以學(xué)習(xí)。因此,我們需要裁剪圖像,僅輸入人臉部分。
數(shù)據(jù)集中的樣本圖像和關(guān)鍵點(diǎn)
數(shù)據(jù)預(yù)處理
為了防止神經(jīng)網(wǎng)絡(luò)過擬合訓(xùn)練數(shù)據(jù)集,我們需要隨機(jī)變換數(shù)據(jù)集。我們將對訓(xùn)練和驗(yàn)證數(shù)據(jù)集應(yīng)用以下操作:
由于人臉只占整個(gè)圖像的一小部分,所以裁剪圖像并僅使用人臉進(jìn)行訓(xùn)練。
將裁剪后的人臉調(diào)整為(224x224)的圖像。
隨機(jī)改變調(diào)整后的人臉的亮度和飽和度。
在上述三個(gè)轉(zhuǎn)換之后,隨機(jī)旋轉(zhuǎn)人臉。
將圖像和關(guān)鍵點(diǎn)轉(zhuǎn)換為torch張量,并在[-1, 1]之間進(jìn)行歸一化。
class Transforms():def __init__(self):passdef rotate(self, image, landmarks, angle):angle = random.uniform(-angle, +angle)transformation_matrix = torch.tensor([[+cos(radians(angle)), -sin(radians(angle))],[+sin(radians(angle)), +cos(radians(angle))]])image = imutils.rotate(np.array(image), angle)landmarks = landmarks - 0.5new_landmarks = np.matmul(landmarks, transformation_matrix)new_landmarks = new_landmarks + 0.5return Image.fromarray(image), new_landmarksdef resize(self, image, landmarks, img_size):image = TF.resize(image, img_size)return image, landmarksdef color_jitter(self, image, landmarks):color_jitter = transforms.ColorJitter(brightness=0.3,contrast=0.3,saturation=0.3,hue=0.1)image = color_jitter(image)return image, landmarksdef crop_face(self, image, landmarks, crops):left = int(crops['left'])top = int(crops['top'])width = int(crops['width'])height = int(crops['height'])image = TF.crop(image, top, left, height, width)img_shape = np.array(image).shapelandmarks = torch.tensor(landmarks) - torch.tensor([[left, top]])landmarks = landmarks / torch.tensor([img_shape[1], img_shape[0]])return image, landmarksdef __call__(self, image, landmarks, crops):image = Image.fromarray(image)image, landmarks = self.crop_face(image, landmarks, crops)image, landmarks = self.resize(image, landmarks, (224, 224))image, landmarks = self.color_jitter(image, landmarks)image, landmarks = self.rotate(image, landmarks, angle=10)image = TF.to_tensor(image)image = TF.normalize(image, [0.5], [0.5])return image, landmarks
數(shù)據(jù)集類
現(xiàn)在我們已經(jīng)準(zhǔn)備好了轉(zhuǎn)換,讓我們編寫我們的數(shù)據(jù)集類。labels_ibug_300W_train.xml包含圖像路徑、關(guān)鍵點(diǎn)和邊界框的坐標(biāo)(用于裁剪人臉)。我們將這些值存儲在列表中,以便在訓(xùn)練期間輕松訪問。在本文章中,神經(jīng)網(wǎng)絡(luò)將在灰度圖像上進(jìn)行訓(xùn)練。
class FaceLandmarksDataset(Dataset):def __init__(self, transform=None):tree = ET.parse('ibug_300W_large_face_landmark_dataset/labels_ibug_300W_train.xml')root = tree.getroot()self.image_filenames = []self.landmarks = []self.crops = []self.transform = transformself.root_dir = 'ibug_300W_large_face_landmark_dataset'for filename in root[2]:self.image_filenames.append(os.path.join(self.root_dir, filename.attrib['file']))self.crops.append(filename[0].attrib)landmark = []for num in range(68):x_coordinate = int(filename[0][num].attrib['x'])y_coordinate = int(filename[0][num].attrib['y'])landmark.append([x_coordinate, y_coordinate])self.landmarks.append(landmark)self.landmarks = np.array(self.landmarks).astype('float32')assert len(self.image_filenames) == len(self.landmarks)def __len__(self):return len(self.image_filenames)def __getitem__(self, index):image = cv2.imread(self.image_filenames[index], 0)landmarks = self.landmarks[index]if self.transform:image, landmarks = self.transform(image, landmarks, self.crops[index])landmarks = landmarks - 0.5return image, landmarksdataset = FaceLandmarksDataset(Transforms())
注意:landmarks = landmarks - 0.5 是為了將關(guān)鍵點(diǎn)居中,因?yàn)橹行幕妮敵鰧ι窠?jīng)網(wǎng)絡(luò)學(xué)習(xí)更容易。經(jīng)過預(yù)處理后的數(shù)據(jù)集輸出如下所示(關(guān)鍵點(diǎn)已經(jīng)繪制在圖像中):
預(yù)處理后的數(shù)據(jù)樣本
神經(jīng)網(wǎng)絡(luò)
我們將使用ResNet18作為基本框架。我們需要修改第一層和最后一層以適應(yīng)我們的目的。在第一層中,我們將輸入通道數(shù)設(shè)為1,以便神經(jīng)網(wǎng)絡(luò)接受灰度圖像。同樣,在最后一層中,輸出通道數(shù)應(yīng)為68 * 2 = 136,以便模型預(yù)測每張人臉的68個(gè)關(guān)鍵點(diǎn)的(x,y)坐標(biāo)。
class Network(nn.Module):def __init__(self,num_classes=136):super().__init__()self.model_name='resnet18'self.model=models.resnet18()self.model.conv1=nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)self.model.fc=nn.Linear(self.model.fc.in_features, num_classes)def forward(self, x):x=self.model(x)return x
訓(xùn)練神經(jīng)網(wǎng)絡(luò)
我們將使用預(yù)測關(guān)鍵點(diǎn)和真實(shí)關(guān)鍵點(diǎn)之間的均方誤差作為損失函數(shù)。請記住,要避免梯度爆炸,學(xué)習(xí)率應(yīng)保持低。每當(dāng)驗(yàn)證損失達(dá)到新的最小值時(shí),網(wǎng)絡(luò)權(quán)重將被保存。至少訓(xùn)練20個(gè)epochs以獲得最佳性能。
network = Network()criterion = nn.MSELoss()optimizer = optim.Adam(network.parameters(), lr=0.0001)loss_min = np.infnum_epochs = 10start_time = time.time()for epoch in range(1,num_epochs+1):loss_train = 0loss_valid = 0running_loss = 0network.train()for step in range(1,len(train_loader)+1):landmarks = next(iter(train_loader))images = images.cuda()landmarks = landmarks.view(landmarks.size(0),-1).cuda()predictions = network(images)# clear all the gradients before calculating themoptimizer.zero_grad()# find the loss for the current steploss_train_step = criterion(predictions, landmarks)# calculate the gradientsloss_train_step.backward()# update the parametersoptimizer.step()loss_train += loss_train_step.item()running_loss = loss_train/steplen(train_loader), running_loss, 'train')with torch.no_grad():for step in range(1,len(valid_loader)+1):landmarks = next(iter(valid_loader))images = images.cuda()landmarks = landmarks.view(landmarks.size(0),-1).cuda()predictions = network(images)# find the loss for the current steploss_valid_step = criterion(predictions, landmarks)loss_valid += loss_valid_step.item()running_loss = loss_valid/steplen(valid_loader), running_loss, 'valid')loss_train /= len(train_loader)loss_valid /= len(valid_loader)print('\n--------------------------------------------------'): {} Train Loss: {:.4f} Valid Loss: {:.4f}'.format(epoch, loss_train, loss_valid))print('--------------------------------------------------')if loss_valid < loss_min:loss_min = loss_valid'/content/face_landmarks.pth')Validation Loss of {:.4f} at epoch {}/{}".format(loss_min, epoch, num_epochs))Saved\n')Complete')Elapsed Time : {} s".format(time.time()-start_time))
在未知數(shù)據(jù)上進(jìn)行預(yù)測
使用以下代碼段在未知圖像中預(yù)測關(guān)鍵點(diǎn)。
import timeimport cv2import osimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Imageimport imutilsimport torchimport torch.nn as nnfrom torchvision import modelsimport torchvision.transforms.functional as TF#######################################################################image_path = 'pic.jpg'weights_path = 'face_landmarks.pth'frontal_face_cascade_path = 'haarcascade_frontalface_default.xml'#######################################################################class Network(nn.Module):def __init__(self,num_classes=136):super().__init__()self.model_name='resnet18'self.model=models.resnet18(pretrained=False)self.model.conv1=nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)self.model.fc=nn.Linear(self.model.fc.in_features,num_classes)def forward(self, x):x=self.model(x)return x#######################################################################face_cascade = cv2.CascadeClassifier(frontal_face_cascade_path)best_network = Network()best_network.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))best_network.eval()image = cv2.imread(image_path)grayscale_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)display_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)height, width,_ = image.shapefaces = face_cascade.detectMultiScale(grayscale_image, 1.1, 4)all_landmarks = []for (x, y, w, h) in faces:image = grayscale_image[y:y+h, x:x+w]image = TF.resize(Image.fromarray(image), size=(224, 224))image = TF.to_tensor(image)image = TF.normalize(image, [0.5], [0.5])with torch.no_grad():landmarks = best_network(image.unsqueeze(0))landmarks = (landmarks.view(68,2).detach().numpy() + 0.5) * np.array([[w, h]]) + np.array([[x, y]])all_landmarks.append(landmarks)plt.figure()plt.imshow(display_image)for landmarks in all_landmarks:plt.scatter(landmarks[:,0], landmarks[:,1], c = 'c', s = 5)plt.show()
OpenCV Haar級聯(lián)分類器用于檢測圖像中的人臉。使用Haar級聯(lián)進(jìn)行對象檢測是一種基于機(jī)器學(xué)習(xí)的方法,其中使用一組輸入數(shù)據(jù)對級聯(lián)函數(shù)進(jìn)行訓(xùn)練。OpenCV已經(jīng)包含了許多預(yù)訓(xùn)練的分類器,用于人臉、眼睛、行人等等。在我們的案例中,我們將使用人臉分類器,你需要下載預(yù)訓(xùn)練的分類器XML文件并將其保存到你的工作目錄中。
人臉檢測
在輸入圖像中檢測到的人臉將被裁剪、調(diào)整大小為(224,224)并輸入到我們訓(xùn)練好的神經(jīng)網(wǎng)絡(luò)中以預(yù)測其中的關(guān)鍵點(diǎn)。
裁剪人臉上的關(guān)鍵點(diǎn)
在裁剪的人臉上疊加預(yù)測的關(guān)鍵點(diǎn)。結(jié)果如下圖所示。相當(dāng)令人印象深刻,不是嗎?
最終結(jié)果
同樣,在多個(gè)人臉上進(jìn)行關(guān)鍵點(diǎn)檢測:
在這里,你可以看到OpenCV Haar級聯(lián)分類器已經(jīng)檢測到了多個(gè)人臉,包括一個(gè)誤報(bào)(一個(gè)拳頭被預(yù)測為人臉)。
下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程
在「小白學(xué)視覺」公眾號后臺回復(fù):擴(kuò)展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺、目標(biāo)跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。
下載2:Python視覺實(shí)戰(zhàn)項(xiàng)目52講
在「小白學(xué)視覺」公眾號后臺回復(fù):Python視覺實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計(jì)數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個(gè)視覺實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺。
下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講
在「小白學(xué)視覺」公眾號后臺回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講,即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。
交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動(dòng)駕駛、計(jì)算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會逐漸細(xì)分),請掃描下面微信號加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~
