<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          在OneFlow實現(xiàn)數(shù)據(jù)類型自動提升

          共 6287字,需瀏覽 13分鐘

           ·

          2021-10-09 20:59

          問題引入

          我們先簡單看下在pytorch下的這幾段代碼,讀者可以猜下最后輸出的類型是什么:

          x_tensor?=?torch.ones((3,?),?dtype=torch.int8)
          y1_tensor?=?torch.tensor(1,?dtype=torch.float64)
          out1?=?torch.mul(x_tensor,?y1_tensor)

          y2_tensor?=?torch.tensor(1,?dtype=torch.int64)
          out2?=?torch.mul(x_tensor,?y2_tensor)

          out3?=?torch.mul(x_tensor,?1.0)

          out4?=?torch.mul(x_tensor,?2^63-1(the?max?value?of?int64))

          接下來揭曉答案:

          out1.dtype:?torch.float64
          out2.dtype:?torch.int8
          out3.dtype:?torch.float32
          out4.dtype:?torch.int8

          可以觀察到同樣是multiply運算,有些結(jié)果的數(shù)據(jù)類型被提升到更高的一級,有些并沒有被提升,還維持著int8類型。這其實是一種類型提升系統(tǒng),系統(tǒng)內(nèi)會自定義一些類型提升的規(guī)則,根據(jù)輸入的數(shù)據(jù)類型來推導(dǎo)最終結(jié)果的數(shù)據(jù)類型。

          Python Array API標準

          參考鏈接:https://data-apis.org/array-api/latest/API_specification/type_promotion.html

          在這里我們可以了解到Python Array的類型提升規(guī)則

          類型提升

          從上圖可以看到:

          • 不同數(shù)據(jù)類型的提升遵循這個連接的規(guī)則
          • 虛線表示python標量在溢出的時候未定義
          • bool int float之間沒有連線,表示這種混合類型的提升未定義

          關(guān)于第一條,我們可以看int8uint8,兩者最終指向了int16,表示兩者運算后最終類型提升到了int16

          而根據(jù)這一個規(guī)則,我們可以列出一個類型提升表格(這個表格很重要,后續(xù)看Pytorch源碼也會用到)

          unsigned int系列和signed int系列為例,列出的表格為:

          更多類型提升規(guī)則表格可參考前面提到的鏈接

          橫坐標和縱坐標分別代表輸入的數(shù)據(jù)類型,表格的值代表類型提升后的數(shù)據(jù)類型,其中:

          • i1 : 8-bit signed integer (i.e., int8 )
          • i2 : 16-bit signed integer (i.e., int16 )
          • i4 : 32-bit signed integer (i.e., int32 )
          • i8 : 64-bit signed integer (i.e., int64 )

          同理于unsigned int

          Python Array 和 Scalar 的類型提升

          上述這些都是array與array之間運算的類型提升規(guī)則,而array與scalar(就是單獨一個int,float數(shù)值)的類型提升規(guī)則則不一樣。

          • 如果兩者同屬于一個數(shù)據(jù)類型系列(比如都是int系列,包含int8, int32, int64),則最終數(shù)據(jù)類型遵循數(shù)組的數(shù)據(jù)類型
          • 如果兩者同不屬于一個數(shù)據(jù)類型系列(比如一個是int32,一個是float),則進行類型提升

          我們可以看下簡單的兩個例子:

          x_tensor?=?torch.ones((3,?),?dtype=torch.int16)
          out1?=?x_tensor?+?2?#?out.dtype?=?torch.int16
          out2?=?x_tensor?+?2.0?#?out.dtype?=?torch.float32

          需要注意的是,Array與Scalar的行為會和Array與0d Array的行為保持一致。

          我們可以再測試前面兩個例子,不同之處是我們將scalar改成一個0d Array

          x_tensor?=?torch.ones((3,?),?dtype=torch.int16)
          y1_tensor?=?torch.tensor(2)
          y2_tensor?=?torch.tensor(2.0)

          out1?=?x_tensor?+?y1_tensor?#?out.dtype?=?torch.int16
          out2?=?x_tensor?+?y2_tensor?#?out.dtype?=?torch.float32

          關(guān)于與Scalar運算的行為,Pytorch是和Python Array API標準一致的,但是Numpy則不同,他會根據(jù)scalar的數(shù)據(jù)范圍做一個合理的類型提升

          import?numpy?as?np

          x?=?np.ones((3,?3),?dtype=np.int32)
          out?=?x?+?(2**31-1)?#?dtype:?np.int32
          out?=?x?+?(2**31)?#?dtype:?np.int64

          我個人更傾向于在類型提升中,Scalar是單獨一種行為,而Scalar Tensor和Tensor的行為一致

          其他情況

          除了前面提到的規(guī)則,Pytorch還存在以下兩種情況:

          1. 要求兩個輸入的數(shù)據(jù)類型完全一致,如torch.dot
          RuntimeError:?dot?:?expected?both?vectors?to?have?same?dtype,?but?found?Short?and?Float
          1. 輸入存在一個最低數(shù)據(jù)類型,比如torch.sum,傳任意int系列數(shù)據(jù)類型,最終輸出結(jié)果均為torch.int64。

          以上就簡單介紹了Pytorch的類型提升規(guī)則,還想要更多的例子可以參考官方文檔:https://pytorch.org/docs/master/tensor_attributes.html#torch.torch.dtype

          Pytorch是怎么做類型提升的?

          實際運算的Kernel,輸入和輸出的數(shù)據(jù)類型都是相同的模板參數(shù),不存在特化一個輸入為int32,輸出為float32或其他類型的函數(shù)。

          因此Pytorch內(nèi)部會先推斷出一個合理的dtype,然后插入一個to這個op,將輸入tensor進行類型提升,再進入到Kernel進行實際的運算。下面我們會根據(jù)Pytorch的源碼進行講解:

          涉及到的代碼:https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h

          https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp

          https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/TensorIterator.cpp

          https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TypeProperties.cpp

          ScalarType.h

          在這個頭文件里定義了相關(guān)的數(shù)據(jù)類型,并且定義了一個類型提升的二維矩陣,這樣我們就可以輸入兩個數(shù)據(jù)類型,根據(jù)索引拿到提升后的數(shù)據(jù)類型。

          類型提升矩陣

          Activation.cpp

          https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp#L24 我們以其中一個激活函數(shù)threshold為例子

          TORCH_META_FUNC(threshold)(const?Tensor&?self,?const?Scalar&?threshold,?const?Scalar&?value)?{
          ??const?Tensor&?result?=?maybe_get_output();
          ??build(TensorIteratorConfig()
          ????...
          ????.promote_inputs_to_common_dtype(true)
          }

          這里調(diào)用了一個build函數(shù),函數(shù)接受一個TensorIteratorConfig,這個Config類是用于配制各種屬性,可以看到這里調(diào)用promote_inputs_to_common_dtype并設(shè)為true。

          TensorIterator.cpp

          build函數(shù)定義在:

          https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/TensorIterator.cpp#L1321

          在1340行,build函數(shù)內(nèi)部調(diào)用了compute_type函數(shù)

          ...
          compute_types(config);
          ...

          而該函數(shù)在260行開始,進行一系列類型推導(dǎo)

          其中TensorIterator是一個容器類(Numpy里也有一個類似的容器NpyIter),用于存儲輸出,輸入tensor,里面用了多個for循環(huán)來推導(dǎo)得到一個common_dtype。

          并在最后進行條件判斷:promote_inputs_to_common_dtype_為true,當(dāng)前Tensor不是輸出Tensor,且輸入的dtype不等于推導(dǎo)得到的common_dtype,則做一個類型提升:

          ??????//?Promotes?inputs?by?creating?temporaries?of?the?correct?dtype
          ??????if?(config.promote_inputs_to_common_dtype_?&&?!op.is_output?&&?op.current_dtype?!=?common_dtype_)?{
          ????????op.original_tensor?=?op.tensor;
          ????????op.tensor?=?c10::MaybeOwned::owned(op.tensor->to(common_dtype_));
          ????????op.current_dtype?=?common_dtype_;
          ????????op.target_dtype?=?common_dtype_;
          ??????}

          OneFlow的做法

          相關(guān)PR:https://github.com/Oneflow-Inc/oneflow/pull/6380

          OneFlow則將類型提升的邏輯放在c++中functional前端部分,類似的我們設(shè)計了一個TensorProcessor類,接口設(shè)計如下:

          class?TensorProcessor?final?{
          ?public:
          ??TensorProcessor()
          ??????:?common_dtype_(DType::InvalidDataType()),?promote_inputs_to_common_dtype_(false){};
          ??TensorProcessor&?AddInputs(const?TensorTuple&?init_list);
          ??TensorProcessor&?AddInputs(const?TensorTuple&?init_list,?Symbol?tensor_lowest_dtype);

          ??Maybe<void>?Apply();
          ??TensorProcessor&?PromoteInputsToCommonDtype(bool?is_promote);
          ??Maybe?GetInputs()?{?return?tensor_tuple_;?};

          ?private:
          ??TensorTuple?tensor_tuple_;
          ??Symbol?common_dtype_;
          ??std::vector>?inputs_lowest_dtype_vec_;

          ??bool?promote_inputs_to_common_dtype_;
          };

          以二元操作Functor基類為例,在實際調(diào)用的時候,我們可以這樣:

          class?BinaryFunctor{
          ?public:
          ??Maybe?operator()(const?std::shared_ptr&?x,
          ???????????????????????????const?std::shared_ptr&?y)
          ?const?
          {
          ????TensorProcessor?tensor_processor;
          ????JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x,?y}).Apply());
          ????TensorTuple?input_tuple?=?JUST(tensor_processor.GetInputs());
          ????return?OpInterpUtil::Dispatch(*op_,?input_tuple);
          ??...
          ??}
          ??...
          };?
          • PromoteInputsToCommonDtype 用于設(shè)置相關(guān)屬性
          • AddInputs函數(shù)將需要參與類型提升的Tensor添加到容器中
          • Apply函數(shù)執(zhí)行實際的類型提升等邏輯

          tensor_processor.cpp還有其他幾個函數(shù),這里簡單介紹下功能:

          • CheckHasDifferentInputDType 遍歷輸入Tensor,檢查輸入Tensor是否有不同的dtype
          • ComputeCommonDType 根據(jù)輸入dtype推導(dǎo)一個合理的提升過的dtype
          • CastToSameType 給輸入Tensor插入一個Cast操作
          Maybe<void>?CastToSameType(TensorTuple&?tensor_tuple,?const?Symbol&?common_dtype)?{
          ??for?(auto&?tensor_ptr?:?tensor_tuple)?{
          ????if?(tensor_ptr->dtype()?!=?common_dtype)?{
          ??????tensor_ptr?=?JUST(functional::Cast(tensor_ptr,?common_dtype));
          ????}
          ??}
          ??return?Maybe<void>::Ok();
          }

          Apply函數(shù)邏輯如下:

          Maybe<void>?TensorProcessor::Apply()?{
          ??if?(promote_inputs_to_common_dtype_)?{
          ????bool?has_different_input_dtype?=?CheckHasDifferentInputDType(tensor_tuple_);
          ????if?(has_different_input_dtype)?{
          ??????common_dtype_?=?ComputeCommonDType(tensor_tuple_);
          ??????JUST(CastToSameType(tensor_tuple_,?common_dtype_));
          ????}
          ??}?else?{
          ????for?(int?i?=?0;?i???????//?Cast?all?the?inputs?to?it's?attribute?`lowest_dtype`?if?the?input?tensor?dtype?is?lower
          ??????//?than?attribute?`lowest_dtype`.
          ??????Symbol?base_dtype?=?inputs_lowest_dtype_vec_.at(i);
          ??????if?(base_dtype->data_type()
          ??????????&&?DType::priority_order[base_dtype->data_type()]
          ?????????????????>?DType::priority_order[tensor_tuple_.at(i)->dtype()->data_type()])?{
          ????????tensor_tuple_.at(i)?=?JUST(one::functional::Cast(tensor_tuple_.at(i),?base_dtype));
          ??????}
          ????}
          ??}
          ??return?Maybe<void>::Ok();
          }

          if內(nèi)執(zhí)行的是類型提升,而else內(nèi)邏輯則是對應(yīng)前面提到的其他情況中的第二條,將Tensor類型提升到設(shè)定好的一個最低數(shù)據(jù)類型。還是sum算子,我們設(shè)定最低數(shù)據(jù)類型為int64是這么做的:

          class?ReduceSumFunctor{
          public:?
          ??Maybe?operator()(const?std::shared_ptr&?x,?const?std::vector<int32_t>&?axis,
          ???????????????????????????const?bool&?keepdims)
          ?const?
          {
          ????...
          ????TensorProcessor?tensor_processor;
          ????JUST(tensor_processor.AddInputs({x},?/*lowest_dtype=*/DType::Int64()).Apply());
          ????TensorTuple?input_tuple?=?JUST(tensor_processor.GetInputs());
          ??}
          ??...
          };?

          總結(jié)

          類型提升是一個我們不經(jīng)意間會使用的一個操作,如果沒有正確處理輸出的數(shù)據(jù)類型,則可能導(dǎo)致結(jié)果溢出,出現(xiàn)錯誤的結(jié)果??此坪芎唵危珜嶋H調(diào)研+推敲細節(jié)也搞了兩三周,最后感謝同事在我完成這個功能的期間提供的許多幫助!


          點擊下方原文鏈接直達OneFlow倉庫,歡迎關(guān)注。

          瀏覽 52
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  亚洲无码成人片 | 狠狠躁日日躁夜夜躁A片无码视频 | 99精品成人免费毛片无码 | 韩国无码精品 | 青青草国产精品久久久久婷婷 |