<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          圖像分類訓(xùn)練技巧之?dāng)?shù)據(jù)增強(qiáng)總結(jié)

          共 38479字,需瀏覽 77分鐘

           ·

          2023-08-18 20:19

          點(diǎn)擊上方小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂

          重磅干貨,第一時(shí)間送達(dá)

          僅作學(xué)術(shù)分享,不代表本公眾號立場,侵權(quán)聯(lián)系刪除
          轉(zhuǎn)載于:作者丨小小將@知乎(已授權(quán))
          來源丨h(huán)ttps://zhuanlan.zhihu.com/p/430563265
          編輯丨極市平臺
          一個(gè)模型的性能除了和網(wǎng)絡(luò)結(jié)構(gòu)本身有關(guān),還非常依賴具體的訓(xùn)練策略,比如優(yōu)化器,數(shù)據(jù)增強(qiáng)以及正則化策略等(當(dāng)然也很訓(xùn)練數(shù)據(jù)強(qiáng)相關(guān),訓(xùn)練數(shù)據(jù)量往往決定模型性能的上線)。近年來,圖像分類模型在ImageNet數(shù)據(jù)集的top1 acc已經(jīng)由原來的56.5(AlexNet,2012)提升至90.88(CoAtNet,2021,用了額外的數(shù)據(jù)集JFT-3B),這進(jìn)步除了主要?dú)w功于模型,算力和數(shù)據(jù)的提升,也與訓(xùn)練策略的提升緊密相關(guān)。最近剛興起的vision transformer相比CNN模型往往也需要更heavy的數(shù)據(jù)增強(qiáng)和正則化策略。這里簡單介紹圖像分類訓(xùn)練技巧中的常用數(shù)據(jù)增強(qiáng)策略。

          baseline

          ImageNet數(shù)據(jù)集訓(xùn)練常用的數(shù)據(jù)增強(qiáng)策略如下,訓(xùn)練過程的數(shù)據(jù)增強(qiáng)包括隨機(jī)縮放裁剪(RandomResizedCrop,這種處理方式源自谷歌的Inception,所以稱為 Inception-style pre-processing)和水平翻轉(zhuǎn)(RandomHorizontalFlip),而測試階段是執(zhí)行縮放和中心裁剪。這其實(shí)是一種輕量級的策略,這里稱之為baseline。torchvision的實(shí)現(xiàn)的ResNet50訓(xùn)練采用的策略就是這個(gè),在ImageNet上的top1 acc可以達(dá)到76.1。

          from torchvision import transforms

          normalize = transforms.Normalize(mean=[0.4850.4560.406],
                                           std=[0.2290.2240.225])
          # 訓(xùn)練
          train_transform = transforms.Compose([
              # 這里的scale指的是面積,ratio是寬高比
              # 具體實(shí)現(xiàn)每次先隨機(jī)確定scale和ratio,可以生成w和h,然后隨機(jī)確定裁剪位置進(jìn)行crop
              # 最后是resize到target size
              transforms.RandomResizedCrop(224, scale=(0.081.0), ratio=(3. / 4.4. / 3.)),
              transforms.RandomHorizontalFlip(),
              transforms.ToTensor(),
              normalize
           ])
          # 測試
          test_transform = transforms.Compose([
              transforms.Resize(256),
              transforms.CenterCrop(224),
              transforms.ToTensor(),
              normalize,
           ])

          AutoAugment

          谷歌在2018年提出通過AutoML來自動搜索數(shù)據(jù)增強(qiáng)策略,稱之為AutoAugment(算是自動數(shù)據(jù)增強(qiáng)開山之作)。搜索方法采用強(qiáng)化學(xué)習(xí),和NAS類似,只不過搜索空間是數(shù)據(jù)增強(qiáng)策略,而不是網(wǎng)絡(luò)架構(gòu)。在搜索空間里,一個(gè)policy包含5個(gè)sub-policies,每個(gè)sub-policy包含兩個(gè)串行的圖像增強(qiáng)操作,每個(gè)增強(qiáng)操作有兩個(gè)超參數(shù):進(jìn)行該操作的概率圖像增強(qiáng)的幅度(magnitude,這個(gè)表示數(shù)據(jù)增強(qiáng)的強(qiáng)度,比如對于旋轉(zhuǎn),旋轉(zhuǎn)的角度就是增強(qiáng)幅度,旋轉(zhuǎn)角度越大,增強(qiáng)越大)。每個(gè)policy在執(zhí)行時(shí),首先隨機(jī)從5個(gè)策略中隨機(jī)選擇一個(gè)sub-policy,然后序列執(zhí)行兩個(gè)圖像操作。

          搜索空間一共有16種圖像增強(qiáng)類型,具體如下所示,大部分操作都定義了圖像增強(qiáng)的幅度范圍,在搜索時(shí)需要將幅度值離散化,具體地是將幅度值在定義范圍內(nèi)均勻地取10個(gè)值。

          論文在不同的數(shù)據(jù)集上( CIFAR-10 , SVHN, ImageNet)做了實(shí)驗(yàn),這里給出在ImageNet數(shù)據(jù)集上搜索得到的最優(yōu)policy(最后實(shí)際上是將搜索得到的前5個(gè)最好的policies合成了一個(gè)policy,所以這里包含25個(gè)sub-policies):

          # operation, probability, magnitude
          (("Posterize"0.48), ("Rotate"0.69)),
          (("Solarize"0.65), ("AutoContrast"0.6None)),                                                          
          (("Equalize"0.8None), ("Equalize"0.6None)),
          (("Posterize"0.67), ("Posterize"0.66)),
          (("Equalize"0.4None), ("Solarize"0.24)),
          (("Equalize"0.4None), ("Rotate"0.88)),
          (("Solarize"0.63), ("Equalize"0.6None)),
          (("Posterize"0.85), ("Equalize"1.0None)),
          (("Rotate"0.23), ("Solarize"0.68)),
          (("Equalize"0.6None), ("Posterize"0.46)),
          (("Rotate"0.88), ("Color"0.40)),
          (("Rotate"0.49), ("Equalize"0.6None)),
          (("Equalize"0.0None), ("Equalize"0.8None)),
          (("Invert"0.6None), ("Equalize"1.0None)),
          (("Color"0.64), ("Contrast"1.08)),
          (("Rotate"0.88), ("Color"1.02)),
          (("Color"0.88), ("Solarize"0.87)),
          (("Sharpness"0.47), ("Invert"0.6None)),
          (("ShearX"0.65), ("Equalize"1.0None)),
          (("Color"0.40), ("Equalize"0.6None)),
          (("Equalize"0.4None), ("Solarize"0.24)),
          (("Solarize"0.65), ("AutoContrast"0.6None)),
          (("Invert"0.6None), ("Equalize"1.0None)),
          (("Color"0.64), ("Contrast"1.08)),
          (("Equalize"0.8None), ("Equalize"0.6None))

          基于搜索得到的AutoAugment訓(xùn)練可以將ResNet50在ImageNet數(shù)據(jù)集上的top1 acc從76.3提升至77.6。一個(gè)比較重要的問題,這些從某一個(gè)數(shù)據(jù)集搜索得到的策略是否只對固定的數(shù)據(jù)集有效,論文也通過具體實(shí)驗(yàn)證明了AutoAugment的遷移能力,比如將ImageNet數(shù)據(jù)集上得到的策略用在5個(gè) FGVC數(shù)據(jù)集(與ImageNet圖像輸入大小相似)也均有提升。

          目前torchvision庫已經(jīng)實(shí)現(xiàn)了AutoAugment,具體使用如下所示(注意AutoAug前也需要包括一個(gè)RandomResizedCrop):

          from torchvision.transforms import autoaugment, transforms

          train_transform = transforms.Compose([
              transforms.RandomResizedCrop(crop_size, interpolation=interpolation),
              transforms.RandomHorizontalFlip(hflip_prob),
              # 這里policy屬于torchvision.transforms.autoaugment.AutoAugmentPolicy,
              # 對于ImageNet就是 AutoAugmentPolicy.IMAGENET
              # 此時(shí)aa_policy = autoaugment.AutoAugmentPolicy('imagenet')
              autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation),
           transforms.PILToTensor(),
              transforms.ConvertImageDtype(torch.float),
              transforms.Normalize(mean=mean, std=std)
           ])

          RandAugment

          AutoAugment存在的一個(gè)問題是搜索空間巨大,這使得搜索只能在代理任務(wù)中進(jìn)行:使用小的模型在ImageNet的一個(gè)小的子集( 120類和6000圖片)搜索。谷歌在2019年又提出了一個(gè)更簡單的數(shù)據(jù)增強(qiáng)策略:RandAugment。這篇論文首先發(fā)現(xiàn)AutoAugment這樣在小數(shù)據(jù)集上搜索出來的策略在大的數(shù)據(jù)集上應(yīng)用會存在問題,這主要是因?yàn)閿?shù)據(jù)增強(qiáng)策略和模型大小和數(shù)據(jù)量大小存在強(qiáng)相關(guān),如下圖所示可以看到模型或者訓(xùn)練數(shù)據(jù)量越大,其最優(yōu)的數(shù)據(jù)增強(qiáng)的幅度越大,這說明AutoAugment得到的結(jié)果應(yīng)該是次優(yōu)的。另外,Population Based Augmentation這篇論文發(fā)現(xiàn)最優(yōu)的數(shù)據(jù)增強(qiáng)幅度是隨訓(xùn)練過程增加,而且不同的增強(qiáng)操作遵循類似的規(guī)律,這啟發(fā)作者采用固定的增強(qiáng)幅度而不是去搜索。RandAugment相比AutoAugment的策略空間很?。?span style="outline: 0px;max-width: 100%;cursor: pointer;box-sizing: border-box !important;overflow-wrap: break-word !important;">  vs  ),所以它不需要采用代理任務(wù),甚至直接采用簡單的網(wǎng)格搜索。

          具體地,RandAugment共包含兩個(gè)超參數(shù):圖像增強(qiáng)操作的數(shù)量N和一個(gè)全局的增強(qiáng)幅度M,其實(shí)現(xiàn)代碼如下所示,每次從候選操作集合(共14種策略)隨機(jī)選擇N個(gè)操作(等概率),然后串行執(zhí)行(這里沒有判斷概率,是一定執(zhí)行)。這里的M取值范圍為{0, . . . , 30}(每個(gè)圖像增強(qiáng)操作歸一化到同樣的幅度范圍),而N取值范圍一般為 {1, 2, 3}。

          # Identity是恒等變換,不做任何增強(qiáng)
          transforms = ['Identity''AutoContrast''Equalize''Rotate''Solarize''Color''Posterize'
                        'Contrast''Brightness''Sharpness''ShearX''ShearY''TranslateX''TranslateY']

          def randaugment(N, M):
           """Generate a set of distortions.
           Args:
           N: Number of augmentation transformations to
           apply sequentially.
           M: Magnitude for all the transformations.
           """

           sampled_ops = np.random.choice(transforms, N)
           return [(op, M) for op in sampled_ops]

          對于ResNet50,其搜索得到的N=2,M=9,RandAugment相比AutoAugment可以在ImageNet得到相似的效果(77.6),不過DeiT中發(fā)現(xiàn)使用RandAugment效果更好一些( DeiT-B:81.8 vs 81.2)。目前torchvision庫也已經(jīng)實(shí)現(xiàn)了RandAugment,具體使用如下所示:

          from torchvision.transforms import autoaugment, transforms

          train_transform = transforms.Compose([
              transforms.RandomResizedCrop(crop_size, interpolation=interpolation),
              transforms.RandomHorizontalFlip(hflip_prob),
              autoaugment.RandAugment(interpolation=interpolation),
           transforms.PILToTensor(),
              transforms.ConvertImageDtype(torch.float),
              transforms.Normalize(mean=mean, std=std)
           ])

          TrivialAugment

          雖然RandAugment的搜索空間極小,但是對于不同的數(shù)據(jù)集還是需要確定最優(yōu)的N和M,這依然有較大的實(shí)驗(yàn)成本。RandAugment后,華為提出了UniformAugment,這種策略不需要搜索也能取得較好的結(jié)果。不過這里我們介紹一項(xiàng)更新的工作:TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation。TrivialAugment也不需要任何搜索,整個(gè)方法非常簡單:每次隨機(jī)選擇一個(gè)圖像增強(qiáng)操作,然后隨機(jī)確定它的增強(qiáng)幅度,并對圖像進(jìn)行增強(qiáng)。由于沒有任何超參數(shù),所以不需要任何搜索。從實(shí)驗(yàn)結(jié)果上看,TA可以在多個(gè)數(shù)據(jù)集上取得更好的結(jié)果,如在ImageNet數(shù)據(jù)集上,ResNet50的top1 acc可以達(dá)到78.1,超過RandAugment。

          TrivialAugment的圖像增強(qiáng)集合和RandAugment基本一樣,不過TA也定義了一套更寬的增強(qiáng)幅度,目前torchvision中已經(jīng)實(shí)現(xiàn)了TrivialAugmentWide,具體使用代碼如下所示:

          from torchvision.transforms import autoaugment, transforms

          augmentation_space = {
              # op_name: (magnitudes, signed)
              "Identity": (torch.tensor(0.0), False),
              "ShearX": (torch.linspace(0.00.99, num_bins), True),
              "ShearY": (torch.linspace(0.00.99, num_bins), True),
              "TranslateX": (torch.linspace(0.032.0, num_bins), True),
              "TranslateY": (torch.linspace(0.032.0, num_bins), True),
              "Rotate": (torch.linspace(0.0135.0, num_bins), True),
              "Brightness": (torch.linspace(0.00.99, num_bins), True),
              "Color": (torch.linspace(0.00.99, num_bins), True),
              "Contrast": (torch.linspace(0.00.99, num_bins), True),
              "Sharpness": (torch.linspace(0.00.99, num_bins), True),
              "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
              "Solarize": (torch.linspace(255.00.0, num_bins), False),
              "AutoContrast": (torch.tensor(0.0), False),
              "Equalize": (torch.tensor(0.0), False),
          }

          train_transform = transforms.Compose([
              transforms.RandomResizedCrop(crop_size, interpolation=interpolation),
              transforms.RandomHorizontalFlip(hflip_prob),
              autoaugment.TrivialAugmentWide(interpolation=interpolation),
           transforms.PILToTensor(),
              transforms.ConvertImageDtype(torch.float),
              transforms.Normalize(mean=mean, std=std)
           ])

          RandomErasing

          RandomErasing是廈門大學(xué)在2017年提出的一種簡單的數(shù)據(jù)增強(qiáng)(這個(gè)策略和同期的CutOut基本一樣),基本原理是:隨機(jī)從圖像中擦除一個(gè)矩形區(qū)域而不改變圖像的原始標(biāo)簽。DeiT的訓(xùn)練策略中也包括了RandomErasing。

          目前torchvision也實(shí)現(xiàn)了RandomErasing,其具體使用代碼如下(注意這個(gè)op不支持PIL圖像,需要在轉(zhuǎn)換為torch.tensor后使用):

          train_transform = transforms.Compose([
              transforms.RandomResizedCrop(224, scale=(0.081.0), ratio=(3. / 4.4. / 3.)),
              transforms.RandomHorizontalFlip(),
              transforms.PILToTensor()
              transforms.ConvertImageDtype(torch.float),
              normalize,
              # scale是指相對于原圖的擦除面積范圍
              # ratio是指擦除區(qū)域的寬高比
              # value是指擦除區(qū)域的值,如果是int,也可以是tuple(RGB3個(gè)通道值),或者是str,需為'random',表示隨機(jī)生成
              transforms.RandomErasing(p=0.5, scale=(0.020.33), ratio=(0.33.3), value=0, inplace=False),
           ])

          MixUp

          MixUp在FAIR在2017年提出的一種數(shù)據(jù)增強(qiáng)方法:兩張不同的圖像隨機(jī)線性組合,而同時(shí)生成線性組合的標(biāo)簽。



          這里的 是兩張不同的圖像, 是它們對應(yīng)的one-hot標(biāo)簽,而 是線性組合系數(shù),每次執(zhí)行時(shí)隨機(jī)生成。假定圖像分類任務(wù)是2分類(區(qū)分狗和貓),兩張輸入圖像分別是狗和貓(如下圖所示),它們對應(yīng)的one-hot標(biāo)簽分別是[1,0]和[0, 1]。在進(jìn)行mixup之前,首先對它們進(jìn)行必要的數(shù)據(jù)增強(qiáng)得到aug_img1和aug_img2,然后隨機(jī)生成線性組合系數(shù),對于 得到的圖像是mix_img1,標(biāo)簽變?yōu)閇0.7, 0.3],而 得到的圖像是mix_img2,標(biāo)簽變?yōu)閇0.3, 0.7]。

          目前timm和torchvision中已經(jīng)實(shí)現(xiàn)了mixup,這里以torchvision為例來講述具體的代碼實(shí)現(xiàn)。由于mixup需要兩個(gè)輸入,而不單單是對當(dāng)前圖像進(jìn)行操作,所以一般是在得到batch數(shù)據(jù)后再進(jìn)行mixup,這也意味著圖像也已經(jīng)完成了其它的數(shù)據(jù)增強(qiáng)如RandAugment,對于batch中的每個(gè)樣本可以隨機(jī)選擇另外一個(gè)樣本進(jìn)行mixup。具體的實(shí)現(xiàn)代碼如下所示:

          # from https://github.com/pytorch/vision/blob/main/references/classification/transforms.py
          class RandomMixup(torch.nn.Module):
              """Randomly apply Mixup to the provided batch and targets.
              The class implements the data augmentations as described in the paper
              `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
              Args:
                  num_classes (int): number of classes used for one-hot encoding.
                  p (float): probability of the batch being transformed. Default value is 0.5.
                  alpha (float): hyperparameter of the Beta distribution used for mixup.
                      Default value is 1.0. # beta分布超參數(shù)
                  inplace (bool): boolean to make this transform inplace. Default set to False.
              """


              def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
                  super().__init__()
                  assert num_classes > 0"Please provide a valid positive value for the num_classes."
                  assert alpha > 0"Alpha param can't be zero."

                  self.num_classes = num_classes
                  self.p = p
                  self.alpha = alpha
                  self.inplace = inplace

              def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
                  """
                  Args:
                      batch (Tensor): Float tensor of size (B, C, H, W)
                      target (Tensor): Integer tensor of size (B, )
                  Returns:
                      Tensor: Randomly transformed batch.
                  """

                  if batch.ndim != 4:
                      raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
                  if target.ndim != 1:
                      raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
                  if not batch.is_floating_point():
                      raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
                  if target.dtype != torch.int64:
                      raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")

                  if not self.inplace:
                      batch = batch.clone()
                      target = target.clone()
            
                  # 建立one-hot標(biāo)簽
                  if target.ndim == 1:
                      target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
            
                  # 判斷是否進(jìn)行mixup
                  if torch.rand(1).item() >= self.p:
                      return batch, target
            
                  # 這里將batch數(shù)據(jù)平移一個(gè)單位,產(chǎn)生mixup的圖像對,這意味著每個(gè)圖像與相鄰的下一個(gè)圖像進(jìn)行mixup
                  # timm實(shí)現(xiàn)是通過flip來做的,這意味著第一個(gè)圖像和最后一個(gè)圖像進(jìn)行mixup
                  # It's faster to roll the batch by one instead of shuffling it to create image pairs
                  batch_rolled = batch.roll(10)
                  target_rolled = target.roll(10)
            
                  # 隨機(jī)生成組合系數(shù)
                  # Implemented as on mixup paper, page 3.
                  lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
                  batch_rolled.mul_(1.0 - lambda_param)
                  batch.mul_(lambda_param).add_(batch_rolled) # 得到mixup后的圖像

                  target_rolled.mul_(1.0 - lambda_param)
                  target.mul_(lambda_param).add_(target_rolled) # 得到mixup后的標(biāo)簽

                  return batch, target

          然后可以將MixUp操作放在DataLoader的collate_fn中,這個(gè)函數(shù)要實(shí)現(xiàn)的是將多個(gè)樣本合并成一個(gè)mini-batch,所以可以將MixUp插在得到mini-batch后,具體實(shí)現(xiàn)如下所示:

          from torch.utils.data.dataloader import default_collate

          mixup_transform = RandomMixup(num_classes, p=1.0, alpha=mixup_alpha)
          collate_fn = lambda batch: mixup_transform(*default_collate(batch))
          data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
              sampler=train_sampler, collate_fn=collate_fn)

          對于MixUp,還要注意兩個(gè)兩點(diǎn)。第一個(gè)是如果同時(shí)采用了label smoothing,那么在創(chuàng)建one-hot標(biāo)簽時(shí)要直接得到smooth后的標(biāo)簽,具體實(shí)現(xiàn)如下(參考timm):

          def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
              x = x.long().view(-11)
              return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)

          off_value = smoothing / num_classes
          on_value = 1. - smoothing + off_value
          smooth_one_hot = one_hot(target, num_classes, on_value=on_value, off_value=off_value)

          第二個(gè)要注意的是MixUp后得到標(biāo)簽時(shí)soft label,不能直接采用torch.nn.CrossEntropyLoss來計(jì)算loss,而是直接計(jì)算交叉熵(參考timm):

          class SoftTargetCrossEntropy(nn.Module):

              def __init__(self):
                  super(SoftTargetCrossEntropy, self).__init__()

              def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
                  loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
                  return loss.mean()

          注意在PyTorch1.10版本之后,torch.nn.CrossEntropyLoss已經(jīng)支持直接送入的target是probabilities for each class,原來只支持target是class indices;而且也支持label_smoothing參數(shù),所以上述兩個(gè)注意點(diǎn)就不再需要了。

          說到計(jì)算loss,timm作者近期在ResNet strikes back: An improved training procedure in timm指出采用MixUp后可以將多分類改成多標(biāo)簽分類(multi-label classification),即從N分類變成N個(gè)2分類(直接采用BinaryCrossEntropy),這應(yīng)該更符合MixUp后圖像的語義,從對比實(shí)驗(yàn)來看效果有微弱的提升。MixUp除了可以用于圖像分類任務(wù),還可以用于物體檢測任務(wù)中,比如YOLOX就采用了MixUp,這里面的做法是對圖像mixup后,其box為兩個(gè)圖像的box的合并集合,而沒有對標(biāo)簽軟化,這塊也可以見論文Bag of Freebies for Training Object Detection Neural Networks。

          CutMix

          CutMix是2019年提出的一項(xiàng)和MixUp和類似的數(shù)據(jù)增強(qiáng)策略,它也是同時(shí)對兩個(gè)圖像和標(biāo)簽進(jìn)行混合,與MixUp不同的是它的圖像混合方式。CutMix不是對兩個(gè)圖像線性組合,而是從另外一張圖像隨機(jī)剪切一個(gè)patch并粘貼到第一張圖像上,patch的起始坐標(biāo)隨機(jī)生成,而寬高是由 來控制:



          這里 是原始圖像的寬和高,所以 其實(shí)決定的是patch和原圖的面積比: 。下圖展示了 分別取0.7和0.3的混合效果, 越小,粘貼的patch越大。對于標(biāo)簽,其處理方式和MixUp一樣,通過 來得到兩張圖像的線性組合。

          CutMix做了ImageNet上的對比實(shí)驗(yàn),相比MixUp,ResNet50的top1 acc大約能提升一個(gè)點(diǎn)(77.4 vs 78.6):

          目前timm和torchvision中也已經(jīng)實(shí)現(xiàn)了CutMix,這里還是以torchvision為例來講述具體的代碼實(shí)現(xiàn),如下所示(和MixUp基本類似,只不過內(nèi)部處理存在差異):

          class RandomCutmix(torch.nn.Module):
              """Randomly apply Cutmix to the provided batch and targets.
              The class implements the data augmentations as described in the paper
              `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
              <https://arxiv.org/abs/1905.04899>`_.
              Args:
                  num_classes (int): number of classes used for one-hot encoding.
                  p (float): probability of the batch being transformed. Default value is 0.5.
                  alpha (float): hyperparameter of the Beta distribution used for cutmix.
                      Default value is 1.0.
                  inplace (bool): boolean to make this transform inplace. Default set to False.
              """


              def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
                  super().__init__()
                  assert num_classes > 0"Please provide a valid positive value for the num_classes."
                  assert alpha > 0"Alpha param can't be zero."

                  self.num_classes = num_classes
                  self.p = p
                  self.alpha = alpha
                  self.inplace = inplace

              def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
                  """
                  Args:
                      batch (Tensor): Float tensor of size (B, C, H, W)
                      target (Tensor): Integer tensor of size (B, )
                  Returns:
                      Tensor: Randomly transformed batch.
                  """

                  if batch.ndim != 4:
                      raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
                  if target.ndim != 1:
                      raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
                  if not batch.is_floating_point():
                      raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
                  if target.dtype != torch.int64:
                      raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")

                  if not self.inplace:
                      batch = batch.clone()
                      target = target.clone()

                  if target.ndim == 1:
                      target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)

                  if torch.rand(1).item() >= self.p:
                      return batch, target

                  # It's faster to roll the batch by one instead of shuffling it to create image pairs
                  batch_rolled = batch.roll(10)
                  target_rolled = target.roll(10)

                  # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
                  lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
                  W, H = F.get_image_size(batch)
            
                  # 確定patch的起點(diǎn)
                  r_x = torch.randint(W, (1,))
                  r_y = torch.randint(H, (1,))
            
                  # 確定patch的w和h(其實(shí)是一半大小)
                  r = 0.5 * math.sqrt(1.0 - lambda_param)
                  r_w_half = int(r * W)
                  r_h_half = int(r * H)
            
                  # 越界處理
                  x1 = int(torch.clamp(r_x - r_w_half, min=0))
                  y1 = int(torch.clamp(r_y - r_h_half, min=0))
                  x2 = int(torch.clamp(r_x + r_w_half, max=W))
                  y2 = int(torch.clamp(r_y + r_h_half, max=H))

                  batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
                  # 由于越界處理, λ可能發(fā)生改變,所以要重新計(jì)算
                  lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

                  target_rolled.mul_(1.0 - lambda_param)
                  target.mul_(lambda_param).add_(target_rolled)

                  return batch, target

          其它使用和MixUp一樣。

          Repeated Augmentation

          Repeated Augmentation (RA)是FAIR在MultiGrain提出的一種抽樣策略,一般情況下,訓(xùn)練的mini-batch包含的增強(qiáng)過的sample都是來自不同的圖像,但是RA這種抽樣策略允許一個(gè)mini-batch中包含來自同一個(gè)圖像的不同增強(qiáng)版本,此時(shí)mini-batch的各個(gè)樣本并非是完全獨(dú)立的,這相當(dāng)于對同一個(gè)樣本進(jìn)行重復(fù)抽樣,所以稱為Repeated Augmentation。這篇論文認(rèn)為在一個(gè)mini-batch學(xué)習(xí)來自同一個(gè)圖像的不同增強(qiáng)版本能讓模型更容易學(xué)習(xí)到增強(qiáng)不變的特征。關(guān)于RA,其實(shí)另外一篇較早的論文Augment your batch: better training with larger batches也提出了類似的策略,另外DeepMind在最近的論文Drawing Multiple Augmentation Samples Per Image During Training Efficiently Decreases Test Error也進(jìn)一步通過實(shí)驗(yàn)來證明這種策略的效果。

          DeiT的訓(xùn)練也采用了RA,嚴(yán)格來說RA不屬于數(shù)據(jù)增強(qiáng)策略,而是一種mini-batch抽樣方法,這里也簡單給出DeiT實(shí)現(xiàn)的RA(可以替換torch.utils.data.DistributedSampler):

          class RASampler(torch.utils.data.Sampler):
              """Sampler that restricts data loading to a subset of the dataset for distributed,
              with repeated augmentation.
              It ensures that different each augmented version of a sample will be visible to a
              different process (GPU)
              Heavily based on torch.utils.data.DistributedSampler
              """


              def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
                  if num_replicas is None:
                      if not dist.is_available():
                          raise RuntimeError("Requires distributed package to be available")
                      num_replicas = dist.get_world_size()
                  if rank is None:
                      if not dist.is_available():
                          raise RuntimeError("Requires distributed package to be available")
                      rank = dist.get_rank()
                  self.dataset = dataset
                  self.num_replicas = num_replicas
                  self.rank = rank
                  self.epoch = 0
                  # 重復(fù)采樣后每個(gè)replica的樣本量
                  self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
                  # 重復(fù)采樣后的總樣本量
                  self.total_size = self.num_samples * self.num_replicas
                  # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
                  # 每個(gè)replica實(shí)際樣本量,即不重復(fù)采樣時(shí)的每個(gè)replica的樣本量
                  self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
                  self.shuffle = shuffle

              def __iter__(self):
                  # deterministically shuffle based on epoch
                  g = torch.Generator()
                  g.manual_seed(self.epoch)
                  if self.shuffle:
                      indices = torch.randperm(len(self.dataset), generator=g).tolist()
                  else:
                      indices = list(range(len(self.dataset)))

                  # add extra samples to make it evenly divisible
                  indices = [ele for ele in indices for i in range(3)] # 重復(fù)3次
                  indices += indices[:(self.total_size - len(indices))]
                  assert len(indices) == self.total_size

                  # subsample: 使得同一個(gè)樣本的重復(fù)版本進(jìn)入不同的進(jìn)程(GPU)
                  indices = indices[self.rank:self.total_size:self.num_replicas]
                  assert len(indices) == self.num_samples

                  return iter(indices[:self.num_selected_samples]) # 截取實(shí)際樣本量

              def __len__(self):
                  return self.num_selected_samples

              def set_epoch(self, epoch):
                  self.epoch = epoch

          小結(jié)

          這里簡單介紹了幾種常用且有效的數(shù)據(jù)增強(qiáng)策略,這些策略在vision transformer模型被使用,而且timm訓(xùn)練的ResNet新baseline也使用了這些策略。

          參考

          1. Training data-efficient image transformers & distillation through attention  (https://arxiv.org/abs/2012.12877)
          2. AutoAugment: Learning Augmentation Policies from Data  (https://arxiv.org/abs/1805.09501)
          3. RandAugment: Practical automated data augmentation with a reduced search space  (https://arxiv.org/abs/1909.13719)
          4. TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation  (https://arxiv.org/abs/2103.10158)
          5. Random Erasing Data Augmentation(https://arxiv.org/abs/1708.04896)
          6. Augment your batch: better training with larger batches  (https://arxiv.org/abs/1901.09335)
          7. MultiGrain: a unified image embedding for classes and instances(https://arxiv.org/abs/1902.05509)
          8. mixup: Beyond Empirical Risk Minimization  (https://arxiv.org/abs/1710.09412)
          9. CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features  (https://arxiv.org/abs/1905.04899)
                
          下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程
          在「小白學(xué)視覺」公眾號后臺回復(fù):擴(kuò)展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺、目標(biāo)跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。

          下載2:Python視覺實(shí)戰(zhàn)項(xiàng)目52講
          小白學(xué)視覺公眾號后臺回復(fù):Python視覺實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計(jì)數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個(gè)視覺實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺。

          下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講
          小白學(xué)視覺公眾號后臺回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講,即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。

          交流群


          歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器自動駕駛、計(jì)算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會逐漸細(xì)分),請掃描下面微信號加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~


          瀏覽 224
          點(diǎn)贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評論
          圖片
          表情
          推薦
          點(diǎn)贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  亚洲国产色婷婷 | 超碰成人福利在线 | 精品国产成人 | 俺也去新网 | 久草综合在线视频 |