ViT 微調(diào)實戰(zhàn)
共 9694字,需瀏覽 20分鐘
·
2024-08-01 10:05
重磅干貨,第一時間送達
重磅干貨,第一時間送達
探索 CIFAR-10 圖像分類
介紹
你一定聽說過“Attention is all your need”?Transformers 最初從文本開始,現(xiàn)在已無處不在,甚至在圖像中使用了一種稱為視覺變換器 (ViT) 的東西,這種變換器最早是在論文《一張圖片勝過 16x16 個單詞:用于大規(guī)模圖像識別的 Transformers》中引入的。這不僅僅是另一個浮華的趨勢;事實證明,它們是強有力的競爭者,可以與卷積神經(jīng)網(wǎng)絡(luò) (CNN) 等傳統(tǒng)模型相媲美。
-
將圖像分成多個塊,將這些塊傳遞到全連接(FC)網(wǎng)絡(luò)或 FC+CNN 以獲取輸入嵌入向量。 -
添加位置信息。 -
將其傳遞到傳統(tǒng)的 Transformer 編碼器中,并在末端附加一個 FC 層。
問題描述
設(shè)置環(huán)境
!pip install torch torchvision!pip install transformers datasets!pip install transformers[torch]
# PyTorchimport torchimport torchvisionfrom torchvision.transforms import Normalize, Resize, ToTensor, Compose# 用于顯示圖像from PIL import Imageimport matplotlib.pyplot as pltfrom torchvision.transforms import ToPILImage# 加載數(shù)據(jù)集from datasets import load_dataset# Transformers從transformers import ViTImageProcessor, ViTForImageClassification從transformers import TrainingArguments, Trainer# 矩陣運算import numpy as np# 評估from sklearn.metrics import accuracy_scorefrom sklearn.metrics import confused_matrix, ConfusionMatrixDisplay
數(shù)據(jù)預(yù)處理
trainds, testds = load_dataset("cifar10", split=["train[:5000]","test[:1000]"])splits = trainds.train_test_split(test_size=0.1)trainds = splits['train']valds = splits['test']trainds, valds, testds
# Output(Dataset({features: ['img', 'label'],num_rows: 4500}),Dataset({features: ['img', 'label'],num_rows: 500}),Dataset({features: ['img', 'label'],num_rows: 1000}))
trainds.features,trainds.num_rows,trainds[ 0 ]
# Output({'img': Image(decode=True, id=None),'label': ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], id=None)},4500,{'img': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,'label': 0})
itos = dict((k,v) for k,v in enumerate(trainds.features['label'].names))stoi = dict((v,k) for k,v in enumerate(trainds.features['label'].names))itos
# Output{0: 'airplane',1: 'automobile',2: 'bird',3: 'cat',4: 'deer',5: 'dog',6: 'frog',7: 'horse',8: 'ship',9: 'truck'}
index = 0img, lab = trainds[index]['img'], itos[trainds[index]['label']]print(lab)img
model_name = "google/vit-base-patch16-224"processor = ViTImageProcessor.from_pretrained(model_name)mu, sigma = processor.image_mean, processor.image_std #get default mu,sigmasize = processor.size
norm = Normalize(mean=mu, std=sigma) #normalize image pixels range to [-1,1]# resize 3x32x32 to 3x224x224 -> convert to Pytorch tensor -> normalize_transf = Compose([Resize(size['height']),ToTensor(),norm])# apply transforms to PIL Image and store it to 'pixels' keydef transf(arg):arg['pixels'] = [_transf(image.convert('RGB')) for image in arg['img']]return arg
trainds.set_transform(transf)valds.set_transform(transf)testds.set_transform(transf)
idx = 0ex = trainds[idx]['pixels']ex = (ex+1)/2 #imshow requires image pixels to be in the range [0,1]exi = ToPILImage()(ex)plt.imshow(exi)plt.show()
微調(diào)模型
model_name = "google/vit-base-patch16-224"model = ViTForImageClassification.from_pretrained(model_name)print(model.classifier)
# OutputLinear(in_features=768, out_features=1000, bias=True)
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True, id2label=itos, label2id=stoi)print(model.classifier)
# OutputLinear(in_features=768, out_features=10, bias=True)
擁抱臉部訓(xùn)練師
args = TrainingArguments(f"test-cifar-10",save_strategy="epoch",evaluation_strategy="epoch",learning_rate=2e-5,per_device_train_batch_size=10,per_device_eval_batch_size=4,num_train_epochs=3,weight_decay=0.01,load_best_model_at_end=True,metric_for_best_model="accuracy",logging_dir='logs',remove_unused_columns=False,)
def collate_fn(examples):pixels = torch.stack([example["pixels"] for example in examples])labels = torch.tensor([example["label"] for example in examples])return {"pixel_values": pixels, "labels": labels}def compute_metrics(eval_pred):predictions, labels = eval_predpredictions = np.argmax(predictions, axis=1)return dict(accuracy=accuracy_score(predictions, labels))
trainer = Trainer(model,train_dataset=trainds,eval_dataset=valds,data_collator=collate_fn,compute_metrics=compute_metrics,tokenizer=processor,)
訓(xùn)練模型
trainer.train()
# OutputTrainOutput(global_step=675, training_loss=0.22329048227380824, metrics={'train_runtime': 1357.9833, 'train_samples_per_second': 9.941, 'train_steps_per_second': 0.497, 'total_flos': 1.046216869705728e+18, 'train_loss': 0.22329048227380824, 'epoch': 3.0})
評估
outputs = trainer.predict(testds)print(outputs.metrics)
# Output{'test_loss': 0.07223748415708542, 'test_accuracy': 0.973, 'test_runtime': 28.5169, 'test_samples_per_second': 35.067, 'test_steps_per_second': 4.383}
itos[np.argmax(outputs.predictions[0])], itos[outputs.label_ids[0]]
y_true = outputs.label_idsy_pred = outputs.predictions.argmax(1)labels = trainds.features['label'].namescm = confusion_matrix(y_true, y_pred)disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)disp.plot(xticks_rotation=45)
下載1:OpenCV-Contrib擴展模塊中文版教程
在「小白學(xué)視覺」公眾號后臺回復(fù):擴展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴展模塊教程中文版,涵蓋擴展模塊安裝、SFM算法、立體視覺、目標跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。
下載2:Python視覺實戰(zhàn)項目52講
在「小白學(xué)視覺」公眾號后臺回復(fù):Python視覺實戰(zhàn)項目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個視覺實戰(zhàn)項目,助力快速學(xué)校計算機視覺。
下載3:OpenCV實戰(zhàn)項目20講
在「小白學(xué)視覺」公眾號后臺回復(fù):OpenCV實戰(zhàn)項目20講,即可下載含有20個基于OpenCV實現(xiàn)20個實戰(zhàn)項目,實現(xiàn)OpenCV學(xué)習(xí)進階。
交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~
評論
圖片
表情
