詳解目標檢測(MMdetection)-HOOK機制

極市導(dǎo)讀
?本文作者分享了自己對MMdetetion中鉤子(HOOK)機制的看法,包括了作用理解和典型案例等內(nèi)容。>>加入極市CV技術(shù)交流群,走在計算機視覺的最前沿
最近做了一段時間的目標檢測,不得不說檢測這塊還是相對比較復(fù)雜的,在熟悉項目的同時也確實學(xué)習(xí)到了很多有用的東西。MMdetetion是現(xiàn)在最著名、算法包最多并且使用人數(shù)最多的訓(xùn)練框架,其中的源碼非常值得學(xué)習(xí),今天總結(jié)下我對其中HOOK(鉤子)機制的理解。
MMdetection最近更新很多,我以2.4.0版本的代碼進行解讀,分享自己的理解,也吸納觀眾的點評。HOOK、Runer的定義在MMCV當中,MMdetection和MMCV是版本匹配的,我這里使用的是MMCV 1.1.2的代碼。(HOOK相關(guān)的定義主要在MMCV中,下面用的代碼都是摘自于MMCV)。
1.HOOK機制的作用
MMdetection中的HOOK可以理解為一種觸發(fā)器,也可以理解為一種訓(xùn)練框架的架構(gòu)規(guī)范,它規(guī)定了在算法訓(xùn)練過程中的種種操作,并且我們可以通過繼承HOOK類,然后注冊HOOK自定義我們想要的操作。
首先看一下HOOK的基類定義
#?Copyright?(c)?Open-MMLab.?All?rights?reserved.
from?mmcv.utils?import?Registry
HOOKS?=?Registry('hook')
class?Hook:
????def?before_run(self,?runner):
????????pass
????def?after_run(self,?runner):
????????pass
????def?before_epoch(self,?runner):
????????pass
????def?after_epoch(self,?runner):
????????pass
????def?before_iter(self,?runner):
????????pass
????def?after_iter(self,?runner):
????????pass
????def?before_train_epoch(self,?runner):
????????self.before_epoch(runner)
????def?before_val_epoch(self,?runner):
????????self.before_epoch(runner)
????def?after_train_epoch(self,?runner):
????????self.after_epoch(runner)
????def?after_val_epoch(self,?runner):
????????self.after_epoch(runner)
????def?before_train_iter(self,?runner):
????????self.before_iter(runner)
????def?before_val_iter(self,?runner):
????????self.before_iter(runner)
????def?after_train_iter(self,?runner):
????????self.after_iter(runner)
????def?after_val_iter(self,?runner):
????????self.after_iter(runner)
????def?every_n_epochs(self,?runner,?n):
????????return?(runner.epoch?+?1)?%?n?==?0?if?n?>?0?else?False
????def?every_n_inner_iters(self,?runner,?n):
????????return?(runner.inner_iter?+?1)?%?n?==?0?if?n?>?0?else?False
????def?every_n_iters(self,?runner,?n):
????????return?(runner.iter?+?1)?%?n?==?0?if?n?>?0?else?False
????def?end_of_epoch(self,?runner):
????????return?runner.inner_iter?+?1?==?len(runner.data_loader)
可以說基類函數(shù)中定義了許多我們在模型訓(xùn)練中需要用到的一些功能,如果想定義一些操作我們就可以繼承這個類并定制化我們的功能,可以看到HOOK中每一個參數(shù)都是有runner作為參數(shù)傳入的。關(guān)于Runner的作用下一篇文章接著說,簡而言之,Runner是一個模型訓(xùn)練的工廠,在其中我們可以加載數(shù)據(jù)、訓(xùn)練、驗證以及梯度backward等等全套流程。MMdetection在設(shè)計的時候也為runner傳入豐富的參數(shù),定義了一個非常好的訓(xùn)練范式。在你的每一個hook函數(shù)中,都可以對runner進行你想要的操作。
而HOOK是怎么嵌套進runner中的呢?其實是在Runner中定義了一個hook的list,list中的每一個元素就是一個實例化的HOOK對象。其中提供了兩種注冊hook的方法,register_hook是傳入一個實例化的HOOK對象,并將它插入到一個列表中,register_hook_from_cfg是傳入一個配置項,根據(jù)配置項來實例化HOOK對象并插入到列表中。當然第二種方法又是MMLab的開源生態(tài)中定義的一種基礎(chǔ)方法mmcv.build_from_cfg了,無論在MMdetection還是其他MMLab開源的算法框架中,都遵循著MMCV的這套基于配置項實例化對象的方法。畢竟MMCV是提供了一個基礎(chǔ)的功能,服務(wù)于各個算法框架,這也是為什么MMLab的代碼高質(zhì)量的原因。不僅僅是算法的復(fù)現(xiàn),更是架構(gòu)、編程范式的一種體現(xiàn),真·代碼如詩。
def?register_hook(self,?hook,?priority='NORMAL'):
????????"""Register?a?hook?into?the?hook?list.
????????The?hook?will?be?inserted?into?a?priority?queue,?with?the?specified
????????priority?(See?:class:`Priority`?for?details?of?priorities).
????????For?hooks?with?the?same?priority,?they?will?be?triggered?in?the?same
????????order?as?they?are?registered.
????????Args:
????????????hook?(:obj:`Hook`):?The?hook?to?be?registered.
????????????priority?(int?or?str?or?:obj:`Priority`):?Hook?priority.
????????????????Lower?value?means?higher?priority.
????????"""
????????assert?isinstance(hook,?Hook)
????????if?hasattr(hook,?'priority'):
????????????raise?ValueError('"priority"?is?a?reserved?attribute?for?hooks')
????????priority?=?get_priority(priority)
????????hook.priority?=?priority
????????#?insert?the?hook?to?a?sorted?list
????????inserted?=?False
????????# hook是分優(yōu)先級插入到list中的,在MMdetection中不同的HOOK是有優(yōu)先級的,為什么呢?稍后在hook的調(diào)用中解釋哈
????????for?i?in?range(len(self._hooks)?-?1,?-1,?-1):
????????????if?priority?>=?self._hooks[i].priority:
????????????????self._hooks.insert(i?+?1,?hook)
????????????????inserted?=?True
????????????????break
????????if?not?inserted:
????????????self._hooks.insert(0,?hook)
????def?register_hook_from_cfg(self,?hook_cfg):
????????"""Register?a?hook?from?its?cfg.
????????Args:
????????????hook_cfg?(dict):?Hook?config.?It?should?have?at?least?keys?'type'
??????????????and?'priority'?indicating?its?type?and?priority.
????????Notes:
????????????The?specific?hook?class?to?register?should?not?use?'type'?and
????????????'priority'?arguments?during?initialization.
????????"""
????????hook_cfg?=?hook_cfg.copy()
????????priority?=?hook_cfg.pop('priority',?'NORMAL')
????????hook?=?mmcv.build_from_cfg(hook_cfg,?HOOKS)
????????self.register_hook(hook,?priority=priority)
調(diào)用HOOK函數(shù)
def?call_hook(self,?fn_name):
????????"""Call?all?hooks.
????????Args:
????????????fn_name?(str):?The?function?name?in?each?hook?to?be?called,?such?as
????????????????"before_train_epoch".
????????"""
????????for?hook?in?self._hooks:
????????????getattr(hook,?fn_name)(self)
可以看到HOOK是調(diào)用的時候是遍歷List,然后根據(jù)HOOK的名字來調(diào)用。這也是為什么要區(qū)分優(yōu)先級的原因,優(yōu)先級越高的放在List的前面,這樣就能更快地被調(diào)用。當你想用_before_run_epoch_來做A和B兩件事情的時候,在runner里面就是調(diào)用一次self.before_run_epoch,但是先做A還是先做B,就是通過不同的HOOK的優(yōu)先級來決定了。比如在evaluation的時候?qū)π枰鰷y試,但是測試前對參數(shù)做滑動平均。比如emaHOOK中的72行,也寫明了要在測試之前做指數(shù)滑動平均。
def?after_train_epoch(self,?runner):
????????"""We?load?parameter?values?from?ema?backup?to?model?before?the
????????EvalHook."""
????????self._swap_ema_parameters()
checkpoint.py的HOOK中,同樣也定義了after_train_epoch函數(shù)如下:
@master_only
????def?after_train_epoch(self,?runner):
????????if?not?self.by_epoch?or?not?self.every_n_epochs(runner,?self.interval):
????????????return
????????runner.logger.info(f'Saving?checkpoint?at?{runner.epoch?+?1}?epochs')
????????if?not?self.out_dir:
????????????self.out_dir?=?runner.work_dir
????????runner.save_checkpoint(
????????????self.out_dir,?save_optimizer=self.save_optimizer,?**self.args)
????????#?remove?other?checkpoints
????????if?self.max_keep_ckpts?>?0:
????????????filename_tmpl?=?self.args.get('filename_tmpl',?'epoch_{}.pth')
????????????current_epoch?=?runner.epoch?+?1
????????????for?epoch?in?range(current_epoch?-?self.max_keep_ckpts,?0,?-1):
????????????????ckpt_path?=?os.path.join(self.out_dir,
?????????????????????????????????????????filename_tmpl.format(epoch))
????????????????if?os.path.exists(ckpt_path):
????????????????????os.remove(ckpt_path)
????????????????else:
????????????????????break
從測試代碼中可以看到不同的HOOK雖然都是重寫了after_train_epoch函數(shù),但是調(diào)用的順序還是先調(diào)用ema.py中的,然后再調(diào)用checkpoint.py中的after_train_epoch。
resume_ema_hook?=?EMAHook(
????????momentum=0.5,?warm_up=0,?resume_from=f'{work_dir}/epoch_1.pth')
????runner?=?_build_demo_runner()
????runner.model?=?demo_model
????#?設(shè)置了HIGHREST的優(yōu)先級
????runner.register_hook(resume_ema_hook,?priority='HIGHEST')
????checkpointhook?=?CheckpointHook(interval=1,?by_epoch=True)
????runner.register_hook(checkpointhook)
????runner.run([loader,?loader],?[('train',?1),?('val',?1)],?2)
具體的優(yōu)先級定義有以下7種,作為HOOK的類成員屬性。具體定義在鏈接中。
+------------+------------+
| Level | Value |
+============+============+
| HIGHEST | 0 |
+------------+------------+
| VERY_HIGH | 10 |
+------------+------------+
| HIGH | 30 |
+------------+------------+
| NORMAL | 50 |
+------------+------------+
| LOW | 70 |
+------------+------------+
| VERY_LOW | 90 |
+------------+------------+
| LOWEST | 100 |
+------------+------------+
2.舉一個簡單的例子
最近打算好好鍛煉身體,健康生活,努力工作,我打算讓自己變得更加自律。我給自己定下了幾個條例,每天吃早飯之前得晨練30分鐘,運動完之后才會感覺充滿活力。每天吃午飯之前我得跑上一個實驗,吃完飯之后回來剛好可以看下中間結(jié)果,吃完午飯之后我感覺結(jié)果沒問題我需要午休30分鐘, 晚上下班前我如果沒什么事再鍛煉30分鐘。秉承著這樣的原則我給自己定義一個HOOK來規(guī)范我的生活。
定義我的HOOK
import?sys
class?HOOK:
????def?before_breakfast(self,?runner):
????????print('{}:吃早飯之前晨練30分鐘'.format(sys._getframe().f_code.co_name))
????def?after_breakfast(self,?runner):
????????print('{}:吃早飯之前晨練30分鐘'.format(sys._getframe().f_code.co_name))
????def?before_lunch(self,?runner):
????????print('{}:吃午飯之前跑上實驗'.format(sys._getframe().f_code.co_name))
????def?after_lunch(self,?runner):
????????print('{}:吃完午飯午休30分鐘'.format(sys._getframe().f_code.co_name))
????def?before_dinner(self,?runner):
????????print('{}:?沒想好做什么'.format(sys._getframe().f_code.co_name))
????def?after_dinner(self,?runner):
????????print('{}:?沒想好做什么'.format(sys._getframe().f_code.co_name))
????def?after_finish_work(self,?runner,?are_you_busy=False):
????????if?are_you_busy:
????????????print('{}:今天事賊多,還是加班吧'.format(sys._getframe().f_code.co_name))
????????else:
????????????print('{}:今天沒啥事,去鍛煉30分鐘'.format(sys._getframe().f_code.co_name))
定義我的Runner
class?Runner(object):
????def?__init__(self,?):
????????pass
????????self._hooks?=?[]
????def?register_hook(self,?hook):
????????#?這里不做優(yōu)先級判斷,直接在頭部插入HOOK
????????self._hooks.insert(0,?hook)
????def?call_hook(self,?hook_name):
????????for?hook?in?self._hooks:
????????????getattr(hook,?hook_name)(self)
????def?run(self):
????????print('開始啟動我的一天')
????????self.call_hook('before_breakfast')
????????self.call_hook('after_breakfast')
????????self.call_hook('before_lunch')
????????self.call_hook('after_lunch')
????????self.call_hook('before_dinner')
????????self.call_hook('after_dinner')
????????self.call_hook('after_finish_work')
????????print('~~睡覺~~')
運行main函數(shù),注冊HOOK并且調(diào)用Runner.run()開啟我的一天
from?MyHook?import?HOOK
from?MyRunner?import?Runner
runner?=?Runner()
hook?=?HOOK()
runner.register_hook(hook)
runner.run()
得到的輸出結(jié)果如下:
開始啟動我的一天
before_breakfast:吃早飯之前晨練30分鐘
after_breakfast:吃早飯之前晨練30分鐘
before_lunch:吃午飯之前跑上實驗
after_lunch:吃完午飯午休30分鐘
before_dinner:?沒想好做什么
after_dinner:?沒想好做什么
after_finish_work:今天沒啥事,去鍛煉30分鐘
~~睡覺~~
3.總結(jié)
MMdetection中的HOOK設(shè)計巧妙,很好地對算法訓(xùn)練、測試進行了抽象和解耦。每一個做上層算法模型的,都值得一看。感謝MMLab貢獻這么優(yōu)質(zhì)的代碼,讓我等凡夫俗子醍醐灌頂。
除了HOOK之外,這個代碼中還有很多優(yōu)質(zhì)的思想。比如Runner是怎么做到包辦一切的?注冊器這個中樞管理系統(tǒng)是怎么工作的?多卡訓(xùn)練的一些坑是怎么解決的?等等等等,我也在持續(xù)地學(xué)習(xí)和消化。路漫漫其修遠兮,吾將上下而求索。
一個小題目:我的代碼中每個函數(shù)輸出的時候都會打印出這個函數(shù)名,這個可以用_裝飾器_很方便地解決奧。裝飾器這個東西在MMLab的系列項目中有大量的應(yīng)用。其中對fp16的支持讓大家贊不絕口。接下來有時間,對Runner、Register、裝飾器這些東西好好盤一盤。
推薦閱讀

