使用PyTorch來進展不平衡數(shù)據(jù)集的圖像分類

作者:Marek Paulik
編譯:ronghuaiyang
來源:AI公園
一個非常簡單和容易上手的例子。
對于教程中使用的大多數(shù)人工數(shù)據(jù)集,每個類都有相同數(shù)量的數(shù)據(jù)。然而,在實際應(yīng)用中,這種情況很少發(fā)生。今天,我將給你介紹來自Kaggle的木薯葉分類,并告訴你當類頻率有很大差異時該怎么做。
處理類別的不平衡
有兩種方法可以解決這個問題。
WeightedRandomSampler loss函數(shù)中的weight參數(shù)
下一步是創(chuàng)建一個有5個方法的CassavaClassifier類:load_data()、load_model()、fit_one_epoch()、val_one_epoch()和fit()。
在load_data()中,將構(gòu)造一個train和驗證數(shù)據(jù)集,并返回數(shù)據(jù)加載器以供進一步使用。
在load_model()中定義了體系結(jié)構(gòu)、損失函數(shù)和優(yōu)化器。
fit方法包含一些初始化和對fit_one_epoch()和val_one_epoch()的循環(huán)。
早期停止
早期停止類有助于根據(jù)驗證損失跟蹤最佳模型,并保存檢查點。
#Callbacks
# Early stopping
class EarlyStopping:
def __init__(self, patience=1, delta=0, path='checkpoint.pt'):
self.patience = patience
self.delta = delta
self.path= path
self.counter = 0
self.best_score = None
self.early_stop = False
def __call__(self, val_loss, model):
if self.best_score is None:
self.best_score = val_loss
self.save_checkpoint(model)
elif val_loss > self.best_score:
self.counter +=1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = val_loss
self.save_checkpoint(model)
self.counter = 0
def save_checkpoint(self, model):
torch.save(model.state_dict(), self.path)
Init
我們首先初始化CassavaClassifier類。
class CassavaClassifier():
def __init__(self, data_dir, num_classes, device, Transform=None, sample=False, loss_weights=False, batch_size=16,
lr=1e-4, stop_early=True, freeze_backbone=True):
#############################################################################################################
# data_dir - directory with images in subfolders, subfolders name are categories
# Transform - data augmentations
# sample - if the dataset is imbalanced set to true and RandomWeightedSampler will be used
# loss_weights - if the dataset is imbalanced set to true and weight parameter will be passed to loss function
# freeze_backbone - if using pretrained architecture freeze all but the classification layer
###############################################################################################################
self.data_dir = data_dir
self.num_classes = num_classes
self.device = device
self.sample = sample
self.loss_weights = loss_weights
self.batch_size = batch_size
self.lr = lr
self.stop_early = stop_early
self.freeze_backbone = freeze_backbone
self.Transform = Transform
Load Data
訓練圖像被組織在子文件夾中,子文件夾名稱表示圖像的類。這是圖像分類問題的典型情況,幸運的是,不需要編寫自定義數(shù)據(jù)集類。在這種情況下,可以立即使用torchvision中的ImageFolder。如果你想使用WeightedRandomSampler,你需要為數(shù)據(jù)集的每個元素指定一個權(quán)重。通常,總圖像總比上類別數(shù)被用作一個權(quán)重。
def load_data(self):
train_full = torchvision.datasets.ImageFolder(self.data_dir, transform=self.Transform)
train_set, val_set = random_split(train_full, [math.floor(len(train_full)*0.8), math.ceil(len(train_full)*0.2)])
self.train_classes = [label for _, label in train_set]
if self.sample:
# Need to get weight for every image in the dataset
class_count = Counter(self.train_classes)
class_weights = torch.Tensor([len(self.train_classes)/c for c in pd.Series(class_count).sort_index().values])
# Can't iterate over class_count because dictionary is unordered
sample_weights = [0] * len(train_set)
for idx, (image, label) in enumerate(train_set):
class_weight = class_weights[label]
sample_weights[idx] = class_weight
sampler = WeightedRandomSampler(weights=sample_weights,
num_samples = len(train_set), replacement=True)
train_loader = DataLoader(train_set, batch_size=self.batch_size, sampler=sampler)
else:
train_loader = DataLoader(train_set, batch_size=self.batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=self.batch_size)
return train_loader, val_loader
Load Model
在該方法中,我使用遷移學習,架構(gòu)參數(shù)從預先訓練的resnet50和efficientnet-b7中選擇。CrossEntropyLoss和許多其他損失函數(shù)都有權(quán)重參數(shù)。這是一個手動調(diào)整參數(shù),用于處理不平衡。在這種情況下,不需要為每個參數(shù)定義權(quán)重,只需為每個類定義權(quán)重。
def load_model(self, arch='resnet'):
##############################################################################################################
# arch - choose the pretrained architecture from resnet or efficientnetb7
##############################################################################################################
if arch == 'resnet':
self.model = torchvision.models.resnet50(pretrained=True)
if self.freeze_backbone:
for param in self.model.parameters():
param.requires_grad = False
self.model.fc = nn.Linear(in_features=self.model.fc.in_features, out_features=self.num_classes)
elif arch == 'efficient-net':
self.model = EfficientNet.from_pretrained('efficientnet-b7')
if self.freeze_backbone:
for param in self.model.parameters():
param.requires_grad = False
self.model._fc = nn.Linear(in_features=self.model._fc.in_features, out_features=self.num_classes)
self.model = self.model.to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr)
if self.loss_weights:
class_count = Counter(self.train_classes)
class_weights = torch.Tensor([len(self.train_classes)/c for c in pd.Series(class_count).sort_index().values])
# Cant iterate over class_count because dictionary is unordered
class_weights = class_weights.to(self.device)
self.criterion = nn.CrossEntropyLoss(class_weights)
else:
self.criterion = nn.CrossEntropyLoss()
Fit One Epoch
這個方法只包含一個經(jīng)典的訓練循環(huán),帶有訓練損失記錄和tqdm進度條。
def fit_one_epoch(self, train_loader, epoch, num_epochs ):
step_train = 0
train_losses = list() # Every epoch check average loss per batch
train_acc = list()
self.model.train()
for i, (images, targets) in enumerate(tqdm(train_loader)):
images = images.to(self.device)
targets = targets.to(self.device)
logits = self.model(images)
loss = self.criterion(logits, targets)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
train_losses.append(loss.item())
#Calculate running train accuracy
predictions = torch.argmax(logits, dim=1)
num_correct = sum(predictions.eq(targets))
running_train_acc = float(num_correct) / float(images.shape[0])
train_acc.append(running_train_acc)
train_loss = torch.tensor(train_losses).mean()
print(f'Epoch {epoch}/{num_epochs-1}')
print(f'Training loss: {train_loss:.2f}')
Validate one epoch
與上面類似,但此方法在驗證數(shù)據(jù)加載器上迭代。在每一個epoch'之后,平均batch損失和準確性被打印出來。
def val_one_epoch(self, val_loader, scaler):
val_losses = list()
val_accs = list()
self.model.eval()
step_val = 0
with torch.no_grad():
for (images, targets) in val_loader:
images = images.to(self.device)
targets = targets.to(self.device)
logits = self.model(images)
loss = self.criterion(logits, targets)
val_losses.append(loss.item())
predictions = torch.argmax(logits, dim=1)
num_correct = sum(predictions.eq(targets))
running_val_acc = float(num_correct) / float(images.shape[0])
val_accs.append(running_val_acc)
self.val_loss = torch.tensor(val_losses).mean()
val_acc = torch.tensor(val_accs).mean() # Average acc per batch
print(f'Validation loss: {self.val_loss:.2f}')
print(f'Validation accuracy: {val_acc:.2f}')
Fit
Fit方法在訓練和驗證過程中經(jīng)歷了許多階段和循環(huán)。如果預訓練模型的參數(shù)在開始時被凍結(jié),那么unfreeze_after定義了整個模型在多少個epoch之后開始訓練。在此之前,只訓練全連接層(分類器)。
def fit(self, train_loader, val_loader, num_epochs=10, unfreeze_after=5, checkpoint_dir='checkpoint.pt'):
if self.stop_early:
early_stopping = EarlyStopping(
patience=5,
path=checkpoint_dir)
for epoch in range(num_epochs):
if self.freeze_backbone:
if epoch == unfreeze_after: # Unfreeze after x epochs
for param in self.model.parameters():
param.requires_grad = True
self.fit_one_epoch(train_loader, scaler, epoch, num_epochs)
self.val_one_epoch(val_loader, scaler)
if self.stop_early:
early_stopping(self.val_loss, self.model)
if early_stopping.early_stop:
print('Early Stopping')
print(f'Best validation loss: {early_stopping.best_score}')
break
Run
現(xiàn)在,可以初始化CassavaClassifier類、創(chuàng)建dataloaders、設(shè)置模型并運行整個過程了。
Transform = T.Compose(
[T.ToTensor(),
T.Resize((256, 256)),
T.RandomRotation(90),
T.RandomHorizontalFlip(p=0.5),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
data_dir = "Data/cassava-disease/train/train"
classifier = CassavaClassifier(data_dir=data_dir, num_classes=5, device=device, sample=True, Transform=Transform)
train_loader, val_loader = classifier.load_data()
classifier.load_model()
classifier.fit(num_epochs=20, unfreeze_after=5, train_loader=train_loader, val_loader=val_loader)
Inference
使用ImageFolder加載測試數(shù)據(jù)是不可能的,因為顯然沒有帶有類的子文件夾。因此,我創(chuàng)建了一個返回圖像和圖像id的自定義數(shù)據(jù)集。隨后,加載模型檢查點,通過推理循環(huán)運行它,并將預測保存到數(shù)據(jù)幀中。將數(shù)據(jù)幀導出為CSV并提交結(jié)果。
# Inference
model = torchvision.models.resnet50()
#model = EfficientNet.from_name('efficientnet-b7')
model.fc = nn.Linear(in_features=model.fc.in_features, out_features=5)
model = model.to(device)
checkpoint = torch.load('Data/cassava-disease/sampler_checkpoint.pt')
model.load_state_dict(checkpoint)
model.eval()
# Dataset for test data
class Cassava_Test(Dataset):
def __init__(self, dir, transform=None):
self.dir = dir
self.transform = transform
self.images = os.listdir(self.dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img = Image.open(os.path.join(self.dir, self.images[idx]))
return self.transform(img), self.images[idx]
test_dir = 'Data/cassava-disease/test/test/0'
test_set = Cassava_Test(test_dir, transform=Transform)
test_loader = DataLoader(test_set, batch_size=4)
# Test loop
sub = pd.DataFrame(columns=['category', 'id'])
id_list = []
pred_list = []
model = model.to(device)
with torch.no_grad():
for (image, image_id) in test_loader:
image = image.to(device)
logits = model(image)
predicted = list(torch.argmax(logits, 1).cpu().numpy())
for id in image_id:
id_list.append(id)
for prediction in predicted:
pred_list.append(prediction)
sub['category'] = pred_list
sub['id'] = id_list
mapping = {0:'cbb', 1:'cbsd', 2:'cgm', 3:'cmd', 4:'healthy'}
sub['category'] = sub['category'].map(mapping)
sub = sub.sort_values(by='id')
sub.to_csv('Cassava_sub.csv', index=False)
如果在方案中包含WeightedRandomSampler或損失權(quán)值,則測試集的精度會提高2%。對于僅僅幾行代碼來說,這是一個很好的改進。對于這個數(shù)據(jù)集,我沒有看到這兩種方法在精度上的巨大差異,但WeightedRandomSampler的表現(xiàn)要好一些。
不同的學習速度、優(yōu)化器和數(shù)據(jù)擴展肯定有自己的發(fā)展空間。然而,對于這種簡單的方法來說,86%的準確率似乎足夠好了。

英文原文:https://marekpaulik.medium.com/imbalanced-dataset-image-classification-with-pytorch-6de864982eb1

喜歡的話,請給我個在看吧!
