5 分鐘掌握 Python 中的 Hook 鉤子函數(shù)

1. 什么是Hook
what is hook ?鉤子hook,顧名思義,可以理解是一個掛鉤,作用是有需要的時候掛一個東西上去。具體的解釋是:鉤子函數(shù)是把我們自己實現(xiàn)的hook函數(shù)在某一時刻掛接到目標(biāo)掛載點上。 hook函數(shù)的作用 舉個例子,hook的概念在windows桌面軟件開發(fā)很常見,特別是各種事件觸發(fā)的機(jī)制; 比如C++的MFC程序中,要監(jiān)聽鼠標(biāo)左鍵按下的時間,MFC提供了一個onLeftKeyDown的鉤子函數(shù)。很顯然,MFC框架并沒有為我們實現(xiàn)onLeftKeyDown具體的操作,只是為我們提供一個鉤子,當(dāng)我們需要處理的時候,只要去重寫這個函數(shù),把我們需要操作掛載在這個鉤子里,如果我們不掛載,MFC事件觸發(fā)機(jī)制中執(zhí)行的就是空的操作。
hook函數(shù)是程序中預(yù)定義好的函數(shù),這個函數(shù)處于原有程序流程當(dāng)中(暴露一個鉤子出來) 我們需要再在有流程中鉤子定義的函數(shù)塊中實現(xiàn)某個具體的細(xì)節(jié),需要把我們的實現(xiàn),掛接或者注冊(register)到鉤子里,使得hook函數(shù)對目標(biāo)可用 hook 是一種編程機(jī)制,和具體的語言沒有直接的關(guān)系 如果從設(shè)計模式上看,hook模式是模板方法的擴(kuò)展 鉤子只有注冊的時候,才會使用,所以原有程序的流程中,沒有注冊或掛載時,執(zhí)行的是空(即沒有執(zhí)行任何操作)

2. hook實現(xiàn)例子
需要再插入隊列前,對數(shù)據(jù)進(jìn)行篩選 input_filter_fn插入隊列 insert_queue
class?ContentStash(object):
????"""
????content?stash?for?online?operation
????pipeline?is
????1.?input_filter:?filter?some?contents,?no?use?to?user
????2.?insert_queue(redis?or?other?broker):?insert?useful?content?to?queue
????"""
????def?__init__(self):
????????self.input_filter_fn?=?None
????????self.broker?=?[]
????def?register_input_filter_hook(self,?input_filter_fn):
????????"""
????????register?input?filter?function,?parameter?is?content?dict
????????Args:
????????????input_filter_fn:?input?filter?function
????????Returns:
????????"""
????????self.input_filter_fn?=?input_filter_fn
????def?insert_queue(self,?content):
????????"""
????????insert?content?to?queue
????????Args:
????????????content:?dict
????????Returns:
????????"""
????????self.broker.append(content)
????def?input_pipeline(self,?content,?use=False):
????????"""
????????pipeline?of?input?for?content?stash
????????Args:
????????????use:?is?use,?defaul?False
????????????content:?dict
????????Returns:
????????"""
????????if?not?use:
????????????return
????????#?input?filter
????????if?self.input_filter_fn:
????????????_filter?=?self.input_filter_fn(content)
????????????
????????#?insert?to?queue
????????if?not?_filter:
????????????self.insert_queue(content)
#?test
##?實現(xiàn)一個你所需要的鉤子實現(xiàn):比如如果content 包含time就過濾掉,否則插入隊列
def?input_filter_hook(content):
????"""
????test?input?filter?hook
????Args:
????????content:?dict
????Returns:?None?or?content
????"""
????if?content.get('time')?is?None:
????????return
????else:
????????return?content
#?原有程序
content?=?{'filename':?'test.jpg',?'b64_file':?"#test",?'data':?{"result":?"cat",?"probility":?0.9}}
content_stash?=?ContentStash('audit',?work_dir='')
#?掛上鉤子函數(shù),?可以有各種不同鉤子函數(shù)的實現(xiàn),但是要主要函數(shù)輸入輸出必須保持原有程序中一致,比如這里是content
content_stash.register_input_filter_hook(input_filter_hook)
#?執(zhí)行流程
content_stash.input_pipeline(content)
3. hook在開源框架中的應(yīng)用
3.1 keras
開始訓(xùn)練 訓(xùn)練一個epoch前 訓(xùn)練一個batch前 訓(xùn)練一個batch后 訓(xùn)練一個epoch后 評估驗證集 結(jié)束訓(xùn)練
訓(xùn)練一個epoch后我們要保存下訓(xùn)練的模型,在結(jié)束訓(xùn)練時用最好的模型執(zhí)行下測試集的效果等等。@keras_export('keras.callbacks.Callback')
class?Callback(object):
??"""Abstract?base?class?used?to?build?new?callbacks.
??Attributes:
??????params:?Dict.?Training?parameters
??????????(eg.?verbosity,?batch?size,?number?of?epochs...).
??????model:?Instance?of?`keras.models.Model`.
??????????Reference?of?the?model?being?trained.
??The?`logs`?dictionary?that?callback?methods
??take?as?argument?will?contain?keys?for?quantities?relevant?to
??the?current?batch?or?epoch?(see?method-specific?docstrings).
??"""
??def?__init__(self):
????self.validation_data?=?None??#?pylint:?disable=g-missing-from-attributes
????self.model?=?None
????#?Whether?this?Callback?should?only?run?on?the?chief?worker?in?a
????#?Multi-Worker?setting.
????#?TODO(omalleyt):?Make?this?attr?public?once?solution?is?stable.
????self._chief_worker_only?=?None
????self._supports_tf_logs?=?False
??def?set_params(self,?params):
????self.params?=?params
??def?set_model(self,?model):
????self.model?=?model
??@doc_controls.for_subclass_implementers
??@generic_utils.default
??def?on_batch_begin(self,?batch,?logs=None):
????"""A?backwards?compatibility?alias?for?`on_train_batch_begin`."""
??@doc_controls.for_subclass_implementers
??@generic_utils.default
??def?on_batch_end(self,?batch,?logs=None):
????"""A?backwards?compatibility?alias?for?`on_train_batch_end`."""
??@doc_controls.for_subclass_implementers
??def?on_epoch_begin(self,?epoch,?logs=None):
????"""Called?at?the?start?of?an?epoch.
????Subclasses?should?override?for?any?actions?to?run.?This?function?should?only
????be?called?during?TRAIN?mode.
????Arguments:
????????epoch:?Integer,?index?of?epoch.
????????logs:?Dict.?Currently?no?data?is?passed?to?this?argument?for?this?method
??????????but?that?may?change?in?the?future.
????"""
??@doc_controls.for_subclass_implementers
??def?on_epoch_end(self,?epoch,?logs=None):
????"""Called?at?the?end?of?an?epoch.
????Subclasses?should?override?for?any?actions?to?run.?This?function?should?only
????be?called?during?TRAIN?mode.
????Arguments:
????????epoch:?Integer,?index?of?epoch.
????????logs:?Dict,?metric?results?for?this?training?epoch,?and?for?the
??????????validation?epoch?if?validation?is?performed.?Validation?result?keys
??????????are?prefixed?with?`val_`.
????"""
??@doc_controls.for_subclass_implementers
??@generic_utils.default
??def?on_train_batch_begin(self,?batch,?logs=None):
????"""Called?at?the?beginning?of?a?training?batch?in?`fit`?methods.
????Subclasses?should?override?for?any?actions?to?run.
????Arguments:
????????batch:?Integer,?index?of?batch?within?the?current?epoch.
????????logs:?Dict,?contains?the?return?value?of?`model.train_step`.?Typically,
??????????the?values?of?the?`Model`'s?metrics?are?returned.??Example:
??????????`{'loss':?0.2,?'accuracy':?0.7}`.
????"""
????#?For?backwards?compatibility.
????self.on_batch_begin(batch,?logs=logs)
??@doc_controls.for_subclass_implementers
??@generic_utils.default
??def?on_train_batch_end(self,?batch,?logs=None):
????"""Called?at?the?end?of?a?training?batch?in?`fit`?methods.
????Subclasses?should?override?for?any?actions?to?run.
????Arguments:
????????batch:?Integer,?index?of?batch?within?the?current?epoch.
????????logs:?Dict.?Aggregated?metric?results?up?until?this?batch.
????"""
????#?For?backwards?compatibility.
????self.on_batch_end(batch,?logs=logs)
??@doc_controls.for_subclass_implementers
??@generic_utils.default
??def?on_test_batch_begin(self,?batch,?logs=None):
????"""Called?at?the?beginning?of?a?batch?in?`evaluate`?methods.
????Also?called?at?the?beginning?of?a?validation?batch?in?the?`fit`
????methods,?if?validation?data?is?provided.
????Subclasses?should?override?for?any?actions?to?run.
????Arguments:
????????batch:?Integer,?index?of?batch?within?the?current?epoch.
????????logs:?Dict,?contains?the?return?value?of?`model.test_step`.?Typically,
??????????the?values?of?the?`Model`'s?metrics?are?returned.??Example:
??????????`{'loss':?0.2,?'accuracy':?0.7}`.
????"""
??@doc_controls.for_subclass_implementers
??@generic_utils.default
??def?on_test_batch_end(self,?batch,?logs=None):
????"""Called?at?the?end?of?a?batch?in?`evaluate`?methods.
????Also?called?at?the?end?of?a?validation?batch?in?the?`fit`
????methods,?if?validation?data?is?provided.
????Subclasses?should?override?for?any?actions?to?run.
????Arguments:
????????batch:?Integer,?index?of?batch?within?the?current?epoch.
????????logs:?Dict.?Aggregated?metric?results?up?until?this?batch.
????"""
??@doc_controls.for_subclass_implementers
??@generic_utils.default
??def?on_predict_batch_begin(self,?batch,?logs=None):
????"""Called?at?the?beginning?of?a?batch?in?`predict`?methods.
????Subclasses?should?override?for?any?actions?to?run.
????Arguments:
????????batch:?Integer,?index?of?batch?within?the?current?epoch.
????????logs:?Dict,?contains?the?return?value?of?`model.predict_step`,
??????????it?typically?returns?a?dict?with?a?key?'outputs'?containing
??????????the?model's?outputs.
????"""
??@doc_controls.for_subclass_implementers
??@generic_utils.default
??def?on_predict_batch_end(self,?batch,?logs=None):
????"""Called?at?the?end?of?a?batch?in?`predict`?methods.
????Subclasses?should?override?for?any?actions?to?run.
????Arguments:
????????batch:?Integer,?index?of?batch?within?the?current?epoch.
????????logs:?Dict.?Aggregated?metric?results?up?until?this?batch.
????"""
??@doc_controls.for_subclass_implementers
??def?on_train_begin(self,?logs=None):
????"""Called?at?the?beginning?of?training.
????Subclasses?should?override?for?any?actions?to?run.
????Arguments:
????????logs:?Dict.?Currently?no?data?is?passed?to?this?argument?for?this?method
??????????but?that?may?change?in?the?future.
????"""
??@doc_controls.for_subclass_implementers
??def?on_train_end(self,?logs=None):
????"""Called?at?the?end?of?training.
????Subclasses?should?override?for?any?actions?to?run.
????Arguments:
????????logs:?Dict.?Currently?the?output?of?the?last?call?to?`on_epoch_end()`
??????????is?passed?to?this?argument?for?this?method?but?that?may?change?in
??????????the?future.
????"""
??@doc_controls.for_subclass_implementers
??def?on_test_begin(self,?logs=None):
????"""Called?at?the?beginning?of?evaluation?or?validation.
????Subclasses?should?override?for?any?actions?to?run.
????Arguments:
????????logs:?Dict.?Currently?no?data?is?passed?to?this?argument?for?this?method
??????????but?that?may?change?in?the?future.
????"""
??@doc_controls.for_subclass_implementers
??def?on_test_end(self,?logs=None):
????"""Called?at?the?end?of?evaluation?or?validation.
????Subclasses?should?override?for?any?actions?to?run.
????Arguments:
????????logs:?Dict.?Currently?the?output?of?the?last?call?to
??????????`on_test_batch_end()`?is?passed?to?this?argument?for?this?method
??????????but?that?may?change?in?the?future.
????"""
??@doc_controls.for_subclass_implementers
??def?on_predict_begin(self,?logs=None):
????"""Called?at?the?beginning?of?prediction.
????Subclasses?should?override?for?any?actions?to?run.
????Arguments:
????????logs:?Dict.?Currently?no?data?is?passed?to?this?argument?for?this?method
??????????but?that?may?change?in?the?future.
????"""
??@doc_controls.for_subclass_implementers
??def?on_predict_end(self,?logs=None):
????"""Called?at?the?end?of?prediction.
????Subclasses?should?override?for?any?actions?to?run.
????Arguments:
????????logs:?Dict.?Currently?no?data?is?passed?to?this?argument?for?this?method
??????????but?that?may?change?in?the?future.
????"""
??def?_implements_train_batch_hooks(self):
????"""Determines?if?this?Callback?should?be?called?for?each?train?batch."""
????return?(not?generic_utils.is_default(self.on_batch_begin)?or
????????????not?generic_utils.is_default(self.on_batch_end)?or
????????????not?generic_utils.is_default(self.on_train_batch_begin)?or
????????????not?generic_utils.is_default(self.on_train_batch_end))
keras源碼位置: tensorflow\python\keras\engine\training.py
#?Container?that?configures?and?calls?`tf.keras.Callback`s.
??????if?not?isinstance(callbacks,?callbacks_module.CallbackList):
????????callbacks?=?callbacks_module.CallbackList(
????????????callbacks,
????????????add_history=True,
????????????add_progbar=verbose?!=?0,
????????????model=self,
????????????verbose=verbose,
????????????epochs=epochs,
????????????steps=data_handler.inferred_steps)
??????##?I?am?hook
??????callbacks.on_train_begin()
??????training_logs?=?None
??????#?Handle?fault-tolerance?for?multi-worker.
??????#?TODO(omalleyt):?Fix?the?ordering?issues?that?mean?this?has?to
??????#?happen?after?`callbacks.on_train_begin`.
??????data_handler._initial_epoch?=?(??#?pylint:?disable=protected-access
??????????self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
??????for?epoch,?iterator?in?data_handler.enumerate_epochs():
????????self.reset_metrics()
????????callbacks.on_epoch_begin(epoch)
????????with?data_handler.catch_stop_iteration():
??????????for?step?in?data_handler.steps():
????????????with?trace.Trace(
????????????????'TraceContext',
????????????????graph_type='train',
????????????????epoch_num=epoch,
????????????????step_num=step,
????????????????batch_size=batch_size):
??????????????##?I?am?hook
??????????????callbacks.on_train_batch_begin(step)
??????????????tmp_logs?=?train_function(iterator)
??????????????if?data_handler.should_sync:
????????????????context.async_wait()
??????????????logs?=?tmp_logs??#?No?error,?now?safe?to?assign?to?logs.
??????????????end_step?=?step?+?data_handler.step_increment
??????????????callbacks.on_train_batch_end(end_step,?logs)
????????epoch_logs?=?copy.copy(logs)
????????#?Run?validation.
????????##?I?am?hook
????????callbacks.on_epoch_end(epoch,?epoch_logs)
3.2 mmdetection
https://github.com/open-mmlab/mmdetectionhttps://github.com/open-mmlab/mmdetection/blob/5d592154cca589c5113e8aadc8798bbc73630d98/mmdet/apis/train.py)def?train_detector(model,
???????????????????dataset,
???????????????????cfg,
???????????????????distributed=False,
???????????????????validate=False,
???????????????????timestamp=None,
???????????????????meta=None):
????logger?=?get_root_logger(cfg.log_level)
????#?prepare?data?loaders
????#?put?model?on?gpus
????#?build?runner
????optimizer?=?build_optimizer(model,?cfg.optimizer)
????runner?=?EpochBasedRunner(
????????model,
????????optimizer=optimizer,
????????work_dir=cfg.work_dir,
????????logger=logger,
????????meta=meta)
????#?an?ugly?workaround?to?make?.log?and?.log.json?filenames?the?same
????runner.timestamp?=?timestamp
????#?fp16?setting
????#?register?hooks
????runner.register_training_hooks(cfg.lr_config,?optimizer_config,
???????????????????????????????????cfg.checkpoint_config,?cfg.log_config,
???????????????????????????????????cfg.get('momentum_config',?None))
????if?distributed:
????????runner.register_hook(DistSamplerSeedHook())
????#?register?eval?hooks
????if?validate:
????????#?Support?batch_size?>?1?in?validation
????????eval_cfg?=?cfg.get('evaluation',?{})
????????eval_hook?=?DistEvalHook?if?distributed?else?EvalHook
????????runner.register_hook(eval_hook(val_dataloader,?**eval_cfg))
????#?user-defined?hooks
????if?cfg.get('custom_hooks',?None):
????????custom_hooks?=?cfg.custom_hooks
????????assert?isinstance(custom_hooks,?list),?\
????????????f'custom_hooks?expect?list?type,?but?got?{type(custom_hooks)}'
????????for?hook_cfg?in?cfg.custom_hooks:
????????????assert?isinstance(hook_cfg,?dict),?\
????????????????'Each?item?in?custom_hooks?expects?dict?type,?but?got?'?\
????????????????f'{type(hook_cfg)}'
????????????hook_cfg?=?hook_cfg.copy()
????????????priority?=?hook_cfg.pop('priority',?'NORMAL')
????????????hook?=?build_from_cfg(hook_cfg,?HOOKS)
????????????runner.register_hook(hook,?priority=priority)
4. 總結(jié)
hook函數(shù)是流程中預(yù)定義好的一個步驟,沒有實現(xiàn) 掛載或者注冊時, 流程執(zhí)行就會執(zhí)行這個鉤子函數(shù) 回調(diào)函數(shù)和hook函數(shù)功能上是一致的 hook設(shè)計方式帶來靈活性,如果流程中有一個步驟,你想讓調(diào)用方來實現(xiàn),你可以用hook函數(shù)
作者簡介:wedo實驗君, 數(shù)據(jù)分析師;熱愛生活,熱愛寫作
贊 賞 作 者

推薦閱讀


點擊下方閱讀原文加入社區(qū)會員
點贊鼓勵一下

評論
圖片
表情
