<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          EMA在detectron2中的實(shí)現(xiàn)

          共 20299字,需瀏覽 41分鐘

           ·

          2021-09-18 21:04

          點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師

          設(shè)為星標(biāo),干貨直達(dá)!


          近期很流行的一些檢測模型如YOLOv5和YOLOX都包含了很多的tricks,如數(shù)據(jù)增強(qiáng)(MixUp, Mosaic)等,其中EMA也是一種常采用的trick。EMA全稱為Exponential Moving Average,最早是在TensorFlow中出現(xiàn)(具體實(shí)現(xiàn)為tf.train.ExponentialMovingAverage),簡單來說,在模型訓(xùn)練過程中對模型參數(shù)計(jì)算指數(shù)移動(dòng)平均,得到的模型參數(shù)要比最后訓(xùn)練得到的模型參數(shù)在效果上可能要好一點(diǎn)。從某種意義上來看,EMA有點(diǎn)像模型集成,但是它在測試時(shí)不需要額外的負(fù)擔(dān),在訓(xùn)練過程只是多消耗一份顯存(多一份模型參數(shù))以及訓(xùn)練過程稍多一點(diǎn)開銷(對參數(shù)進(jìn)行移動(dòng)平均,耗時(shí)很小)。

          EMA的實(shí)現(xiàn)也很簡單,對模型參數(shù)params只需要多維護(hù)一份參數(shù)ema_params就好,然后在每個(gè)訓(xùn)練step后,對每一個(gè)模型參數(shù)進(jìn)行移動(dòng)平均:

          這里的decay是一個(gè)超參數(shù),一般取值接近1,比如設(shè)置為0.999。可以看到EMA比較通用,幾乎適用于任何模型訓(xùn)練中。

          目前商湯開源的mmdet框架已經(jīng)復(fù)現(xiàn)了YOLOX,里面也包含了EMA的實(shí)現(xiàn)。而目前Facebook AI的detectron2還沒有包含EMA的實(shí)現(xiàn),但是其移動(dòng)端版本D2Go已經(jīng)實(shí)現(xiàn)了EMA,兩個(gè)版本其實(shí)是互通的,只有略微的差別。這里就講一下如何將D2Go的EMA應(yīng)用到detectron2中,這主要包括三個(gè)部分:模型中添加EMA參數(shù)、訓(xùn)練過程中進(jìn)行更新以及測試時(shí)使用EMA參數(shù)。

          EMA需要多維護(hù)一份模型參數(shù),就是EMA參數(shù),這里定義一個(gè)EMAState類來存儲EMA參數(shù),這個(gè)類里面的state字典存儲EMA參數(shù)。這里的get_model_state_iterator方法是獲得模型的參數(shù),包括訓(xùn)練參數(shù)params以及buffers,BN的一些參數(shù)moving_mean和moving_var屬于buffers,一般情況下對BN的moving_mean和moving_var也進(jìn)行EMA效果會(huì)更好一點(diǎn)。

          class EMAState(object):
              def __init__(self):
                  self.state = {}

              @classmethod
              def FromModel(cls, model: torch.nn.Module, device: str = ""):
                  ret = cls()
                  ret.save_from(model, device)
                  return ret

              def save_from(self, model: torch.nn.Module, device: str = ""):
                  """Save model state from `model` to this object"""
                  for name, val in self.get_model_state_iterator(model):
                      val = val.detach().clone()
                      self.state[name] = val.to(device) if device else val

              def apply_to(self, model: torch.nn.Module):
                  """Apply state to `model` from this object"""
                  with torch.no_grad():
                      for name, val in self.get_model_state_iterator(model):
                          assert (
                              name in self.state
                          ), f"Name {name} not existed, available names {self.state.keys()}"
                          val.copy_(self.state[name])

              def get_ema_model(self, model):
                  ret = copy.deepcopy(model)
                  self.apply_to(ret)
                  return ret

              @property
              def device(self):
                  if not self.has_inited():
                      return None
                  return next(iter(self.state.values())).device

              def to(self, device):
                  for name in self.state:
                      self.state[name] = self.state[name].to(device)
                  return self

              def has_inited(self):
                  return self.state

              def clear(self):
                  self.state.clear()
                  return self

              def get_model_state_iterator(self, model):
                  param_iter = model.named_parameters()
                  buffer_iter = model.named_buffers()
                  return itertools.chain(param_iter, buffer_iter)

              def state_dict(self):
                  return self.state

              def load_state_dict(self, state_dict, strict: bool = True):
                  self.clear()
                  for x, y in state_dict.items():
                      self.state[x] = y
                  return torch.nn.modules.module._IncompatibleKeys(
                      missing_keys=[], unexpected_keys=[]
                  )

              def __repr__(self):
                  ret = f"EMAState(state=[{','.join(self.state.keys())}])"
                  return ret

          這樣在d2的Trainer中,創(chuàng)建model的同時(shí)也定義EMA,添加后model會(huì)多一個(gè)model_ema屬性,它是EMAState的一個(gè)實(shí)例:

          def may_build_model_ema(cfg, model):
              if not cfg.MODEL_EMA.ENABLED:
                  return
              model = _remove_ddp(model)
              assert not hasattr(
                  model, "ema_state"
              ), "Name `ema_state` is reserved for model ema."
              model.ema_state = EMAState() # 添加到model的屬性中
              logger.info("Using Model EMA.")
              
              
          class Trainer(DefaultTrainer):
              
              # override build_model,在里面添加ema
              @classmethod
              def build_model(cls, cfg):
                  """
                  Returns:
                      torch.nn.Module:
                  It now calls :func:`detectron2.modeling.build_model`.
                  Overwrite it if you'd like a different model.
                  """

                  model = build_model(cfg)
                  logger = logging.getLogger(__name__)
                  logger.info("Model:\n{}".format(model))
                  # add model EMA if enabled
                  model_ema.may_build_model_ema(cfg, model)
                  return model

          上面實(shí)現(xiàn)了ema的添加,但是在訓(xùn)練后還需要保存ema參數(shù),這可以通過d2的DetectionCheckpointer來實(shí)現(xiàn),DetectionCheckpointer在創(chuàng)建時(shí)可以傳入額外的checkpointable objects,在save和load時(shí)除了模型參數(shù)也會(huì)同步對這些objects進(jìn)行保存和加載。checkpointable objects需要實(shí)現(xiàn)兩個(gè)方法:state_dict()和load_state_dict(),而前面定義的EMAState類也包含了這兩個(gè)方法,用于save和load對應(yīng)的ema參數(shù)。具體的實(shí)現(xiàn)代碼如下:

          class Trainer(DefaultTrainer):
              
             def __init__(self, cfg):
               
               # add model EMA
                  kwargs = {
                      'trainer': weakref.proxy(self),
                  }
                  kwargs.update(model_ema.may_get_ema_checkpointer(cfg, model)) # 添加ema到checkpointables
                  self.checkpointer = DetectionCheckpointer(
                      # Assume you want to save checkpoints together with logs/statistics
                      model,
                      cfg.OUTPUT_DIR,
                      **kwargs,
                  )

          上面完成了第一個(gè)部分,就是在模型中添加ema參數(shù),第二個(gè)要做的工作就是實(shí)現(xiàn)ema參數(shù)在訓(xùn)練過程的更新,首先定義一個(gè)EMAUpdater,其中update方法用來進(jìn)行一次ema更新:

          class EMAUpdater(object):
              """Model Exponential Moving Average
              Keep a moving average of everything in the model state_dict (parameters and
              buffers). This is intended to allow functionality like
              https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
              Note:  It's very important to set EMA for ALL network parameters (instead of
              parameters that require gradient), including batch-norm moving average mean
              and variance.  This leads to significant improvement in accuracy.
              For example, for EfficientNetB3, with default setting (no mixup, lr exponential
              decay) without bn_sync, the EMA accuracy with EMA on params that requires
              gradient is 79.87%, while the corresponding accuracy with EMA on all params
              is 80.61%.
              Also, bn sync should be switched on for EMA.
              """


              def __init__(self, state: EMAState, decay: float = 0.999, device: str = ""):
                  self.decay = decay
                  self.device = device

                  self.state = state

              def init_state(self, model):
                  self.state.clear()
                  self.state.save_from(model, self.device)

              def update(self, model):
                  with torch.no_grad():
                      for name, val in self.state.get_model_state_iterator(model):
                          ema_val = self.state.state[name]
                          if self.device:
                              val = val.to(self.device)
                          # 指數(shù)移動(dòng)平均
                          ema_val.copy_(ema_val * self.decay + val * (1.0 - self.decay))

          要實(shí)現(xiàn)訓(xùn)練過程中的更新,可以采用hook的方式,這里定義一個(gè)EMAHook,這里主要是在after_step方法中加入ema的update:

          class EMAHook(HookBase):
              def __init__(self, cfg, model):
                  model = _remove_ddp(model)
                  assert cfg.MODEL_EMA.ENABLED
                  assert hasattr(
                      model, "ema_state"
                  ), "Call `may_build_model_ema` first to initilaize the model ema"
                  self.model = model
                  self.ema = self.model.ema_state
                  self.device = cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE
                  self.ema_updater = EMAUpdater(
                      self.model.ema_state, decay=cfg.MODEL_EMA.DECAY, device=self.device
                  )

              def before_train(self):
                  if self.ema.has_inited():
                      self.ema.to(self.device)
                  else:
                      self.ema_updater.init_state(self.model)

              def after_train(self):
                  pass

              def before_step(self):
                  pass

              def after_step(self):
                  if not self.model.train:
                      return
                  self.ema_updater.update(self.model)

          然后把EMAHook加到trainer中的hooks里:

             def build_hooks(self):
                  """
                  Build a list of default hooks, including timing, evaluation,
                  checkpointing, lr scheduling, precise BN, writing events.
                  Returns:
                      list[HookBase]:
                  """

                  cfg = self.cfg.clone()
                  cfg.defrost()
                  cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

                  ret = [
                      hooks.IterationTimer(),
                      model_ema.EMAHook(self.cfg, self.model) if cfg.MODEL_EMA.ENABLED else None# add EMA hook
                      hooks.LRScheduler(),
                      hooks.PreciseBN(
                          # Run at the same freq as (but before) evaluation.
                          cfg.TEST.EVAL_PERIOD,
                          self.model,
                          # Build a new data loader to not affect training
                          self.build_train_loader(cfg),
                          cfg.TEST.PRECISE_BN.NUM_ITER,
                      )
                      if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
                      else None,
                  ]

          最后一個(gè)要實(shí)現(xiàn)的就是如何在測試時(shí)采用ema參數(shù),這里采用的方法是每次進(jìn)行test時(shí),先將model參數(shù)保存一個(gè)副本,然后用ema參數(shù)替換,完成測試后再用保存的副本復(fù)原回來,在實(shí)現(xiàn)上,可以采用python的上下文管理器來巧妙地實(shí)現(xiàn):

          @contextmanager
          def apply_model_ema_and_restore(model, state=None):
              """Apply ema stored in `model` to model and returns a function to restore
              the weights are applied
              """

              model = _remove_ddp(model)

              if state is None:
                  state = get_model_ema_state(model)

              old_state = EMAState.FromModel(model, state.device) # 創(chuàng)建當(dāng)前模型參數(shù)副本
              state.apply_to(model) # 用ema替換模型參數(shù)
              yield old_state
              old_state.apply_to(model) # 恢復(fù)模型參數(shù)

          用這個(gè)上下文管理器對test進(jìn)行包裝,就可以實(shí)現(xiàn)想要的效果了:

              @classmethod
              def do_test(cls, cfg, model, evaluators=None):
                  # model with ema weights
                  logger = logging.getLogger("detectron2")
                  if cfg.MODEL_EMA.ENABLED:
                      logger.info("Run evaluation with EMA.")
                      with model_ema.apply_model_ema_and_restore(model):
                          results = cls.test(cfg, model, evaluators=evaluators)
                  else:
                      results = cls.test(cfg, model, evaluators=evaluators)
                  return results

          完整的代碼放在了github上,歡迎試用和star(https://github.com/xiaohu2015/detectron2_ema)。我初步用RetinaNet_R_50_FPN_1x測試的話,采用ema比原始效果要好一點(diǎn)(37.23 vs 37.18),而YOLOv5采用ema能提升1~2個(gè)點(diǎn)的。在YOLOv5中,ema的實(shí)現(xiàn)有一個(gè)額外的trick,那就是在訓(xùn)練前期,采用較小的decay,然后逐步增到默認(rèn)值,因?yàn)榍捌谀P陀?xùn)練速度快,應(yīng)該對ema參數(shù)更新更激進(jìn)一些,具體的實(shí)現(xiàn)如下:

          self.decay = lambda x: decay * (1 - math.exp(-x / 2000))  # decay exponential ramp (to help early epochs)

          這個(gè)實(shí)現(xiàn)應(yīng)該很容易在d2的EMA中添加,有時(shí)間再更新(mmdet的ema已經(jīng)實(shí)現(xiàn)這個(gè)功能了)。

          參考

          1. fvcore
          2. d2go
          3. yolov5



          推薦閱讀

          CPVT:一個(gè)卷積就可以隱式編碼位置信息

          SOTA模型Swin Transformer是如何煉成的!

          谷歌AI用30億數(shù)據(jù)訓(xùn)練了一個(gè)20億參數(shù)Vision Transformer模型,在ImageNet上達(dá)到新的SOTA!

          BatchNorm的避坑指南(上)

          BatchNorm的避坑指南(下)

          目標(biāo)跟蹤入門篇-相關(guān)濾波

          SOTA模型Swin Transformer是如何煉成的!

          MoCo V3:我并不是你想的那樣!

          Transformer在語義分割上的應(yīng)用

          "未來"的經(jīng)典之作ViT:transformer is all you need!

          PVT:可用于密集任務(wù)backbone的金字塔視覺transformer!

          漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA

          Transformer為何能闖入CV界秒殺CNN?

          不妨試試MoCo,來替換ImageNet上pretrain模型!


          機(jī)器學(xué)習(xí)算法工程師


                                              一個(gè)用心的公眾號

          瀏覽 74
          點(diǎn)贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評論
          圖片
          表情
          推薦
          點(diǎn)贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  成人做爰A片免费看网站 | 51成人做爰www免费看网站 | 91亚洲精品久久久久蜜桃 | 性无码一区二区三区 | 午夜性爱在线 |