<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          使用PyTorch來進展不平衡數(shù)據(jù)集的圖像分類

          共 20674字,需瀏覽 42分鐘

           ·

          2021-05-13 07:55


          作者:Marek Paulik

          編譯:ronghuaiyang

          來源:AI公園

          導讀

          一個非常簡單和容易上手的例子。


          對于教程中使用的大多數(shù)人工數(shù)據(jù)集,每個類都有相同數(shù)量的數(shù)據(jù)。然而,在實際應(yīng)用中,這種情況很少發(fā)生。今天,我將給你介紹來自Kaggle的木薯葉分類,并告訴你當類頻率有很大差異時該怎么做。

          處理類別的不平衡

          有兩種方法可以解決這個問題。

          • WeightedRandomSampler
          • loss函數(shù)中的weight參數(shù)

          下一步是創(chuàng)建一個有5個方法的CassavaClassifier類:load_data()、load_model()、fit_one_epoch()、val_one_epoch()和fit()。

          在load_data()中,將構(gòu)造一個train和驗證數(shù)據(jù)集,并返回數(shù)據(jù)加載器以供進一步使用。

          在load_model()中定義了體系結(jié)構(gòu)、損失函數(shù)和優(yōu)化器。

          fit方法包含一些初始化和對fit_one_epoch()和val_one_epoch()的循環(huán)。

          早期停止

          早期停止類有助于根據(jù)驗證損失跟蹤最佳模型,并保存檢查點。

          #Callbacks
          # Early stopping
          class EarlyStopping:
            def __init__(self, patience=1, delta=0, path='checkpoint.pt'):
              self.patience = patience
              self.delta = delta
              self.path= path
              self.counter = 0
              self.best_score = None
              self.early_stop = False

            def __call__(self, val_loss, model):
              if self.best_score is None:
                self.best_score = val_loss
                self.save_checkpoint(model)
              elif val_loss > self.best_score:
                self.counter +=1
                if self.counter >= self.patience:
                  self.early_stop = True 
              else:
                self.best_score = val_loss
                self.save_checkpoint(model)
                self.counter = 0      

            def save_checkpoint(self, model):
              torch.save(model.state_dict(), self.path)

          Init

          我們首先初始化CassavaClassifier類。

          class CassavaClassifier():
              def __init__(self, data_dir, num_classes, device, Transform=None, sample=False, loss_weights=False, batch_size=16,
               lr=1e-4, stop_early=True, freeze_backbone=True)
          :

              #############################################################################################################
              # data_dir - directory with images in subfolders, subfolders name are categories
              # Transform - data augmentations
              # sample - if the dataset is imbalanced set to true and RandomWeightedSampler will be used
              # loss_weights - if the dataset is imbalanced set to true and weight parameter will be passed to loss function
              # freeze_backbone - if using pretrained architecture freeze all but the classification layer
              ###############################################################################################################
                  self.data_dir = data_dir
                  self.num_classes = num_classes
                  self.device = device
                  self.sample = sample
                  self.loss_weights = loss_weights
                  self.batch_size = batch_size
                  self.lr = lr
                  self.stop_early = stop_early
                  self.freeze_backbone = freeze_backbone
                  self.Transform = Transform

          Load Data

          訓練圖像被組織在子文件夾中,子文件夾名稱表示圖像的類。這是圖像分類問題的典型情況,幸運的是,不需要編寫自定義數(shù)據(jù)集類。在這種情況下,可以立即使用torchvision中的ImageFolder。如果你想使用WeightedRandomSampler,你需要為數(shù)據(jù)集的每個元素指定一個權(quán)重。通常,總圖像總比上類別數(shù)被用作一個權(quán)重。

          def load_data(self):
              train_full = torchvision.datasets.ImageFolder(self.data_dir, transform=self.Transform)
              train_set, val_set = random_split(train_full, [math.floor(len(train_full)*0.8), math.ceil(len(train_full)*0.2)])

              self.train_classes = [label for _, label in train_set]
              if self.sample:
                  # Need to get weight for every image in the dataset
                  class_count = Counter(self.train_classes)
                  class_weights = torch.Tensor([len(self.train_classes)/c for c in pd.Series(class_count).sort_index().values]) 
                  # Can't iterate over class_count because dictionary is unordered

                  sample_weights = [0] * len(train_set)
                  for idx, (image, label) in enumerate(train_set):
                      class_weight = class_weights[label]
                      sample_weights[idx] = class_weight

                  sampler = WeightedRandomSampler(weights=sample_weights,
                                                  num_samples = len(train_set), replacement=True)  
                  train_loader = DataLoader(train_set, batch_size=self.batch_size, sampler=sampler)
              else:
                  train_loader = DataLoader(train_set, batch_size=self.batch_size, shuffle=True)

              val_loader = DataLoader(val_set, batch_size=self.batch_size)

              return train_loader, val_loader

          Load Model

          在該方法中,我使用遷移學習,架構(gòu)參數(shù)從預先訓練的resnet50和efficientnet-b7中選擇。CrossEntropyLoss和許多其他損失函數(shù)都有權(quán)重參數(shù)。這是一個手動調(diào)整參數(shù),用于處理不平衡。在這種情況下,不需要為每個參數(shù)定義權(quán)重,只需為每個類定義權(quán)重。

          def load_model(self, arch='resnet'):
              ##############################################################################################################
              # arch - choose the pretrained architecture from resnet or efficientnetb7
              ############################################################################################################## 
              if arch == 'resnet':
                  self.model = torchvision.models.resnet50(pretrained=True)
                  if self.freeze_backbone:
                      for param in self.model.parameters():
                          param.requires_grad = False
                  self.model.fc = nn.Linear(in_features=self.model.fc.in_features, out_features=self.num_classes)
              elif arch == 'efficient-net':
                  self.model = EfficientNet.from_pretrained('efficientnet-b7')
                  if self.freeze_backbone:
                      for param in self.model.parameters():
                          param.requires_grad = False
                  self.model._fc = nn.Linear(in_features=self.model._fc.in_features, out_features=self.num_classes)    

              self.model = self.model.to(self.device)

              self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr) 

              if self.loss_weights:
                  class_count = Counter(self.train_classes)
                  class_weights = torch.Tensor([len(self.train_classes)/c for c in pd.Series(class_count).sort_index().values])
                  # Cant iterate over class_count because dictionary is unordered
                  class_weights = class_weights.to(self.device)  
                  self.criterion = nn.CrossEntropyLoss(class_weights)
              else:
                  self.criterion = nn.CrossEntropyLoss() 

          Fit One Epoch

          這個方法只包含一個經(jīng)典的訓練循環(huán),帶有訓練損失記錄和tqdm進度條。

          def fit_one_epoch(self, train_loader, epoch, num_epochs ): 
              step_train = 0

              train_losses = list() # Every epoch check average loss per batch 
              train_acc = list()
              self.model.train()
              for i, (images, targets) in enumerate(tqdm(train_loader)):
                  images = images.to(self.device)
                  targets = targets.to(self.device)

                  logits = self.model(images)
                  loss = self.criterion(logits, targets)

                  loss.backward()
                  self.optimizer.step()

                  self.optimizer.zero_grad()

                  train_losses.append(loss.item())

                  #Calculate running train accuracy
                  predictions = torch.argmax(logits, dim=1)
                  num_correct = sum(predictions.eq(targets))
                  running_train_acc = float(num_correct) / float(images.shape[0])
                  train_acc.append(running_train_acc)
                  
              train_loss = torch.tensor(train_losses).mean()    
              print(f'Epoch {epoch}/{num_epochs-1}')  
              print(f'Training loss: {train_loss:.2f}')

          Validate one epoch

          與上面類似,但此方法在驗證數(shù)據(jù)加載器上迭代。在每一個epoch'之后,平均batch損失和準確性被打印出來。

          def val_one_epoch(self, val_loader, scaler):
                  val_losses = list()
                  val_accs = list()
                  self.model.eval()
                  step_val = 0
                  with torch.no_grad():
                      for (images, targets) in val_loader:
                          images = images.to(self.device)
                          targets = targets.to(self.device)

                          logits = self.model(images)
                          loss = self.criterion(logits, targets)
                          val_losses.append(loss.item())      
                      
                          predictions = torch.argmax(logits, dim=1)
                          num_correct = sum(predictions.eq(targets))
                          running_val_acc = float(num_correct) / float(images.shape[0])

                          val_accs.append(running_val_acc)
                      

                      self.val_loss = torch.tensor(val_losses).mean()
                      val_acc = torch.tensor(val_accs).mean() # Average acc per batch
                  
                      print(f'Validation loss: {self.val_loss:.2f}')  
                      print(f'Validation accuracy: {val_acc:.2f}'

          Fit

          Fit方法在訓練和驗證過程中經(jīng)歷了許多階段和循環(huán)。如果預訓練模型的參數(shù)在開始時被凍結(jié),那么unfreeze_after定義了整個模型在多少個epoch之后開始訓練。在此之前,只訓練全連接層(分類器)。

          def fit(self, train_loader, val_loader, num_epochs=10, unfreeze_after=5, checkpoint_dir='checkpoint.pt'):
              if self.stop_early:
                  early_stopping = EarlyStopping(
                  patience=5
                  path=checkpoint_dir)
            
              for epoch in range(num_epochs):
                  if self.freeze_backbone:
                      if epoch == unfreeze_after:  # Unfreeze after x epochs
                          for param in self.model.parameters():
                              param.requires_grad = True
                  self.fit_one_epoch(train_loader, scaler, epoch, num_epochs)
                  self.val_one_epoch(val_loader, scaler)
                  if self.stop_early:
                      early_stopping(self.val_loss, self.model)
                      if early_stopping.early_stop:
                          print('Early Stopping')
                          print(f'Best validation loss: {early_stopping.best_score}')
                          break

          Run

          現(xiàn)在,可以初始化CassavaClassifier類、創(chuàng)建dataloaders、設(shè)置模型并運行整個過程了。

          Transform = T.Compose(
                              [T.ToTensor(),
                              T.Resize((256256)),
                              T.RandomRotation(90),
                              T.RandomHorizontalFlip(p=0.5),
                              T.Normalize((0.4850.4560.406), (0.2290.2240.225))])

          device = torch.device('cuda'if torch.cuda.is_available() else torch.device('cpu')
          data_dir = "Data/cassava-disease/train/train"

          classifier = CassavaClassifier(data_dir=data_dir, num_classes=5, device=device, sample=True, Transform=Transform)
          train_loader, val_loader = classifier.load_data()
          classifier.load_model()
          classifier.fit(num_epochs=20, unfreeze_after=5, train_loader=train_loader, val_loader=val_loader)

          Inference

          使用ImageFolder加載測試數(shù)據(jù)是不可能的,因為顯然沒有帶有類的子文件夾。因此,我創(chuàng)建了一個返回圖像和圖像id的自定義數(shù)據(jù)集。隨后,加載模型檢查點,通過推理循環(huán)運行它,并將預測保存到數(shù)據(jù)幀中。將數(shù)據(jù)幀導出為CSV并提交結(jié)果。

          # Inference
          model = torchvision.models.resnet50()
          #model = EfficientNet.from_name('efficientnet-b7')
          model.fc = nn.Linear(in_features=model.fc.in_features, out_features=5)
          model = model.to(device)
          checkpoint = torch.load('Data/cassava-disease/sampler_checkpoint.pt')
          model.load_state_dict(checkpoint)
          model.eval()


          # Dataset for test data
          class Cassava_Test(Dataset):
            def __init__(self, dir, transform=None):
              self.dir = dir
              self.transform = transform

              self.images = os.listdir(self.dir)  

            def __len__(self):
              return len(self.images)

            def __getitem__(self, idx):
              img = Image.open(os.path.join(self.dir, self.images[idx]))
              return self.transform(img), self.images[idx] 


          test_dir = 'Data/cassava-disease/test/test/0'
          test_set = Cassava_Test(test_dir, transform=Transform)
          test_loader = DataLoader(test_set, batch_size=4)  

          # Test loop
          sub = pd.DataFrame(columns=['category''id'])
          id_list = []
          pred_list = []

          model = model.to(device)

          with torch.no_grad():
            for (image, image_id) in test_loader:
              image = image.to(device)

              logits = model(image)
              predicted = list(torch.argmax(logits, 1).cpu().numpy())

              for id in image_id:
                id_list.append(id)
            
              for prediction in predicted:
                pred_list.append(prediction)
          sub['category'] = pred_list
          sub['id'] = id_list

          mapping = {0:'cbb'1:'cbsd'2:'cgm'3:'cmd'4:'healthy'}

          sub['category'] = sub['category'].map(mapping)
          sub = sub.sort_values(by='id')

          sub.to_csv('Cassava_sub.csv', index=False)

          如果在方案中包含WeightedRandomSampler或損失權(quán)值,則測試集的精度會提高2%。對于僅僅幾行代碼來說,這是一個很好的改進。對于這個數(shù)據(jù)集,我沒有看到這兩種方法在精度上的巨大差異,但WeightedRandomSampler的表現(xiàn)要好一些。

          不同的學習速度、優(yōu)化器和數(shù)據(jù)擴展肯定有自己的發(fā)展空間。然而,對于這種簡單的方法來說,86%的準確率似乎足夠好了。


          END

          英文原文:https://marekpaulik.medium.com/imbalanced-dataset-image-classification-with-pytorch-6de864982eb1


          喜歡的話,請給我個在看吧

          瀏覽 33
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  有码精品一区二区 | 狼友视频官网 | 麻豆人妻少妇精品无码专区 | 91亚洲国产成人精品一区 | 熟女少妇内射日韩亚洲 |