PyTorch 源碼解讀之 torch.serialization & torch.hub
導(dǎo)讀
?本文解讀基于PyTorch 1.7版本,對(duì)torch.serialization、torch.save和torch.hub展開介紹。
torch.serialization
torch.serialization 實(shí)現(xiàn)對(duì) PyTorch 對(duì)象結(jié)構(gòu)的二進(jìn)制序列化和反序列化,其中序列化由 torch.save 實(shí)現(xiàn),反序列化由 torch.load 實(shí)現(xiàn)。
torch.save
torch.save 主要使用 pickle 來進(jìn)行二進(jìn)制序列化:
def?save(obj,?#?待序列化的對(duì)象
?????????f:?Union[str,?os.PathLike,?BinaryIO],?#?帶寫入的文件
?????????pickle_module=pickle,?#?默認(rèn)使用?pickle?進(jìn)行序列化
?????????pickle_protocol=DEFAULT_PROTOCOL,?#?默認(rèn)使用?pickle?第2版協(xié)議
?????????_use_new_zipfile_serialization=True)?->?None:?#?pytorch?1.6?之后默認(rèn)使用基于?zipfile?的存儲(chǔ)文件格式,?如果想用舊的格式,?
???????????????????????????????????????????????????????#?可設(shè)為False.?torch.load?同時(shí)支持新舊格式文件的讀取.
????#?如果使用?dill?進(jìn)行序列化操作,?dill的版本需大于?0.3.1.
????_check_dill_version(pickle_module)
????with?_open_file_like(f,?'wb')?as?opened_file:
????????#?基于?zipfile?的存儲(chǔ)格式
????????if?_use_new_zipfile_serialization:
????????????with?_open_zipfile_writer(opened_file)?as?opened_zipfile:
????????????????_save(obj,?opened_zipfile,?pickle_module,?pickle_protocol)
????????????????return
????????#?以二進(jìn)制方式寫入文件
????????_legacy_save(obj,?opened_file,?pickle_module,?pickle_protocol)
可以看到核心函數(shù)是?_save(),_legacy_save() ,接下來分別介紹,我們首先介紹_save()函數(shù):
def?_save(obj,?zip_file,?pickle_module,?pickle_protocol):
????serialized_storages?=?{}?#?暫存具體數(shù)據(jù)內(nèi)容以及其對(duì)應(yīng)的key
????def?persistent_id(obj):
????????if?torch.is_storage(obj):?#?如果是需要存儲(chǔ)的數(shù)據(jù)內(nèi)容
????????????storage_type?=?normalize_storage_type(type(obj))?#?存儲(chǔ)類型,int,?float,?...
????????????obj_key?=?str(obj._cdata)?#?數(shù)據(jù)內(nèi)容對(duì)應(yīng)的key.?在load時(shí)根據(jù)key讀取數(shù)據(jù)
????????????location?=?location_tag(obj)?#?cpu?還是cuda
????????????serialized_storages[obj_key]?=?obj?#?數(shù)據(jù)及其對(duì)應(yīng)的key
????????????return?('storage',?storage_type,?obj_key,?location,?obj.size())?#?注意這里沒有具體數(shù)據(jù),只返回?cái)?shù)據(jù)相關(guān)的信息
????????return?None
????data_buf?=?io.BytesIO()?#?開辟?buffer
????pickler?=?pickle_module.Pickler(data_buf,?protocol=pickle_protocol)?#?對(duì)象的結(jié)構(gòu)信息即將寫入?data_buf?中
????pickler.persistent_id?=?persistent_id?#?將對(duì)象的結(jié)構(gòu)信息寫入?data_buf?中,具體數(shù)據(jù)內(nèi)容暫存在?serialized_storages?中
????pickler.dump(obj)?#?對(duì)對(duì)象執(zhí)行寫入操作,寫入過程會(huì)調(diào)?persistent_id?函數(shù)
????data_value?=?data_buf.getvalue()?#?將寫入的對(duì)象的結(jié)構(gòu)信息取出來
????zip_file.write_record('data.pkl',?data_value,?len(data_value))?#?寫入到存儲(chǔ)文件?zip_file?中,注意這里寫入的信息只是對(duì)象的結(jié)構(gòu)
???????????????????????????????????????????????????????????????????#?信息(通過?data.pkl?來標(biāo)識(shí)),具體數(shù)據(jù)內(nèi)容還未寫入
????for?key?in?sorted(serialized_storages.keys()):?#?寫入數(shù)據(jù)內(nèi)容
????????name?=?f'data/{key}'?#?數(shù)據(jù)的名字
????????storage?=?serialized_storages[key]?#?具體數(shù)據(jù)內(nèi)容
????????if?storage.device.type?==?'cpu':?#?數(shù)據(jù)在?cpu?上
????????????num_bytes?=?storage.size()?*?storage.element_size()?#?計(jì)算占用的字節(jié)數(shù)
????????????zip_file.write_record(name,?storage.data_ptr(),?num_bytes)?#?寫入數(shù)據(jù)
????????else:?#?數(shù)據(jù)在?cuda?上
????????????buf?=?io.BytesIO()?#?開辟?buffer
????????????storage._write_file(buf,?_should_read_directly(buf),?False)?#?將?cuda?上的數(shù)據(jù)復(fù)制到內(nèi)存中
????????????buf_value?=?buf.getvalue()?#?讀取內(nèi)存中的數(shù)據(jù)
????????????zip_file.write_record(name,?buf_value,?len(buf_value))?#?寫入數(shù)據(jù)
總的來說?_save()?函數(shù)在將對(duì)象二進(jìn)制序列化的過程中,首先寫入對(duì)象的結(jié)構(gòu)信息,之后再寫入具體的數(shù)據(jù)內(nèi)容。
接下來介紹_legacy_save()函數(shù):
def?_legacy_save(obj,?f,?pickle_module,?pickle_protocol)?->?None:
????import?torch.nn?as?nn
????serialized_container_types?=?{}
????serialized_storages?=?{}
????def?persistent_id(obj:?Any)?->?Optional[Tuple]:
????????if?isinstance(obj,?type)?and?issubclass(obj,?nn.Module):?#?記錄?source?code
????????????if?obj?in?serialized_container_types:?#?如果已經(jīng)記錄過一樣的,不需要重復(fù)記錄
????????????????return?None
????????????serialized_container_types[obj]?=?True
????????????source_file?=?source?=?None
????????????try:
????????????????source_lines,?_,?source_file?=?get_source_lines_and_file(obj)?#?讀取?source?code
????????????????source?=?''.join(source_lines)?#?讀取?source?code
????????????except?Exception:?#?找不到的話,打印warning
????????????????warnings.warn("Couldn't?retrieve?source?code?for?container?of?"
??????????????????????????????"type?"?+?obj.__name__?+?".?It?won't?be?checked?"
??????????????????????????????"for?correctness?upon?loading.")
????????????return?('module',?obj,?source_file,?source)
????????elif?torch.is_storage(obj):?#?與上面?`_save()`?中?`persistent_id()`?的對(duì)應(yīng)內(nèi)容類似
????????????view_metadata:?Optional[Tuple[str,?int,?int]]
????????????obj?=?cast(Storage,?obj)
????????????storage_type?=?normalize_storage_type(type(obj))
????????????offset?=?0
????????????obj_key?=?str(obj._cdata)
????????????location?=?location_tag(obj)
????????????serialized_storages[obj_key]?=?obj
????????????is_view?=?obj._cdata?!=?obj._cdata
????????????if?is_view:
????????????????view_metadata?=?(str(obj._cdata),?offset,?obj.size())
????????????else:
????????????????view_metadata?=?None
????????????return?('storage',?storage_type,?obj_key,?location,?obj.size(),
????????????????????view_metadata)
????????return?None
????#?記錄一些系統(tǒng)信息
????sys_info?=?dict(
????????protocol_version=PROTOCOL_VERSION,
????????little_endian=sys.byteorder?==?'little',
????????type_sizes=dict(
????????????short=SHORT_SIZE,
????????????int=INT_SIZE,
????????????long=LONG_SIZE,
????????),
????)
????pickle_module.dump(MAGIC_NUMBER,?f,?protocol=pickle_protocol)?#?記錄?MAGIC_NUMBER,用于load時(shí)驗(yàn)證文件是否損壞
????pickle_module.dump(PROTOCOL_VERSION,?f,?protocol=pickle_protocol)?#?記錄?pickle?協(xié)議,用于load時(shí)驗(yàn)證pickle協(xié)議是否一致
????pickle_module.dump(sys_info,?f,?protocol=pickle_protocol)?#?記錄一些系統(tǒng)信息
????pickler?=?pickle_module.Pickler(f,?protocol=pickle_protocol)?#?對(duì)象的結(jié)構(gòu)信息即將寫入文件中
????pickler.persistent_id?=?persistent_id??#?將對(duì)象的結(jié)構(gòu)信息寫入?data_buf?中,具體數(shù)據(jù)內(nèi)容暫存在?serialized_storages?中
????pickler.dump(obj)?#?執(zhí)行寫入操作,期間會(huì)調(diào)用?persistent_id()?函數(shù)
????serialized_storage_keys?=?sorted(serialized_storages.keys())
????pickle_module.dump(serialized_storage_keys,?f,?protocol=pickle_protocol)?#?寫入具體數(shù)據(jù)對(duì)應(yīng)的?key
????f.flush()?#?刷新緩存區(qū)
????for?key?in?serialized_storage_keys:
????????serialized_storages[key]._write_file(f,?_should_read_directly(f),?True)?#?寫入具體數(shù)據(jù)
可以看到_legacy_save()和_save()?在序列化的過程中,整體的pipeline是類似的,只是寫入的內(nèi)容有輕微差別。
torch.load
torch.load 主要使用 pickle 來進(jìn)行二進(jìn)制反序列化。
def?load(f,?#?待反序列化的文件
?????????map_location=None,?#?將對(duì)象放到cpu或cuda上,默認(rèn)與文件里對(duì)象的location一致
?????????pickle_module=pickle,?#?默認(rèn)使用pickle來反序列化
?????????**pickle_load_args):
????_check_dill_version(pickle_module)
????if?'encoding'?not?in?pickle_load_args.keys():?#?默認(rèn)使用?utf-8?解碼
????????pickle_load_args['encoding']?=?'utf-8'
????with?_open_file_like(f,?'rb')?as?opened_file:
????????if?_is_zipfile(opened_file):?#?如果是基于?zipfile?的存儲(chǔ)格式
????????????orig_position?=?opened_file.tell()
????????????with?_open_zipfile_reader(opened_file)?as?opened_zipfile:
????????????????if?_is_torchscript_zip(opened_zipfile):?#?如果存的torchscript文件,用torch.jit.load().否則用_load()反序列化
????????????????????warnings.warn(
????????????????????????"'torch.load'?received?a?zip?file?that?looks?like?a?TorchScript?archive"
????????????????????????"?dispatching?to?'torch.jit.load'?(call?'torch.jit.load'?directly?to"
????????????????????????"?silence?this?warning)",?UserWarning)
????????????????????opened_file.seek(orig_position)
????????????????????return?torch.jit.load(opened_file)
????????????????return?_load(opened_zipfile,?map_location,?pickle_module,
?????????????????????????????**pickle_load_args)
????????#?對(duì)二進(jìn)制文件,用_legacy_load()反序列化
????????return?_legacy_load(opened_file,?map_location,?pickle_module,
????????????????????????????**pickle_load_args)
可以看到核心函數(shù)是_load(),_legacy_load(),接下來分別介紹,我們首先介紹_load()函數(shù):
def?_load(zip_file,
??????????map_location,
??????????pickle_module,
??????????pickle_file='data.pkl',?#?注意這里的'data.pkl'與_save()中的一一對(duì)應(yīng)
??????????**pickle_load_args):
????restore_location?=?_get_restore_location(map_location)?#?根據(jù)map_location來生成restore_location函數(shù),用于將數(shù)據(jù)放在cpu或cuda上
????loaded_storages?=?{}
????def?load_tensor(data_type,?size,?key,?location):
????????name?=?f'data/{key}'?#?數(shù)據(jù)的key,用于尋找數(shù)據(jù)
????????dtype?=?data_type(0).dtype?#?數(shù)據(jù)類型,比如?int,?float,?...
????????storage?=?zip_file.get_storage_from_record(name,?size,?dtype).storage()?#?從文件中找到數(shù)據(jù)
????????loaded_storages[key]?=?restore_location(storage,?location)?#?放到?cpu?或?cuda?上
????def?persistent_load(saved_id):
????????assert?isinstance(saved_id,?tuple)?#?save_id?=?('storage',?storage_type,?obj_key,?location,?obj.size())
????????typename?=?_maybe_decode_ascii(saved_id[0])
????????data?=?saved_id[1:]
????????assert?typename?==?'storage',?\
????????????f"Unknown?typename?for?persistent_load,?expected?'storage'?but?got?'{typename}'"
????????data_type,?key,?location,?size?=?data?#?data_type,?key,?location,?size?=?storage_type,?obj_key,?location,?obj.size()
????????if?key?not?in?loaded_storages:
????????????load_tensor(data_type,?size,?key,?_maybe_decode_ascii(location))
????????storage?=?loaded_storages[key]
????????return?storage
????data_file?=?io.BytesIO(zip_file.get_record(pickle_file))?#?讀取對(duì)象的配置文件`data.pkl`,存儲(chǔ)的對(duì)象的結(jié)構(gòu)信息
????unpickler?=?pickle_module.Unpickler(data_file,?**pickle_load_args)
????unpickler.persistent_load?=?persistent_load?#?用于讀取具體數(shù)據(jù)的persistent_load函數(shù)
????result?=?unpickler.load()?#?執(zhí)行讀取操作
????torch._utils._validate_loaded_sparse_tensors()
????return?result
總的來說?_load()?函數(shù)在將對(duì)象二進(jìn)制反序列化的過程中,在構(gòu)建對(duì)象結(jié)構(gòu)信息的同時(shí),就已經(jīng)將具體的數(shù)據(jù)內(nèi)容加載進(jìn)來了。_legacy_load()函數(shù)與它不同,_legacy_load()是先構(gòu)建對(duì)象結(jié)構(gòu)信息,再加載具體的數(shù)據(jù)。
def?_legacy_load(f,?map_location,?pickle_module,?**pickle_load_args):
????deserialized_objects:?Dict[int,?Any]?=?{}
????restore_location?=?_get_restore_location(map_location)?#?根據(jù)map_location來生成restore_location函數(shù),用于將數(shù)據(jù)放在cpu或cuda上
????def?legacy_load(f):
????????deserialized_objects:?Dict[int,?Any]?=?{}
????????#?由于不是基于?zipfile?的存儲(chǔ)格式,報(bào)錯(cuò)退出,之后代碼不會(huì)執(zhí)行
????????with?closing(tarfile.open(fileobj=f,?mode='r:',?format=tarfile.PAX_FORMAT))?as?tar,?\
????????????????mkdtemp()?as?tmpdir:
????????????...
????deserialized_objects?=?{}
????def?persistent_load(saved_id):
????????assert?isinstance(saved_id,?tuple)?#?saved_id?=?('storage',?storage_type,?obj_key,?location,?obj.size(),?view_metadata)
???????????????????????????????????????????#?or?saved_id?=?('module',?obj,?source_file,?source)
????????typename?=?_maybe_decode_ascii(saved_id[0])
????????data?=?saved_id[1:]
????????if?typename?==?'module':
????????????#?Ignore?containers?that?don't?have?any?sources?saved
????????????if?all(data[1:]):
????????????????_check_container_source(*data)?#?檢查source?code是否一致
????????????return?data[0]
????????elif?typename?==?'storage':?#?注意這里并沒有載入具體數(shù)據(jù),只是恢復(fù)了對(duì)象的結(jié)構(gòu)信息
????????????data_type,?root_key,?location,?size,?view_metadata?=?data
????????????location?=?_maybe_decode_ascii(location)
????????????if?root_key?not?in?deserialized_objects:
????????????????obj?=?data_type(size)
????????????????obj._torch_load_uninitialized?=?True
????????????????deserialized_objects[root_key]?=?restore_location(
????????????????????obj,?location)
????????????storage?=?deserialized_objects[root_key]
????????????if?view_metadata?is?not?None:
????????????????view_key,?offset,?view_size?=?view_metadata
????????????????if?view_key?not?in?deserialized_objects:
????????????????????deserialized_objects[view_key]?=?storage[offset:offset?+
?????????????????????????????????????????????????????????????view_size]
????????????????return?deserialized_objects[view_key]
????????????else:
????????????????return?storage
????????else:
????????????raise?RuntimeError("Unknown?saved?id?type:?%s"?%?saved_id[0])
????_check_seekable(f)?#?檢查文件是否支持seek(), tell()方法。seek()用于定位到文件任意位置,tell()返回指針在文件的當(dāng)前位置
????f_should_read_directly?=?_should_read_directly(f)?#?是否二進(jìn)制可讀,比如如果是zip文件,則為False。
??????????????????????????????????????????????????????#?但由于傳進(jìn)來的文件格式不是zip格式,這里一般為True
????if?f_should_read_directly?and?f.tell()?==?0:
????????try:
????????????return?legacy_load(f)?#?因?yàn)椴皇?zip?格式,報(bào)錯(cuò)退出
????????except?tarfile.TarError:
????????????if?_is_zipfile(f):?#?一般不執(zhí)行
????????????????raise?RuntimeError(
????????????????????f"{f.name}?is?a?zip?archive?(did?you?mean?to?use?torch.jit.load()?)"
????????????????)?from?None
????????????f.seek(0)?#?定位到文件初始位置
????if?not?hasattr(f,
???????????????????'readinto')?and?(3,?8,?0)?<=?sys.version_info?(3,?8,?2):
????????raise?RuntimeError(
????????????"torch.load?does?not?work?with?file-like?objects?that?do?not?implement?readinto?on?Python?3.8.0?and?3.8.1.?"
????????????f"Received?object?of?type?\"{type(f)}\".?Please?update?to?Python?3.8.2?or?newer?to?restore?this?"
????????????"functionality.")
????magic_number?=?pickle_module.load(f,?**pickle_load_args)
????if?magic_number?!=?MAGIC_NUMBER:?#?檢查MAGIC_NUMBER是否一致
????????raise?RuntimeError("Invalid?magic?number;?corrupt?file?")
????protocol_version?=?pickle_module.load(f,?**pickle_load_args)?
????if?protocol_version?!=?PROTOCOL_VERSION:?#?檢查pickle協(xié)議是否一致
????????raise?RuntimeError("Invalid?protocol?version:?%s"?%?protocol_version)
????_sys_info?=?pickle_module.load(f,?**pickle_load_args)?#?讀取一些系統(tǒng)信息
????unpickler?=?pickle_module.Unpickler(f,?**pickle_load_args)
????unpickler.persistent_load?=?persistent_load?
????result?=?unpickler.load()?#?調(diào)用persistent_load()函數(shù)讀取對(duì)象的結(jié)構(gòu)信息,注意此時(shí)還未讀取具體的數(shù)據(jù)
????deserialized_storage_keys?=?pickle_module.load(f,?**pickle_load_args)?#?讀取數(shù)據(jù)對(duì)應(yīng)的key,到這里可以發(fā)現(xiàn)pickle_module.load()
??????????????????????????????????????????????????????????????????????????#?出的結(jié)果和上面`_legacy_save()`函數(shù)中dump的內(nèi)容一一對(duì)應(yīng)
????offset?=?f.tell()?if?f_should_read_directly?else?None
????for?key?in?deserialized_storage_keys:?#?讀取具體的數(shù)據(jù)
????????assert?key?in?deserialized_objects
????????deserialized_objects[key]._set_from_file(f,?offset,
?????????????????????????????????????????????????f_should_read_directly)
????????if?offset?is?not?None:
????????????offset?=?f.tell()
????torch._utils._validate_loaded_sparse_tensors()
????return?result
在load()和_legacy_load()中都有_get_restore_location()函數(shù)生成restore_location(obj,location)函數(shù),它決定將讀取的對(duì)象(obj)放到 CPU or CUDA (location)上,接下來我們介紹_get_restore_location():
def?_cpu_deserialize(obj,?location):?#?將對(duì)象放到cpu上,注意可能返回None
????if?location?==?'cpu':
????????return?obj
def?_cuda_deserialize(obj,?location):?#?將對(duì)象放到指定的cuda?device上,注意可能返回None
????if?location.startswith('cuda'):
????????device?=?validate_cuda_device(location)?#?驗(yàn)證是否有顯卡,以及給定的device?id是否超過當(dāng)前機(jī)器擁有的顯卡數(shù)量
????????if?getattr(obj,?"_torch_load_uninitialized",?False):
????????????storage_type?=?getattr(torch.cuda,?type(obj).__name__)
????????????with?torch.cuda.device(device):
????????????????return?storage_type(obj.size())
????????else:
????????????return?obj.cuda(device)
_package_registry?=?[]
def?register_package(priority,?tagger,?deserializer):
????queue_elem?=?(priority,?tagger,?deserializer)
????_package_registry.append(queue_elem)
????_package_registry.sort()
register_package(10,?_cpu_tag,?_cpu_deserialize)
register_package(20,?_cuda_tag,?_cuda_deserialize)
def?default_restore_location(storage,?location):?#?按先cpu后cuda的優(yōu)先級(jí)將數(shù)據(jù)放入cpu或cuda上
????for?_,?_,?fn?in?_package_registry:
????????result?=?fn(storage,?location)
????????if?result?is?not?None:
????????????return?result
????raise?RuntimeError("don't?know?how?to?restore?data?location?of?"?+
???????????????????????torch.typename(storage)?+?"?(tagged?with?"?+?location?+
???????????????????????")")
def?_get_restore_location(map_location):
????if?map_location?is?None:
????????restore_location?=?default_restore_location?#?map_location?=?None?:?放到location記錄的cpu?or?cuda上
????elif?isinstance(map_location,?dict):
????????def?restore_location(storage,?location):?# map_location =?{'cpu':?'cuda:0'}?:?如果location是'cpu',則放到'cuda:0'上;
?????????????????????????????????????????????????#?否則仍放到'cpu'上
????????????location?=?map_location.get(location,?location)
????????????return?default_restore_location(storage,?location)
????elif?isinstance(map_location,?_string_classes):?#?map_location?=?'cuda:0'?:?不管location是什么,都放到'cuda:0'上
????????def?restore_location(storage,?location):
????????????return?default_restore_location(storage,?map_location)
????elif?isinstance(map_location,?torch.device):
????????def?restore_location(storage,?location):?#?map_location?=?torch.device('cpu')?:?不管location是什么,都放到'cpu'上
????????????return?default_restore_location(storage,?str(map_location))
????else:
????????def?restore_location(storage,?location):?#?可以替換default_restore_location函數(shù),map_location是一個(gè)函數(shù)
?????????????????????????????????????????????????#?比如?map_location?=?lambda?storage,?location:?storage.cuda(1)?表示
?????????????????????????????????????????????????#?不管location是什么,都放到'cuda:1'上
????????????result?=?map_location(storage,?location)
????????????if?result?is?None:
????????????????result?=?default_restore_location(storage,?location)
????????????return?result
????return?restore_location
以上是torch.serialization的源碼分析,torch.serialization主要包含torch.save(),torch.load()函數(shù),其中torch.save()主要通過調(diào)用_save()or_legacy_save()實(shí)現(xiàn),torch.load()主要通過調(diào)用_load()or_legacy_load().torch.load()中的map_location參數(shù)通過_get_restore_location()函數(shù)決定將對(duì)象反序列化到 CPU 還是 CUDA 上。
torch.hub
torch.hub 提供了一系列 pretrained models 來方便大家使用,我們以https:// github.com/ pytorch/ vision為例,介紹怎樣使用 torch.hub 提供的接口來調(diào)用 torchvision 里的 model。
torch.hub主要提供了三個(gè)接口torch.hub.list(),torch.hub.help(),torch.hub.load(),我們依次介紹。
torch.hub.list() 會(huì)從給定的 GitHub repo 中尋找 hubconf.py(此文件導(dǎo)入 repo 里提供的所有 models),然后返回一個(gè) list,里面包含了提供的 model 類名。https://github.com/pytorch/vision下的 hubconf.py 文件內(nèi)容如下:
#?Optional?list?of?dependencies?required?by?the?package
dependencies?=?['torch']
#?classification
from?torchvision.models.alexnet?import?alexnet
from?torchvision.models.densenet?import?densenet121,?densenet169,?densenet201,?densenet161
from?torchvision.models.inception?import?inception_v3
from?torchvision.models.resnet?import?resnet18,?resnet34,?resnet50,?resnet101,?resnet152,\
????resnext50_32x4d,?resnext101_32x8d,?wide_resnet50_2,?wide_resnet101_2
from?torchvision.models.squeezenet?import?squeezenet1_0,?squeezenet1_1
from?torchvision.models.vgg?import?vgg11,?vgg13,?vgg16,?vgg19,?vgg11_bn,?vgg13_bn,?vgg16_bn,?vgg19_bn
from?torchvision.models.googlenet?import?googlenet
from?torchvision.models.shufflenetv2?import?shufflenet_v2_x0_5,?shufflenet_v2_x1_0
from?torchvision.models.mobilenetv2?import?mobilenet_v2
from?torchvision.models.mobilenetv3?import?mobilenet_v3_large,?mobilenet_v3_small
from?torchvision.models.mnasnet?import?mnasnet0_5,?mnasnet0_75,?mnasnet1_0,?\
????mnasnet1_3
#?segmentation
from?torchvision.models.segmentation?import?fcn_resnet50,?fcn_resnet101,?\
????deeplabv3_resnet50,?deeplabv3_resnet101,?deeplabv3_mobilenet_v3_large,?lraspp_mobilenet_v3_large
接下來我們分析torch.hub.list()代碼:
def?list(github,?# repo的名字,比如`pytorch/vision`。注意沒有前綴`https://github.com/`,代碼里hard code進(jìn)去了
?????????force_reload=False):?#?是否要重新下載?repo
????repo_dir?=?_get_cache_or_reload(github,?force_reload,?True)?#?根據(jù)repo的地址下載到本地,然后返回下載到本地的repo的路徑
????????????????????????????????????????????????????????????????#?可以通過torch.hub.get_dir()得到下載的根目錄,提前通過
????????????????????????????????????????????????????????????????#?torch.hub.set_dir(string)設(shè)置下載的根目錄
????sys.path.insert(0,?repo_dir)?#?本地的repo路徑加入到搜索路徑中,優(yōu)先級(jí)最高
????hub_module?=?import_module(MODULE_HUBCONF,?repo_dir?+?'/'?+?MODULE_HUBCONF)?#?MODULE_HUBCONF?=?'hubconf.py',
????????????????????????????????????????????????????????????????????????????????#?從本地的repo中找到'hubconf.py',
????????????????????????????????????????????????????????????????????????????????#?并解析得到'hubconf.py'里提供的所有module
????sys.path.remove(repo_dir)?#?從搜索路徑中刪除本地repo路徑
????#?We?take?functions?starts?with?'_'?as?internal?helper?functions
????entrypoints?=?[f?for?f?in?dir(hub_module)?if?callable(getattr(hub_module,?f))?and?not?f.startswith('_')]?#?注意這里類名的開頭
?????????????????????????????????????????????????????????????????????????????????????????????????????????????#?如果是'_'的話,
?????????????????????????????????????????????????????????????????????????????????????????????????????????????#?將會(huì)被濾掉
????return?entrypoints?#?list(string),?repo提供的model類名
#?An?example:
print(torch.hub.list('pytorch/vision',?True))
#?print?info:
'''
['alexnet',?'deeplabv3_mobilenet_v3_large',?'deeplabv3_resnet101',?'deeplabv3_resnet50',?'densenet121',?'densenet161',
'densenet169',?'densenet201',?'fcn_resnet101',?'fcn_resnet50',?'googlenet',?'inception_v3',?'lraspp_mobilenet_v3_large',
'mnasnet0_5',?'mnasnet0_75',?'mnasnet1_0',?'mnasnet1_3',?'mobilenet_v2',?'mobilenet_v3_large',?'mobilenet_v3_small',
'resnet101',?'resnet152',?'resnet18',?'resnet34',?'resnet50',?'resnext101_32x8d',?'resnext50_32x4d',?'shufflenet_v2_x0_5',
'shufflenet_v2_x1_0',?'squeezenet1_0',?'squeezenet1_1',?'vgg11',?'vgg11_bn',?'vgg13',?'vgg13_bn',?'vgg16',?'vgg16_bn',
'vgg19',?'vgg19_bn',?'wide_resnet101_2',?'wide_resnet50_2']
'''
torch.hub.help()會(huì)返回給定 repo 下給定 module 的文檔:
def?help(github,
?????????model,?#?module的名字,比如'resnet50'
?????????force_reload=False):
????repo_dir?=?_get_cache_or_reload(github,?force_reload,?True)
????sys.path.insert(0,?repo_dir)
????hub_module?=?import_module(MODULE_HUBCONF,?repo_dir?+?'/'?+?MODULE_HUBCONF)
????sys.path.remove(repo_dir)
????entry?=?_load_entry_from_hubconf(hub_module,?model)?#?從所有modules(hub_module)里找到給定module(model)
????return?entry.__doc__?#?返回`model`的文檔
#?An?example:
print(torch.hub.help('pytorch/vision',?'resnet18',?True))
#?print?info:
'''
ResNet-18?model?from
????`"Deep?Residual?Learning?for?Image?Recognition"?`_.
????Args:
????????pretrained?(bool):?If?True,?returns?a?model?pre-trained?on?ImageNet
????????progress?(bool):?If?True,?displays?a?progress?bar?of?the?download?to?stderr
'''
torch.hub.load()會(huì)返回實(shí)例化后的 module:
def?load(repo_or_dir,?#?本地的路徑,或者github上的repo名
?????????model,
?????????*args,?#?用于實(shí)例化?module
?????????**kwargs):?#?用于實(shí)例化?module
????source?=?kwargs.pop('source',?'github').lower()?#?repo?默認(rèn)從github上尋找
????force_reload?=?kwargs.pop('force_reload',?False)
????verbose?=?kwargs.pop('verbose',?True)?#?如果True,打印一些log
????if?source?not?in?('github',?'local'):?#?要么從github上找repo,要么從本地找repo
????????raise?ValueError(
????????????f'Unknown?source:?"{source}".?Allowed?values:?"github"?|?"local".')
????if?source?==?'github':
????????repo_or_dir?=?_get_cache_or_reload(repo_or_dir,?force_reload,?verbose)
????model?=?_load_local(repo_or_dir,?model,?*args,?**kwargs)
????return?model
def?_load_local(hubconf_dir,?model,?*args,?**kwargs):
????sys.path.insert(0,?hubconf_dir)
????hubconf_path?=?os.path.join(hubconf_dir,?MODULE_HUBCONF)
????hub_module?=?import_module(MODULE_HUBCONF,?hubconf_path)
????entry?=?_load_entry_from_hubconf(hub_module,?model)?#?找到指定的module
????model?=?entry(*args,?**kwargs)?#?實(shí)例化?module
????sys.path.remove(hubconf_dir)
????return?model
#?An?example:
resnet18?=?torch.hub.load('pytorch/vision',?'resnet18',?pretrained=True)?#?載入預(yù)訓(xùn)練權(quán)重
以上以 pytorch/vision 為例介紹了torch.hub的使用。實(shí)際上只要一個(gè) GitHub repo 里有 hubconf.py 的文件,都可以使用 torch.hub 提供的接口,比如一個(gè)簡(jiǎn)單的例子?。
原文鏈接:https://zhuanlan.zhihu.com/p/364239544
如果覺得有用,就請(qǐng)分享到朋友圈吧!
