<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          PyTorch 源碼解讀之 torch.serialization & torch.hub

          共 11722字,需瀏覽 24分鐘

           ·

          2021-10-29 16:07


          作者 | 123456?
          來源 | OpenMMLab?
          編輯 | 極市平臺(tái)

          導(dǎo)讀

          ?

          本文解讀基于PyTorch 1.7版本,對(duì)torch.serialization、torch.save和torch.hub展開介紹。

          torch.serialization

          torch.serialization 實(shí)現(xiàn)對(duì) PyTorch 對(duì)象結(jié)構(gòu)的二進(jìn)制序列化和反序列化,其中序列化由 torch.save 實(shí)現(xiàn),反序列化由 torch.load 實(shí)現(xiàn)。

          torch.save

          torch.save 主要使用 pickle 來進(jìn)行二進(jìn)制序列化:

          def?save(obj,?#?待序列化的對(duì)象
          ?????????f:?Union[str,?os.PathLike,?BinaryIO],?#?帶寫入的文件
          ?????????pickle_module=pickle,?#?默認(rèn)使用?pickle?進(jìn)行序列化
          ?????????pickle_protocol=DEFAULT_PROTOCOL,?#?默認(rèn)使用?pickle?第2版協(xié)議
          ?????????_use_new_zipfile_serialization=True)
          ?->?None:
          ?#?pytorch?1.6?之后默認(rèn)使用基于?zipfile?的存儲(chǔ)文件格式,?如果想用舊的格式,?
          ???????????????????????????????????????????????????????#?可設(shè)為False.?torch.load?同時(shí)支持新舊格式文件的讀取.

          ????#?如果使用?dill?進(jìn)行序列化操作,?dill的版本需大于?0.3.1.
          ????_check_dill_version(pickle_module)

          ????with?_open_file_like(f,?'wb')?as?opened_file:
          ????????#?基于?zipfile?的存儲(chǔ)格式
          ????????if?_use_new_zipfile_serialization:
          ????????????with?_open_zipfile_writer(opened_file)?as?opened_zipfile:
          ????????????????_save(obj,?opened_zipfile,?pickle_module,?pickle_protocol)
          ????????????????return
          ????????#?以二進(jìn)制方式寫入文件
          ????????_legacy_save(obj,?opened_file,?pickle_module,?pickle_protocol)

          可以看到核心函數(shù)是?_save()_legacy_save() ,接下來分別介紹,我們首先介紹_save()函數(shù)

          def?_save(obj,?zip_file,?pickle_module,?pickle_protocol):
          ????serialized_storages?=?{}?#?暫存具體數(shù)據(jù)內(nèi)容以及其對(duì)應(yīng)的key

          ????def?persistent_id(obj):
          ????????if?torch.is_storage(obj):?#?如果是需要存儲(chǔ)的數(shù)據(jù)內(nèi)容
          ????????????storage_type?=?normalize_storage_type(type(obj))?#?存儲(chǔ)類型,int,?float,?...
          ????????????obj_key?=?str(obj._cdata)?#?數(shù)據(jù)內(nèi)容對(duì)應(yīng)的key.?在load時(shí)根據(jù)key讀取數(shù)據(jù)
          ????????????location?=?location_tag(obj)?#?cpu?還是cuda
          ????????????serialized_storages[obj_key]?=?obj?#?數(shù)據(jù)及其對(duì)應(yīng)的key

          ????????????return?('storage',?storage_type,?obj_key,?location,?obj.size())?#?注意這里沒有具體數(shù)據(jù),只返回?cái)?shù)據(jù)相關(guān)的信息
          ????????return?None

          ????data_buf?=?io.BytesIO()?#?開辟?buffer
          ????pickler?=?pickle_module.Pickler(data_buf,?protocol=pickle_protocol)?#?對(duì)象的結(jié)構(gòu)信息即將寫入?data_buf?中
          ????pickler.persistent_id?=?persistent_id?#?將對(duì)象的結(jié)構(gòu)信息寫入?data_buf?中,具體數(shù)據(jù)內(nèi)容暫存在?serialized_storages?中
          ????pickler.dump(obj)?#?對(duì)對(duì)象執(zhí)行寫入操作,寫入過程會(huì)調(diào)?persistent_id?函數(shù)
          ????data_value?=?data_buf.getvalue()?#?將寫入的對(duì)象的結(jié)構(gòu)信息取出來
          ????zip_file.write_record('data.pkl',?data_value,?len(data_value))?#?寫入到存儲(chǔ)文件?zip_file?中,注意這里寫入的信息只是對(duì)象的結(jié)構(gòu)
          ???????????????????????????????????????????????????????????????????#?信息(通過?data.pkl?來標(biāo)識(shí)),具體數(shù)據(jù)內(nèi)容還未寫入

          ????for?key?in?sorted(serialized_storages.keys()):?#?寫入數(shù)據(jù)內(nèi)容
          ????????name?=?f'data/{key}'?#?數(shù)據(jù)的名字
          ????????storage?=?serialized_storages[key]?#?具體數(shù)據(jù)內(nèi)容
          ????????if?storage.device.type?==?'cpu':?#?數(shù)據(jù)在?cpu?上
          ????????????num_bytes?=?storage.size()?*?storage.element_size()?#?計(jì)算占用的字節(jié)數(shù)
          ????????????zip_file.write_record(name,?storage.data_ptr(),?num_bytes)?#?寫入數(shù)據(jù)
          ????????else:?#?數(shù)據(jù)在?cuda?上
          ????????????buf?=?io.BytesIO()?#?開辟?buffer
          ????????????storage._write_file(buf,?_should_read_directly(buf),?False)?#?將?cuda?上的數(shù)據(jù)復(fù)制到內(nèi)存中
          ????????????buf_value?=?buf.getvalue()?#?讀取內(nèi)存中的數(shù)據(jù)
          ????????????zip_file.write_record(name,?buf_value,?len(buf_value))?#?寫入數(shù)據(jù)

          總的來說?_save()?函數(shù)在將對(duì)象二進(jìn)制序列化的過程中,首先寫入對(duì)象的結(jié)構(gòu)信息,之后再寫入具體的數(shù)據(jù)內(nèi)容。
          接下來介紹_legacy_save()函數(shù):

          def?_legacy_save(obj,?f,?pickle_module,?pickle_protocol)?->?None:
          ????import?torch.nn?as?nn
          ????serialized_container_types?=?{}
          ????serialized_storages?=?{}

          ????def?persistent_id(obj:?Any)?->?Optional[Tuple]:
          ????????if?isinstance(obj,?type)?and?issubclass(obj,?nn.Module):?#?記錄?source?code
          ????????????if?obj?in?serialized_container_types:?#?如果已經(jīng)記錄過一樣的,不需要重復(fù)記錄
          ????????????????return?None
          ????????????serialized_container_types[obj]?=?True
          ????????????source_file?=?source?=?None
          ????????????try:
          ????????????????source_lines,?_,?source_file?=?get_source_lines_and_file(obj)?#?讀取?source?code
          ????????????????source?=?''.join(source_lines)?#?讀取?source?code
          ????????????except?Exception:?#?找不到的話,打印warning
          ????????????????warnings.warn("Couldn't?retrieve?source?code?for?container?of?"
          ??????????????????????????????"type?"?+?obj.__name__?+?".?It?won't?be?checked?"
          ??????????????????????????????"for?correctness?upon?loading.")
          ????????????return?('module',?obj,?source_file,?source)

          ????????elif?torch.is_storage(obj):?#?與上面?`_save()`?中?`persistent_id()`?的對(duì)應(yīng)內(nèi)容類似
          ????????????view_metadata:?Optional[Tuple[str,?int,?int]]
          ????????????obj?=?cast(Storage,?obj)
          ????????????storage_type?=?normalize_storage_type(type(obj))
          ????????????offset?=?0
          ????????????obj_key?=?str(obj._cdata)
          ????????????location?=?location_tag(obj)
          ????????????serialized_storages[obj_key]?=?obj
          ????????????is_view?=?obj._cdata?!=?obj._cdata
          ????????????if?is_view:
          ????????????????view_metadata?=?(str(obj._cdata),?offset,?obj.size())
          ????????????else:
          ????????????????view_metadata?=?None

          ????????????return?('storage',?storage_type,?obj_key,?location,?obj.size(),
          ????????????????????view_metadata)
          ????????return?None

          ????#?記錄一些系統(tǒng)信息
          ????sys_info?=?dict(
          ????????protocol_version=PROTOCOL_VERSION,
          ????????little_endian=sys.byteorder?==?'little',
          ????????type_sizes=dict(
          ????????????short=SHORT_SIZE,
          ????????????int=INT_SIZE,
          ????????????long=LONG_SIZE,
          ????????),
          ????)

          ????pickle_module.dump(MAGIC_NUMBER,?f,?protocol=pickle_protocol)?#?記錄?MAGIC_NUMBER,用于load時(shí)驗(yàn)證文件是否損壞
          ????pickle_module.dump(PROTOCOL_VERSION,?f,?protocol=pickle_protocol)?#?記錄?pickle?協(xié)議,用于load時(shí)驗(yàn)證pickle協(xié)議是否一致
          ????pickle_module.dump(sys_info,?f,?protocol=pickle_protocol)?#?記錄一些系統(tǒng)信息
          ????pickler?=?pickle_module.Pickler(f,?protocol=pickle_protocol)?#?對(duì)象的結(jié)構(gòu)信息即將寫入文件中
          ????pickler.persistent_id?=?persistent_id??#?將對(duì)象的結(jié)構(gòu)信息寫入?data_buf?中,具體數(shù)據(jù)內(nèi)容暫存在?serialized_storages?中
          ????pickler.dump(obj)?#?執(zhí)行寫入操作,期間會(huì)調(diào)用?persistent_id()?函數(shù)

          ????serialized_storage_keys?=?sorted(serialized_storages.keys())
          ????pickle_module.dump(serialized_storage_keys,?f,?protocol=pickle_protocol)?#?寫入具體數(shù)據(jù)對(duì)應(yīng)的?key
          ????f.flush()?#?刷新緩存區(qū)
          ????for?key?in?serialized_storage_keys:
          ????????serialized_storages[key]._write_file(f,?_should_read_directly(f),?True)?#?寫入具體數(shù)據(jù)

          可以看到_legacy_save()_save()?在序列化的過程中,整體的pipeline是類似的,只是寫入的內(nèi)容有輕微差別。

          torch.load

          torch.load 主要使用 pickle 來進(jìn)行二進(jìn)制反序列化。

          def?load(f,?#?待反序列化的文件
          ?????????map_location=None,?#?將對(duì)象放到cpu或cuda上,默認(rèn)與文件里對(duì)象的location一致
          ?????????pickle_module=pickle,?#?默認(rèn)使用pickle來反序列化
          ?????????**pickle_load_args)
          :

          ????_check_dill_version(pickle_module)

          ????if?'encoding'?not?in?pickle_load_args.keys():?#?默認(rèn)使用?utf-8?解碼
          ????????pickle_load_args['encoding']?=?'utf-8'

          ????with?_open_file_like(f,?'rb')?as?opened_file:
          ????????if?_is_zipfile(opened_file):?#?如果是基于?zipfile?的存儲(chǔ)格式
          ????????????orig_position?=?opened_file.tell()
          ????????????with?_open_zipfile_reader(opened_file)?as?opened_zipfile:
          ????????????????if?_is_torchscript_zip(opened_zipfile):?#?如果存的torchscript文件,用torch.jit.load().否則用_load()反序列化
          ????????????????????warnings.warn(
          ????????????????????????"'torch.load'?received?a?zip?file?that?looks?like?a?TorchScript?archive"
          ????????????????????????"?dispatching?to?'torch.jit.load'?(call?'torch.jit.load'?directly?to"
          ????????????????????????"?silence?this?warning)",?UserWarning)
          ????????????????????opened_file.seek(orig_position)
          ????????????????????return?torch.jit.load(opened_file)
          ????????????????return?_load(opened_zipfile,?map_location,?pickle_module,
          ?????????????????????????????**pickle_load_args)
          ????????#?對(duì)二進(jìn)制文件,用_legacy_load()反序列化
          ????????return?_legacy_load(opened_file,?map_location,?pickle_module,
          ????????????????????????????**pickle_load_args)

          可以看到核心函數(shù)是_load()_legacy_load(),接下來分別介紹,我們首先介紹_load()函數(shù):

          def?_load(zip_file,
          ??????????map_location,
          ??????????pickle_module,
          ??????????pickle_file='data.pkl',?#?注意這里的'data.pkl'與_save()中的一一對(duì)應(yīng)
          ??????????**pickle_load_args)
          :

          ????restore_location?=?_get_restore_location(map_location)?#?根據(jù)map_location來生成restore_location函數(shù),用于將數(shù)據(jù)放在cpu或cuda上

          ????loaded_storages?=?{}

          ????def?load_tensor(data_type,?size,?key,?location):
          ????????name?=?f'data/{key}'?#?數(shù)據(jù)的key,用于尋找數(shù)據(jù)
          ????????dtype?=?data_type(0).dtype?#?數(shù)據(jù)類型,比如?int,?float,?...

          ????????storage?=?zip_file.get_storage_from_record(name,?size,?dtype).storage()?#?從文件中找到數(shù)據(jù)
          ????????loaded_storages[key]?=?restore_location(storage,?location)?#?放到?cpu?或?cuda?上

          ????def?persistent_load(saved_id):
          ????????assert?isinstance(saved_id,?tuple)?#?save_id?=?('storage',?storage_type,?obj_key,?location,?obj.size())
          ????????typename?=?_maybe_decode_ascii(saved_id[0])
          ????????data?=?saved_id[1:]

          ????????assert?typename?==?'storage',?\
          ????????????f"Unknown?typename?for?persistent_load,?expected?'storage'?but?got?'{typename}'"
          ????????data_type,?key,?location,?size?=?data?#?data_type,?key,?location,?size?=?storage_type,?obj_key,?location,?obj.size()
          ????????if?key?not?in?loaded_storages:
          ????????????load_tensor(data_type,?size,?key,?_maybe_decode_ascii(location))
          ????????storage?=?loaded_storages[key]
          ????????return?storage

          ????data_file?=?io.BytesIO(zip_file.get_record(pickle_file))?#?讀取對(duì)象的配置文件`data.pkl`,存儲(chǔ)的對(duì)象的結(jié)構(gòu)信息
          ????unpickler?=?pickle_module.Unpickler(data_file,?**pickle_load_args)
          ????unpickler.persistent_load?=?persistent_load?#?用于讀取具體數(shù)據(jù)的persistent_load函數(shù)
          ????result?=?unpickler.load()?#?執(zhí)行讀取操作

          ????torch._utils._validate_loaded_sparse_tensors()

          ????return?result

          總的來說?_load()?函數(shù)在將對(duì)象二進(jìn)制反序列化的過程中,在構(gòu)建對(duì)象結(jié)構(gòu)信息的同時(shí),就已經(jīng)將具體的數(shù)據(jù)內(nèi)容加載進(jìn)來了。
          _legacy_load()函數(shù)與它不同,_legacy_load()是先構(gòu)建對(duì)象結(jié)構(gòu)信息,再加載具體的數(shù)據(jù)。

          def?_legacy_load(f,?map_location,?pickle_module,?**pickle_load_args):
          ????deserialized_objects:?Dict[int,?Any]?=?{}

          ????restore_location?=?_get_restore_location(map_location)?#?根據(jù)map_location來生成restore_location函數(shù),用于將數(shù)據(jù)放在cpu或cuda上

          ????def?legacy_load(f):
          ????????deserialized_objects:?Dict[int,?Any]?=?{}

          ????????#?由于不是基于?zipfile?的存儲(chǔ)格式,報(bào)錯(cuò)退出,之后代碼不會(huì)執(zhí)行
          ????????with?closing(tarfile.open(fileobj=f,?mode='r:',?format=tarfile.PAX_FORMAT))?as?tar,?\
          ????????????????mkdtemp()?as?tmpdir:
          ????????????...

          ????deserialized_objects?=?{}

          ????def?persistent_load(saved_id):
          ????????assert?isinstance(saved_id,?tuple)?#?saved_id?=?('storage',?storage_type,?obj_key,?location,?obj.size(),?view_metadata)
          ???????????????????????????????????????????#?or?saved_id?=?('module',?obj,?source_file,?source)
          ????????typename?=?_maybe_decode_ascii(saved_id[0])
          ????????data?=?saved_id[1:]

          ????????if?typename?==?'module':
          ????????????#?Ignore?containers?that?don't?have?any?sources?saved
          ????????????if?all(data[1:]):
          ????????????????_check_container_source(*data)?#?檢查source?code是否一致
          ????????????return?data[0]
          ????????elif?typename?==?'storage':?#?注意這里并沒有載入具體數(shù)據(jù),只是恢復(fù)了對(duì)象的結(jié)構(gòu)信息
          ????????????data_type,?root_key,?location,?size,?view_metadata?=?data
          ????????????location?=?_maybe_decode_ascii(location)
          ????????????if?root_key?not?in?deserialized_objects:
          ????????????????obj?=?data_type(size)
          ????????????????obj._torch_load_uninitialized?=?True
          ????????????????deserialized_objects[root_key]?=?restore_location(
          ????????????????????obj,?location)
          ????????????storage?=?deserialized_objects[root_key]
          ????????????if?view_metadata?is?not?None:
          ????????????????view_key,?offset,?view_size?=?view_metadata
          ????????????????if?view_key?not?in?deserialized_objects:
          ????????????????????deserialized_objects[view_key]?=?storage[offset:offset?+
          ?????????????????????????????????????????????????????????????view_size]
          ????????????????return?deserialized_objects[view_key]
          ????????????else:
          ????????????????return?storage
          ????????else:
          ????????????raise?RuntimeError("Unknown?saved?id?type:?%s"?%?saved_id[0])

          ????_check_seekable(f)?#?檢查文件是否支持seek(), tell()方法。seek()用于定位到文件任意位置,tell()返回指針在文件的當(dāng)前位置
          ????f_should_read_directly?=?_should_read_directly(f)?#?是否二進(jìn)制可讀,比如如果是zip文件,則為False。
          ??????????????????????????????????????????????????????#?但由于傳進(jìn)來的文件格式不是zip格式,這里一般為True

          ????if?f_should_read_directly?and?f.tell()?==?0:
          ????????try:
          ????????????return?legacy_load(f)?#?因?yàn)椴皇?zip?格式,報(bào)錯(cuò)退出
          ????????except?tarfile.TarError:
          ????????????if?_is_zipfile(f):?#?一般不執(zhí)行
          ????????????????raise?RuntimeError(
          ????????????????????f"{f.name}?is?a?zip?archive?(did?you?mean?to?use?torch.jit.load()?)"
          ????????????????)?from?None
          ????????????f.seek(0)?#?定位到文件初始位置

          ????if?not?hasattr(f,
          ???????????????????'readinto')?and?(3,?8,?0)?<=?sys.version_info?3,?8,?2):
          ????????raise?RuntimeError(
          ????????????"torch.load?does?not?work?with?file-like?objects?that?do?not?implement?readinto?on?Python?3.8.0?and?3.8.1.?"
          ????????????f"Received?object?of?type?\"{type(f)}\".?Please?update?to?Python?3.8.2?or?newer?to?restore?this?"
          ????????????"functionality.")

          ????magic_number?=?pickle_module.load(f,?**pickle_load_args)
          ????if?magic_number?!=?MAGIC_NUMBER:?#?檢查MAGIC_NUMBER是否一致
          ????????raise?RuntimeError("Invalid?magic?number;?corrupt?file?")
          ????protocol_version?=?pickle_module.load(f,?**pickle_load_args)?
          ????if?protocol_version?!=?PROTOCOL_VERSION:?#?檢查pickle協(xié)議是否一致
          ????????raise?RuntimeError("Invalid?protocol?version:?%s"?%?protocol_version)

          ????_sys_info?=?pickle_module.load(f,?**pickle_load_args)?#?讀取一些系統(tǒng)信息
          ????unpickler?=?pickle_module.Unpickler(f,?**pickle_load_args)
          ????unpickler.persistent_load?=?persistent_load?
          ????result?=?unpickler.load()?#?調(diào)用persistent_load()函數(shù)讀取對(duì)象的結(jié)構(gòu)信息,注意此時(shí)還未讀取具體的數(shù)據(jù)

          ????deserialized_storage_keys?=?pickle_module.load(f,?**pickle_load_args)?#?讀取數(shù)據(jù)對(duì)應(yīng)的key,到這里可以發(fā)現(xiàn)pickle_module.load()
          ??????????????????????????????????????????????????????????????????????????#?出的結(jié)果和上面`_legacy_save()`函數(shù)中dump的內(nèi)容一一對(duì)應(yīng)

          ????offset?=?f.tell()?if?f_should_read_directly?else?None
          ????for?key?in?deserialized_storage_keys:?#?讀取具體的數(shù)據(jù)
          ????????assert?key?in?deserialized_objects
          ????????deserialized_objects[key]._set_from_file(f,?offset,
          ?????????????????????????????????????????????????f_should_read_directly)
          ????????if?offset?is?not?None:
          ????????????offset?=?f.tell()

          ????torch._utils._validate_loaded_sparse_tensors()

          ????return?result

          load()_legacy_load()中都有_get_restore_location()函數(shù)生成restore_location(obj,location)函數(shù),它決定將讀取的對(duì)象(obj)放到 CPU or CUDA (location)上,接下來我們介紹_get_restore_location()

          def?_cpu_deserialize(obj,?location):?#?將對(duì)象放到cpu上,注意可能返回None
          ????if?location?==?'cpu':
          ????????return?obj


          def?_cuda_deserialize(obj,?location):?#?將對(duì)象放到指定的cuda?device上,注意可能返回None
          ????if?location.startswith('cuda'):
          ????????device?=?validate_cuda_device(location)?#?驗(yàn)證是否有顯卡,以及給定的device?id是否超過當(dāng)前機(jī)器擁有的顯卡數(shù)量
          ????????if?getattr(obj,?"_torch_load_uninitialized",?False):
          ????????????storage_type?=?getattr(torch.cuda,?type(obj).__name__)
          ????????????with?torch.cuda.device(device):
          ????????????????return?storage_type(obj.size())
          ????????else:
          ????????????return?obj.cuda(device)

          _package_registry?=?[]

          def?register_package(priority,?tagger,?deserializer):
          ????queue_elem?=?(priority,?tagger,?deserializer)
          ????_package_registry.append(queue_elem)
          ????_package_registry.sort()

          register_package(10,?_cpu_tag,?_cpu_deserialize)
          register_package(20,?_cuda_tag,?_cuda_deserialize)


          def?default_restore_location(storage,?location):?#?按先cpu后cuda的優(yōu)先級(jí)將數(shù)據(jù)放入cpu或cuda上
          ????for?_,?_,?fn?in?_package_registry:
          ????????result?=?fn(storage,?location)
          ????????if?result?is?not?None:
          ????????????return?result
          ????raise?RuntimeError("don't?know?how?to?restore?data?location?of?"?+
          ???????????????????????torch.typename(storage)?+?"?(tagged?with?"?+?location?+
          ???????????????????????")")


          def?_get_restore_location(map_location):
          ????if?map_location?is?None:
          ????????restore_location?=?default_restore_location?#?map_location?=?None?:?放到location記錄的cpu?or?cuda上
          ????elif?isinstance(map_location,?dict):

          ????????def?restore_location(storage,?location):?# map_location =?{'cpu':?'cuda:0'}?:?如果location是'cpu',則放到'cuda:0'上;
          ?????????????????????????????????????????????????#?否則仍放到'cpu'上
          ????????????location?=?map_location.get(location,?location)
          ????????????return?default_restore_location(storage,?location)
          ????elif?isinstance(map_location,?_string_classes):?#?map_location?=?'cuda:0'?:?不管location是什么,都放到'cuda:0'上

          ????????def?restore_location(storage,?location):
          ????????????return?default_restore_location(storage,?map_location)
          ????elif?isinstance(map_location,?torch.device):

          ????????def?restore_location(storage,?location):?#?map_location?=?torch.device('cpu')?:?不管location是什么,都放到'cpu'上
          ????????????return?default_restore_location(storage,?str(map_location))
          ????else:

          ????????def?restore_location(storage,?location):?#?可以替換default_restore_location函數(shù),map_location是一個(gè)函數(shù)
          ?????????????????????????????????????????????????#?比如?map_location?=?lambda?storage,?location:?storage.cuda(1)?表示
          ?????????????????????????????????????????????????#?不管location是什么,都放到'cuda:1'上
          ????????????result?=?map_location(storage,?location)
          ????????????if?result?is?None:
          ????????????????result?=?default_restore_location(storage,?location)
          ????????????return?result

          ????return?restore_location

          以上是torch.serialization的源碼分析,torch.serialization主要包含torch.save(),torch.load()函數(shù),其中torch.save()主要通過調(diào)用_save()or_legacy_save()實(shí)現(xiàn),torch.load()主要通過調(diào)用_load()or_legacy_load().torch.load()中的map_location參數(shù)通過_get_restore_location()函數(shù)決定將對(duì)象反序列化到 CPU 還是 CUDA 上。

          torch.hub

          torch.hub 提供了一系列 pretrained models 來方便大家使用,我們以https:// github.com/ pytorch/ vision為例,介紹怎樣使用 torch.hub 提供的接口來調(diào)用 torchvision 里的 model。

          torch.hub主要提供了三個(gè)接口torch.hub.list(),torch.hub.help(),torch.hub.load(),我們依次介紹。

          torch.hub.list() 會(huì)從給定的 GitHub repo 中尋找 hubconf.py(此文件導(dǎo)入 repo 里提供的所有 models),然后返回一個(gè) list,里面包含了提供的 model 類名。https://github.com/pytorch/vision下的 hubconf.py 文件內(nèi)容如下:

          #?Optional?list?of?dependencies?required?by?the?package
          dependencies?=?['torch']

          #?classification
          from?torchvision.models.alexnet?import?alexnet
          from?torchvision.models.densenet?import?densenet121,?densenet169,?densenet201,?densenet161
          from?torchvision.models.inception?import?inception_v3
          from?torchvision.models.resnet?import?resnet18,?resnet34,?resnet50,?resnet101,?resnet152,\
          ????resnext50_32x4d,?resnext101_32x8d,?wide_resnet50_2,?wide_resnet101_2
          from?torchvision.models.squeezenet?import?squeezenet1_0,?squeezenet1_1
          from?torchvision.models.vgg?import?vgg11,?vgg13,?vgg16,?vgg19,?vgg11_bn,?vgg13_bn,?vgg16_bn,?vgg19_bn
          from?torchvision.models.googlenet?import?googlenet
          from?torchvision.models.shufflenetv2?import?shufflenet_v2_x0_5,?shufflenet_v2_x1_0
          from?torchvision.models.mobilenetv2?import?mobilenet_v2
          from?torchvision.models.mobilenetv3?import?mobilenet_v3_large,?mobilenet_v3_small
          from?torchvision.models.mnasnet?import?mnasnet0_5,?mnasnet0_75,?mnasnet1_0,?\
          ????mnasnet1_3

          #?segmentation
          from?torchvision.models.segmentation?import?fcn_resnet50,?fcn_resnet101,?\
          ????deeplabv3_resnet50,?deeplabv3_resnet101,?deeplabv3_mobilenet_v3_large,?lraspp_mobilenet_v3_large

          接下來我們分析torch.hub.list()代碼:

          def?list(github,?# repo的名字,比如`pytorch/vision`。注意沒有前綴`https://github.com/`,代碼里hard code進(jìn)去了
          ?????????force_reload=False)
          :
          ?#?是否要重新下載?repo
          ????repo_dir?=?_get_cache_or_reload(github,?force_reload,?True)?#?根據(jù)repo的地址下載到本地,然后返回下載到本地的repo的路徑
          ????????????????????????????????????????????????????????????????#?可以通過torch.hub.get_dir()得到下載的根目錄,提前通過
          ????????????????????????????????????????????????????????????????#?torch.hub.set_dir(string)設(shè)置下載的根目錄

          ????sys.path.insert(0,?repo_dir)?#?本地的repo路徑加入到搜索路徑中,優(yōu)先級(jí)最高

          ????hub_module?=?import_module(MODULE_HUBCONF,?repo_dir?+?'/'?+?MODULE_HUBCONF)?#?MODULE_HUBCONF?=?'hubconf.py',
          ????????????????????????????????????????????????????????????????????????????????#?從本地的repo中找到'hubconf.py',
          ????????????????????????????????????????????????????????????????????????????????#?并解析得到'hubconf.py'里提供的所有module

          ????sys.path.remove(repo_dir)?#?從搜索路徑中刪除本地repo路徑

          ????#?We?take?functions?starts?with?'_'?as?internal?helper?functions
          ????entrypoints?=?[f?for?f?in?dir(hub_module)?if?callable(getattr(hub_module,?f))?and?not?f.startswith('_')]?#?注意這里類名的開頭
          ?????????????????????????????????????????????????????????????????????????????????????????????????????????????#?如果是'_'的話,
          ?????????????????????????????????????????????????????????????????????????????????????????????????????????????#?將會(huì)被濾掉

          ????return?entrypoints?#?list(string),?repo提供的model類名

          #?An?example:
          print(torch.hub.list('pytorch/vision',?True))

          #?print?info:
          '''
          ['alexnet',?'deeplabv3_mobilenet_v3_large',?'deeplabv3_resnet101',?'deeplabv3_resnet50',?'densenet121',?'densenet161',
          'densenet169',?'densenet201',?'fcn_resnet101',?'fcn_resnet50',?'googlenet',?'inception_v3',?'lraspp_mobilenet_v3_large',
          'mnasnet0_5',?'mnasnet0_75',?'mnasnet1_0',?'mnasnet1_3',?'mobilenet_v2',?'mobilenet_v3_large',?'mobilenet_v3_small',
          'resnet101',?'resnet152',?'resnet18',?'resnet34',?'resnet50',?'resnext101_32x8d',?'resnext50_32x4d',?'shufflenet_v2_x0_5',
          'shufflenet_v2_x1_0',?'squeezenet1_0',?'squeezenet1_1',?'vgg11',?'vgg11_bn',?'vgg13',?'vgg13_bn',?'vgg16',?'vgg16_bn',
          'vgg19',?'vgg19_bn',?'wide_resnet101_2',?'wide_resnet50_2']
          '''

          torch.hub.help()會(huì)返回給定 repo 下給定 module 的文檔:

          def?help(github,
          ?????????model,?#?module的名字,比如'resnet50'
          ?????????force_reload=False)
          :

          ????repo_dir?=?_get_cache_or_reload(github,?force_reload,?True)

          ????sys.path.insert(0,?repo_dir)

          ????hub_module?=?import_module(MODULE_HUBCONF,?repo_dir?+?'/'?+?MODULE_HUBCONF)

          ????sys.path.remove(repo_dir)

          ????entry?=?_load_entry_from_hubconf(hub_module,?model)?#?從所有modules(hub_module)里找到給定module(model)

          ????return?entry.__doc__?#?返回`model`的文檔

          #?An?example:
          print(torch.hub.help('pytorch/vision',?'resnet18',?True))

          #?print?info:
          '''
          ResNet-18?model?from
          ????`"Deep?Residual?Learning?for?Image?Recognition"?`_.

          ????Args:
          ????????pretrained?(bool):?If?True,?returns?a?model?pre-trained?on?ImageNet
          ????????progress?(bool):?If?True,?displays?a?progress?bar?of?the?download?to?stderr
          '''

          torch.hub.load()會(huì)返回實(shí)例化后的 module:

          def?load(repo_or_dir,?#?本地的路徑,或者github上的repo名
          ?????????model,
          ?????????*args,?#?用于實(shí)例化?module
          ?????????**kwargs)
          :
          ?#?用于實(shí)例化?module
          ????source?=?kwargs.pop('source',?'github').lower()?#?repo?默認(rèn)從github上尋找
          ????force_reload?=?kwargs.pop('force_reload',?False)
          ????verbose?=?kwargs.pop('verbose',?True)?#?如果True,打印一些log

          ????if?source?not?in?('github',?'local'):?#?要么從github上找repo,要么從本地找repo
          ????????raise?ValueError(
          ????????????f'Unknown?source:?"{source}".?Allowed?values:?"github"?|?"local".')

          ????if?source?==?'github':
          ????????repo_or_dir?=?_get_cache_or_reload(repo_or_dir,?force_reload,?verbose)

          ????model?=?_load_local(repo_or_dir,?model,?*args,?**kwargs)
          ????return?model


          def?_load_local(hubconf_dir,?model,?*args,?**kwargs):
          ????sys.path.insert(0,?hubconf_dir)

          ????hubconf_path?=?os.path.join(hubconf_dir,?MODULE_HUBCONF)
          ????hub_module?=?import_module(MODULE_HUBCONF,?hubconf_path)

          ????entry?=?_load_entry_from_hubconf(hub_module,?model)?#?找到指定的module
          ????model?=?entry(*args,?**kwargs)?#?實(shí)例化?module

          ????sys.path.remove(hubconf_dir)

          ????return?model


          #?An?example:
          resnet18?=?torch.hub.load('pytorch/vision',?'resnet18',?pretrained=True)?#?載入預(yù)訓(xùn)練權(quán)重

          以上以 pytorch/vision 為例介紹了torch.hub的使用。實(shí)際上只要一個(gè) GitHub repo 里有 hubconf.py 的文件,都可以使用 torch.hub 提供的接口,比如一個(gè)簡(jiǎn)單的例子?。

          原文鏈接:https://zhuanlan.zhihu.com/p/364239544

          如果覺得有用,就請(qǐng)分享到朋友圈吧!


          瀏覽 108
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  在线伊人成人网 | 久久久久久蜜桃 | 日韩免费网址 | 国产乱伦三级片导航 | 日本爱爱视频一区 |