<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          再也不用擔(dān)心過擬合的問題了

          共 7184字,需瀏覽 15分鐘

           ·

          2021-03-31 20:32

          點擊上方“程序員大白”,選擇“星標(biāo)”公眾號

          重磅干貨,第一時間送達(dá)

          作者 | Sean Benhur J

          編譯 | ronghuaiyang

          轉(zhuǎn)自 | AI公園


          使用SAM(銳度感知最小化),優(yōu)化到損失的最平坦的最小值的地方,增強(qiáng)泛化能力。


          論文:https://arxiv.org/pdf/2010.01412.pdf

          代碼:https://github.com/moskomule/sam.pytorch

          動機(jī)來自先前的工作,在此基礎(chǔ)上,我們提出了一種新的、有效的方法來同時減小損失值和損失的銳度。具體來說,在我們的處理過程中,進(jìn)行銳度感知最小化(SAM),在領(lǐng)域內(nèi)尋找具有均勻的低損失值的參數(shù)。這個公式產(chǎn)生了一個最小-最大優(yōu)化問題,在這個問題上梯度下降可以有效地執(zhí)行。我們提出的實證結(jié)果表明,SAM在各種基準(zhǔn)數(shù)據(jù)集上都改善了的模型泛化。

          在深度學(xué)習(xí)中,我們使用SGD/Adam等優(yōu)化算法在我們的模型中實現(xiàn)收斂,從而找到全局最小值,即訓(xùn)練數(shù)據(jù)集中損失較低的點。但等幾種研究表明,許多網(wǎng)絡(luò)可以很容易地記住訓(xùn)練數(shù)據(jù)并有能力隨時overfit,為了防止這個問題,增強(qiáng)泛化能力,谷歌研究人員發(fā)表了一篇新論文叫做Sharpness Awareness Minimization,在CIFAR10上以及其他的數(shù)據(jù)集上達(dá)到了最先進(jìn)的結(jié)果。

          在本文中,我們將看看為什么SAM可以實現(xiàn)更好的泛化,以及我們?nèi)绾卧赑ytorch中實現(xiàn)SAM。

          SAM的原理是什么?

          在梯度下降或任何其他優(yōu)化算法中,我們的目標(biāo)是找到一個具有低損失值的參數(shù)。但是,與其他常規(guī)的優(yōu)化方法相比,SAM實現(xiàn)了更好的泛化,它將重點放在領(lǐng)域內(nèi)尋找具有均勻的低損失值的參數(shù)(而不是只有參數(shù)本身具有低損失值)上。

          由于計算鄰域參數(shù)而不是計算單個參數(shù),損失超平面比其他優(yōu)化方法更平坦,這反過來增強(qiáng)了模型的泛化。

          (左))用SGD訓(xùn)練的ResNet收斂到的一個尖銳的最小值。(右)用SAM訓(xùn)練的相同的ResNet收斂到的一個平坦的最小值。

          注意:SAM不是一個新的優(yōu)化器,它與其他常見的優(yōu)化器一起使用,比如SGD/Adam。

          在Pytorch中實現(xiàn)SAM

          在Pytorch中實現(xiàn)SAM非常簡單和直接

          import torch

          class SAM(torch.optim.Optimizer):
              def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
                  assert rho >= 0.0f"Invalid rho, should be non-negative: {rho}"

                  defaults = dict(rho=rho, **kwargs)
                  super(SAM, self).__init__(params, defaults)

                  self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
                  self.param_groups = self.base_optimizer.param_groups

              @torch.no_grad()
              def first_step(self, zero_grad=False):
                  grad_norm = self._grad_norm()
                  for group in self.param_groups:
                      scale = group["rho"] / (grad_norm + 1e-12)

                      for p in group["params"]:
                          if p.grad is Nonecontinue
                          e_w = p.grad * scale.to(p)
                          p.add_(e_w)  # climb to the local maximum "w + e(w)"
                          self.state[p]["e_w"] = e_w

                  if zero_grad: self.zero_grad()

              @torch.no_grad()
              def second_step(self, zero_grad=False):
                  for group in self.param_groups:
                      for p in group["params"]:
                          if p.grad is Nonecontinue
                          p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

                  self.base_optimizer.step()  # do the actual "sharpness-aware" update

                  if zero_grad: self.zero_grad()


              def _grad_norm(self):
                  shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
                  norm = torch.norm(
                              torch.stack([
                                  p.grad.norm(p=2).to(shared_device)
                                  for group in self.param_groups for p in group["params"]
                                  if p.grad is not None
                              ]),
                              p=2
                         )
                  return norm

          代碼取自非官方的Pytorch實現(xiàn)。

          代碼解釋:

          • 首先,我們從Pytorch繼承優(yōu)化器類來創(chuàng)建一個優(yōu)化器,盡管SAM不是一個新的優(yōu)化器,而是在需要繼承該類的每一步更新梯度(在基礎(chǔ)優(yōu)化器的幫助下)。
          • 該類接受模型參數(shù)、基本優(yōu)化器和rho, rho是計算最大損失的鄰域大小。
          • 在進(jìn)行下一步之前,讓我們先看看文中提到的偽代碼,它將幫助我們在沒有數(shù)學(xué)的情況下理解上述代碼。

          • 正如我們在計算第一次反向傳遞后的偽代碼中看到的,我們計算epsilon并將其添加到參數(shù)中,這些步驟是在上述python代碼的方法first_step中實現(xiàn)的。
          • 現(xiàn)在在計算了第一步之后,我們必須回到之前的權(quán)重來計算基礎(chǔ)優(yōu)化器的實際步驟,這些步驟在函數(shù)second_step中實現(xiàn)。
          • 函數(shù)_grad_norm用于返回矩陣向量的norm,即偽代碼的第10行
          • 在構(gòu)建這個類后,你可以簡單地使用它為你的深度學(xué)習(xí)項目通過以下的訓(xùn)練函數(shù)片段。
          from sam import SAM
          ...

          model = YourModel()
          base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
          optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
          ...

          for input, output in data:

            # first forward-backward pass
            loss = loss_function(output, model(input))  # use this loss for any training statistics
            loss.backward()
            optimizer.first_step(zero_grad=True)
            
            # second forward-backward pass
            loss_function(output, model(input)).backward()  # make sure to do a full forward pass
            optimizer.second_step(zero_grad=True)
          ...

          總結(jié)

          雖然SAM的泛化效果較好,但是這種方法的主要缺點是,由于前后兩次計算銳度感知梯度,需要花費(fèi)兩倍的訓(xùn)練時間。除此之外,SAM還在最近發(fā)布的NFNETS上證明了它的效果,這是ImageNet目前的最高水平,在未來,我們可以期待越來越多的論文利用這一技術(shù)來實現(xiàn)更好的泛化。


          英文原文:https://pub.towardsai.net/we-dont-need-to-worry-about-overfitting-anymore-9fb31a154c81


          國產(chǎn)小眾瀏覽器因屏蔽視頻廣告,被索賠100萬(后續(xù))

          年輕人“不講武德”:因看黃片上癮,把網(wǎng)站和786名女主播起訴了

          中國聯(lián)通官網(wǎng)被發(fā)現(xiàn)含木馬腳本,可向用戶推廣色情APP

          張一鳴:每個逆襲的年輕人,都具備的底層能力


          關(guān)


          學(xué)西學(xué)學(xué)運(yùn)護(hù)質(zhì)結(jié)關(guān)[]學(xué)習(xí)進(jìn)


          瀏覽 160
          點贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  97超碰在线播放 | 国产精品久久无码 | 秋霞网一区二区 | 东方欧美色图东方亚洲色图 | 国产一级婬乱片免费 |