EMA在detectron2中的實(shí)現(xiàn)
點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師”
設(shè)為星標(biāo),干貨直達(dá)!
近期很流行的一些檢測模型如YOLOv5和YOLOX都包含了很多的tricks,如數(shù)據(jù)增強(qiáng)(MixUp, Mosaic)等,其中EMA也是一種常采用的trick。EMA全稱為Exponential Moving Average,最早是在TensorFlow中出現(xiàn)(具體實(shí)現(xiàn)為tf.train.ExponentialMovingAverage),簡單來說,在模型訓(xùn)練過程中對模型參數(shù)計(jì)算指數(shù)移動(dòng)平均,得到的模型參數(shù)要比最后訓(xùn)練得到的模型參數(shù)在效果上可能要好一點(diǎn)。從某種意義上來看,EMA有點(diǎn)像模型集成,但是它在測試時(shí)不需要額外的負(fù)擔(dān),在訓(xùn)練過程只是多消耗一份顯存(多一份模型參數(shù))以及訓(xùn)練過程稍多一點(diǎn)開銷(對參數(shù)進(jìn)行移動(dòng)平均,耗時(shí)很小)。
EMA的實(shí)現(xiàn)也很簡單,對模型參數(shù)params只需要多維護(hù)一份參數(shù)ema_params就好,然后在每個(gè)訓(xùn)練step后,對每一個(gè)模型參數(shù)進(jìn)行移動(dòng)平均:
這里的decay是一個(gè)超參數(shù),一般取值接近1,比如設(shè)置為0.999。可以看到EMA比較通用,幾乎適用于任何模型訓(xùn)練中。
目前商湯開源的mmdet框架已經(jīng)復(fù)現(xiàn)了YOLOX,里面也包含了EMA的實(shí)現(xiàn)。而目前Facebook AI的detectron2還沒有包含EMA的實(shí)現(xiàn),但是其移動(dòng)端版本D2Go已經(jīng)實(shí)現(xiàn)了EMA,兩個(gè)版本其實(shí)是互通的,只有略微的差別。這里就講一下如何將D2Go的EMA應(yīng)用到detectron2中,這主要包括三個(gè)部分:模型中添加EMA參數(shù)、訓(xùn)練過程中進(jìn)行更新以及測試時(shí)使用EMA參數(shù)。
EMA需要多維護(hù)一份模型參數(shù),就是EMA參數(shù),這里定義一個(gè)EMAState類來存儲EMA參數(shù),這個(gè)類里面的state字典存儲EMA參數(shù)。這里的get_model_state_iterator方法是獲得模型的參數(shù),包括訓(xùn)練參數(shù)params以及buffers,BN的一些參數(shù)moving_mean和moving_var屬于buffers,一般情況下對BN的moving_mean和moving_var也進(jìn)行EMA效果會(huì)更好一點(diǎn)。
class EMAState(object):
def __init__(self):
self.state = {}
@classmethod
def FromModel(cls, model: torch.nn.Module, device: str = ""):
ret = cls()
ret.save_from(model, device)
return ret
def save_from(self, model: torch.nn.Module, device: str = ""):
"""Save model state from `model` to this object"""
for name, val in self.get_model_state_iterator(model):
val = val.detach().clone()
self.state[name] = val.to(device) if device else val
def apply_to(self, model: torch.nn.Module):
"""Apply state to `model` from this object"""
with torch.no_grad():
for name, val in self.get_model_state_iterator(model):
assert (
name in self.state
), f"Name {name} not existed, available names {self.state.keys()}"
val.copy_(self.state[name])
def get_ema_model(self, model):
ret = copy.deepcopy(model)
self.apply_to(ret)
return ret
@property
def device(self):
if not self.has_inited():
return None
return next(iter(self.state.values())).device
def to(self, device):
for name in self.state:
self.state[name] = self.state[name].to(device)
return self
def has_inited(self):
return self.state
def clear(self):
self.state.clear()
return self
def get_model_state_iterator(self, model):
param_iter = model.named_parameters()
buffer_iter = model.named_buffers()
return itertools.chain(param_iter, buffer_iter)
def state_dict(self):
return self.state
def load_state_dict(self, state_dict, strict: bool = True):
self.clear()
for x, y in state_dict.items():
self.state[x] = y
return torch.nn.modules.module._IncompatibleKeys(
missing_keys=[], unexpected_keys=[]
)
def __repr__(self):
ret = f"EMAState(state=[{','.join(self.state.keys())}])"
return ret
這樣在d2的Trainer中,創(chuàng)建model的同時(shí)也定義EMA,添加后model會(huì)多一個(gè)model_ema屬性,它是EMAState的一個(gè)實(shí)例:
def may_build_model_ema(cfg, model):
if not cfg.MODEL_EMA.ENABLED:
return
model = _remove_ddp(model)
assert not hasattr(
model, "ema_state"
), "Name `ema_state` is reserved for model ema."
model.ema_state = EMAState() # 添加到model的屬性中
logger.info("Using Model EMA.")
class Trainer(DefaultTrainer):
# override build_model,在里面添加ema
@classmethod
def build_model(cls, cfg):
"""
Returns:
torch.nn.Module:
It now calls :func:`detectron2.modeling.build_model`.
Overwrite it if you'd like a different model.
"""
model = build_model(cfg)
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
# add model EMA if enabled
model_ema.may_build_model_ema(cfg, model)
return model
上面實(shí)現(xiàn)了ema的添加,但是在訓(xùn)練后還需要保存ema參數(shù),這可以通過d2的DetectionCheckpointer來實(shí)現(xiàn),DetectionCheckpointer在創(chuàng)建時(shí)可以傳入額外的checkpointable objects,在save和load時(shí)除了模型參數(shù)也會(huì)同步對這些objects進(jìn)行保存和加載。checkpointable objects需要實(shí)現(xiàn)兩個(gè)方法:state_dict()和load_state_dict(),而前面定義的EMAState類也包含了這兩個(gè)方法,用于save和load對應(yīng)的ema參數(shù)。具體的實(shí)現(xiàn)代碼如下:
class Trainer(DefaultTrainer):
def __init__(self, cfg):
# add model EMA
kwargs = {
'trainer': weakref.proxy(self),
}
kwargs.update(model_ema.may_get_ema_checkpointer(cfg, model)) # 添加ema到checkpointables
self.checkpointer = DetectionCheckpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg.OUTPUT_DIR,
**kwargs,
)
上面完成了第一個(gè)部分,就是在模型中添加ema參數(shù),第二個(gè)要做的工作就是實(shí)現(xiàn)ema參數(shù)在訓(xùn)練過程的更新,首先定義一個(gè)EMAUpdater,其中update方法用來進(jìn)行一次ema更新:
class EMAUpdater(object):
"""Model Exponential Moving Average
Keep a moving average of everything in the model state_dict (parameters and
buffers). This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
Note: It's very important to set EMA for ALL network parameters (instead of
parameters that require gradient), including batch-norm moving average mean
and variance. This leads to significant improvement in accuracy.
For example, for EfficientNetB3, with default setting (no mixup, lr exponential
decay) without bn_sync, the EMA accuracy with EMA on params that requires
gradient is 79.87%, while the corresponding accuracy with EMA on all params
is 80.61%.
Also, bn sync should be switched on for EMA.
"""
def __init__(self, state: EMAState, decay: float = 0.999, device: str = ""):
self.decay = decay
self.device = device
self.state = state
def init_state(self, model):
self.state.clear()
self.state.save_from(model, self.device)
def update(self, model):
with torch.no_grad():
for name, val in self.state.get_model_state_iterator(model):
ema_val = self.state.state[name]
if self.device:
val = val.to(self.device)
# 指數(shù)移動(dòng)平均
ema_val.copy_(ema_val * self.decay + val * (1.0 - self.decay))
要實(shí)現(xiàn)訓(xùn)練過程中的更新,可以采用hook的方式,這里定義一個(gè)EMAHook,這里主要是在after_step方法中加入ema的update:
class EMAHook(HookBase):
def __init__(self, cfg, model):
model = _remove_ddp(model)
assert cfg.MODEL_EMA.ENABLED
assert hasattr(
model, "ema_state"
), "Call `may_build_model_ema` first to initilaize the model ema"
self.model = model
self.ema = self.model.ema_state
self.device = cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE
self.ema_updater = EMAUpdater(
self.model.ema_state, decay=cfg.MODEL_EMA.DECAY, device=self.device
)
def before_train(self):
if self.ema.has_inited():
self.ema.to(self.device)
else:
self.ema_updater.init_state(self.model)
def after_train(self):
pass
def before_step(self):
pass
def after_step(self):
if not self.model.train:
return
self.ema_updater.update(self.model)
然后把EMAHook加到trainer中的hooks里:
def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
ret = [
hooks.IterationTimer(),
model_ema.EMAHook(self.cfg, self.model) if cfg.MODEL_EMA.ENABLED else None, # add EMA hook
hooks.LRScheduler(),
hooks.PreciseBN(
# Run at the same freq as (but before) evaluation.
cfg.TEST.EVAL_PERIOD,
self.model,
# Build a new data loader to not affect training
self.build_train_loader(cfg),
cfg.TEST.PRECISE_BN.NUM_ITER,
)
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
else None,
]
最后一個(gè)要實(shí)現(xiàn)的就是如何在測試時(shí)采用ema參數(shù),這里采用的方法是每次進(jìn)行test時(shí),先將model參數(shù)保存一個(gè)副本,然后用ema參數(shù)替換,完成測試后再用保存的副本復(fù)原回來,在實(shí)現(xiàn)上,可以采用python的上下文管理器來巧妙地實(shí)現(xiàn):
@contextmanager
def apply_model_ema_and_restore(model, state=None):
"""Apply ema stored in `model` to model and returns a function to restore
the weights are applied
"""
model = _remove_ddp(model)
if state is None:
state = get_model_ema_state(model)
old_state = EMAState.FromModel(model, state.device) # 創(chuàng)建當(dāng)前模型參數(shù)副本
state.apply_to(model) # 用ema替換模型參數(shù)
yield old_state
old_state.apply_to(model) # 恢復(fù)模型參數(shù)
用這個(gè)上下文管理器對test進(jìn)行包裝,就可以實(shí)現(xiàn)想要的效果了:
@classmethod
def do_test(cls, cfg, model, evaluators=None):
# model with ema weights
logger = logging.getLogger("detectron2")
if cfg.MODEL_EMA.ENABLED:
logger.info("Run evaluation with EMA.")
with model_ema.apply_model_ema_and_restore(model):
results = cls.test(cfg, model, evaluators=evaluators)
else:
results = cls.test(cfg, model, evaluators=evaluators)
return results
完整的代碼放在了github上,歡迎試用和star(https://github.com/xiaohu2015/detectron2_ema)。我初步用RetinaNet_R_50_FPN_1x測試的話,采用ema比原始效果要好一點(diǎn)(37.23 vs 37.18),而YOLOv5采用ema能提升1~2個(gè)點(diǎn)的。在YOLOv5中,ema的實(shí)現(xiàn)有一個(gè)額外的trick,那就是在訓(xùn)練前期,采用較小的decay,然后逐步增到默認(rèn)值,因?yàn)榍捌谀P陀?xùn)練速度快,應(yīng)該對ema參數(shù)更新更激進(jìn)一些,具體的實(shí)現(xiàn)如下:
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
這個(gè)實(shí)現(xiàn)應(yīng)該很容易在d2的EMA中添加,有時(shí)間再更新(mmdet的ema已經(jīng)實(shí)現(xiàn)這個(gè)功能了)。
參考
fvcore d2go yolov5
推薦閱讀
谷歌AI用30億數(shù)據(jù)訓(xùn)練了一個(gè)20億參數(shù)Vision Transformer模型,在ImageNet上達(dá)到新的SOTA!
"未來"的經(jīng)典之作ViT:transformer is all you need!
PVT:可用于密集任務(wù)backbone的金字塔視覺transformer!
漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA
不妨試試MoCo,來替換ImageNet上pretrain模型!
機(jī)器學(xué)習(xí)算法工程師
一個(gè)用心的公眾號

