在OneFlow實現(xiàn)數(shù)據(jù)類型自動提升
問題引入
我們先簡單看下在pytorch下的這幾段代碼,讀者可以猜下最后輸出的類型是什么:
x_tensor?=?torch.ones((3,?),?dtype=torch.int8)
y1_tensor?=?torch.tensor(1,?dtype=torch.float64)
out1?=?torch.mul(x_tensor,?y1_tensor)
y2_tensor?=?torch.tensor(1,?dtype=torch.int64)
out2?=?torch.mul(x_tensor,?y2_tensor)
out3?=?torch.mul(x_tensor,?1.0)
out4?=?torch.mul(x_tensor,?2^63-1(the?max?value?of?int64))
接下來揭曉答案:
out1.dtype:?torch.float64
out2.dtype:?torch.int8
out3.dtype:?torch.float32
out4.dtype:?torch.int8
可以觀察到同樣是multiply運算,有些結(jié)果的數(shù)據(jù)類型被提升到更高的一級,有些并沒有被提升,還維持著int8類型。這其實是一種類型提升系統(tǒng),系統(tǒng)內(nèi)會自定義一些類型提升的規(guī)則,根據(jù)輸入的數(shù)據(jù)類型來推導(dǎo)最終結(jié)果的數(shù)據(jù)類型。
Python Array API標準
參考鏈接:https://data-apis.org/array-api/latest/API_specification/type_promotion.html
在這里我們可以了解到Python Array的類型提升規(guī)則

從上圖可以看到:
不同數(shù)據(jù)類型的提升遵循這個連接的規(guī)則 虛線表示python標量在溢出的時候未定義 bool int float之間沒有連線,表示這種混合類型的提升未定義
關(guān)于第一條,我們可以看int8和uint8,兩者最終指向了int16,表示兩者運算后最終類型提升到了int16
而根據(jù)這一個規(guī)則,我們可以列出一個類型提升表格(這個表格很重要,后續(xù)看Pytorch源碼也會用到)
以unsigned int系列和signed int系列為例,列出的表格為:
更多類型提升規(guī)則表格可參考前面提到的鏈接
橫坐標和縱坐標分別代表輸入的數(shù)據(jù)類型,表格的值代表類型提升后的數(shù)據(jù)類型,其中:
i1 : 8-bit signed integer (i.e., int8 ) i2 : 16-bit signed integer (i.e., int16 ) i4 : 32-bit signed integer (i.e., int32 ) i8 : 64-bit signed integer (i.e., int64 )
同理于unsigned int
Python Array 和 Scalar 的類型提升
上述這些都是array與array之間運算的類型提升規(guī)則,而array與scalar(就是單獨一個int,float數(shù)值)的類型提升規(guī)則則不一樣。
如果兩者同屬于一個數(shù)據(jù)類型系列(比如都是int系列,包含int8, int32, int64),則最終數(shù)據(jù)類型遵循數(shù)組的數(shù)據(jù)類型 如果兩者同不屬于一個數(shù)據(jù)類型系列(比如一個是int32,一個是float),則進行類型提升
我們可以看下簡單的兩個例子:
x_tensor?=?torch.ones((3,?),?dtype=torch.int16)
out1?=?x_tensor?+?2?#?out.dtype?=?torch.int16
out2?=?x_tensor?+?2.0?#?out.dtype?=?torch.float32
需要注意的是,Array與Scalar的行為會和Array與0d Array的行為保持一致。
我們可以再測試前面兩個例子,不同之處是我們將scalar改成一個0d Array
x_tensor?=?torch.ones((3,?),?dtype=torch.int16)
y1_tensor?=?torch.tensor(2)
y2_tensor?=?torch.tensor(2.0)
out1?=?x_tensor?+?y1_tensor?#?out.dtype?=?torch.int16
out2?=?x_tensor?+?y2_tensor?#?out.dtype?=?torch.float32
關(guān)于與Scalar運算的行為,Pytorch是和Python Array API標準一致的,但是Numpy則不同,他會根據(jù)scalar的數(shù)據(jù)范圍做一個合理的類型提升:
import?numpy?as?np
x?=?np.ones((3,?3),?dtype=np.int32)
out?=?x?+?(2**31-1)?#?dtype:?np.int32
out?=?x?+?(2**31)?#?dtype:?np.int64
我個人更傾向于在類型提升中,Scalar是單獨一種行為,而Scalar Tensor和Tensor的行為一致
其他情況
除了前面提到的規(guī)則,Pytorch還存在以下兩種情況:
要求兩個輸入的數(shù)據(jù)類型完全一致,如 torch.dot
RuntimeError:?dot?:?expected?both?vectors?to?have?same?dtype,?but?found?Short?and?Float
輸入存在一個最低數(shù)據(jù)類型,比如 torch.sum,傳任意int系列數(shù)據(jù)類型,最終輸出結(jié)果均為torch.int64。
以上就簡單介紹了Pytorch的類型提升規(guī)則,還想要更多的例子可以參考官方文檔:https://pytorch.org/docs/master/tensor_attributes.html#torch.torch.dtype
Pytorch是怎么做類型提升的?
實際運算的Kernel,輸入和輸出的數(shù)據(jù)類型都是相同的模板參數(shù),不存在特化一個輸入為int32,輸出為float32或其他類型的函數(shù)。
因此Pytorch內(nèi)部會先推斷出一個合理的dtype,然后插入一個to這個op,將輸入tensor進行類型提升,再進入到Kernel進行實際的運算。下面我們會根據(jù)Pytorch的源碼進行講解:
涉及到的代碼:https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/TensorIterator.cpp
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TypeProperties.cpp
ScalarType.h
在這個頭文件里定義了相關(guān)的數(shù)據(jù)類型,并且定義了一個類型提升的二維矩陣,這樣我們就可以輸入兩個數(shù)據(jù)類型,根據(jù)索引拿到提升后的數(shù)據(jù)類型。

Activation.cpp
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp#L24
我們以其中一個激活函數(shù)threshold為例子
TORCH_META_FUNC(threshold)(const?Tensor&?self,?const?Scalar&?threshold,?const?Scalar&?value)?{
??const?Tensor&?result?=?maybe_get_output();
??build(TensorIteratorConfig()
????...
????.promote_inputs_to_common_dtype(true)
}
這里調(diào)用了一個build函數(shù),函數(shù)接受一個TensorIteratorConfig,這個Config類是用于配制各種屬性,可以看到這里調(diào)用promote_inputs_to_common_dtype并設(shè)為true。
TensorIterator.cpp
build函數(shù)定義在:
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/TensorIterator.cpp#L1321
在1340行,build函數(shù)內(nèi)部調(diào)用了compute_type函數(shù)
...
compute_types(config);
...
而該函數(shù)在260行開始,進行一系列類型推導(dǎo)
其中TensorIterator是一個容器類(Numpy里也有一個類似的容器NpyIter),用于存儲輸出,輸入tensor,里面用了多個for循環(huán)來推導(dǎo)得到一個common_dtype。
并在最后進行條件判斷:promote_inputs_to_common_dtype_為true,當(dāng)前Tensor不是輸出Tensor,且輸入的dtype不等于推導(dǎo)得到的common_dtype,則做一個類型提升:
??????//?Promotes?inputs?by?creating?temporaries?of?the?correct?dtype
??????if?(config.promote_inputs_to_common_dtype_?&&?!op.is_output?&&?op.current_dtype?!=?common_dtype_)?{
????????op.original_tensor?=?op.tensor;
????????op.tensor?=?c10::MaybeOwned::owned(op.tensor->to(common_dtype_));
????????op.current_dtype?=?common_dtype_;
????????op.target_dtype?=?common_dtype_;
??????}
OneFlow的做法
相關(guān)PR:https://github.com/Oneflow-Inc/oneflow/pull/6380
OneFlow則將類型提升的邏輯放在c++中functional前端部分,類似的我們設(shè)計了一個TensorProcessor類,接口設(shè)計如下:
class?TensorProcessor?final?{
?public:
??TensorProcessor()
??????:?common_dtype_(DType::InvalidDataType()),?promote_inputs_to_common_dtype_(false){};
??TensorProcessor&?AddInputs(const?TensorTuple&?init_list);
??TensorProcessor&?AddInputs(const?TensorTuple&?init_list,?Symbol?tensor_lowest_dtype) ;
??Maybe<void>?Apply();
??TensorProcessor&?PromoteInputsToCommonDtype(bool?is_promote);
??Maybe?GetInputs()? {?return?tensor_tuple_;?};
?private:
??TensorTuple?tensor_tuple_;
??Symbol?common_dtype_;
??std::vector>?inputs_lowest_dtype_vec_;
??bool?promote_inputs_to_common_dtype_;
};
以二元操作Functor基類為例,在實際調(diào)用的時候,我們可以這樣:
class?BinaryFunctor{
?public:
??Maybe?operator()(const?std::shared_ptr&?x,
???????????????????????????const?std::shared_ptr&?y) ?const? {
????TensorProcessor?tensor_processor;
????JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x,?y}).Apply());
????TensorTuple?input_tuple?=?JUST(tensor_processor.GetInputs());
????return?OpInterpUtil::Dispatch(*op_,?input_tuple);
??...
??}
??...
};?
PromoteInputsToCommonDtype 用于設(shè)置相關(guān)屬性 AddInputs函數(shù)將需要參與類型提升的Tensor添加到容器中 Apply函數(shù)執(zhí)行實際的類型提升等邏輯
tensor_processor.cpp還有其他幾個函數(shù),這里簡單介紹下功能:
CheckHasDifferentInputDType 遍歷輸入Tensor,檢查輸入Tensor是否有不同的dtype ComputeCommonDType 根據(jù)輸入dtype推導(dǎo)一個合理的提升過的dtype CastToSameType 給輸入Tensor插入一個Cast操作
Maybe<void>?CastToSameType(TensorTuple&?tensor_tuple,?const?Symbol&?common_dtype) ?{
??for?(auto&?tensor_ptr?:?tensor_tuple)?{
????if?(tensor_ptr->dtype()?!=?common_dtype)?{
??????tensor_ptr?=?JUST(functional::Cast(tensor_ptr,?common_dtype));
????}
??}
??return?Maybe<void>::Ok();
}
Apply函數(shù)邏輯如下:
Maybe<void>?TensorProcessor::Apply()?{
??if?(promote_inputs_to_common_dtype_)?{
????bool?has_different_input_dtype?=?CheckHasDifferentInputDType(tensor_tuple_);
????if?(has_different_input_dtype)?{
??????common_dtype_?=?ComputeCommonDType(tensor_tuple_);
??????JUST(CastToSameType(tensor_tuple_,?common_dtype_));
????}
??}?else?{
????for?(int?i?=?0;?i???????//?Cast?all?the?inputs?to?it's?attribute?`lowest_dtype`?if?the?input?tensor?dtype?is?lower
??????//?than?attribute?`lowest_dtype`.
??????Symbol?base_dtype?=?inputs_lowest_dtype_vec_.at(i);
??????if?(base_dtype->data_type()
??????????&&?DType::priority_order[base_dtype->data_type()]
?????????????????>?DType::priority_order[tensor_tuple_.at(i)->dtype()->data_type()])?{
????????tensor_tuple_.at(i)?=?JUST(one::functional::Cast(tensor_tuple_.at(i),?base_dtype));
??????}
????}
??}
??return?Maybe<void>::Ok();
}
if內(nèi)執(zhí)行的是類型提升,而else內(nèi)邏輯則是對應(yīng)前面提到的其他情況中的第二條,將Tensor類型提升到設(shè)定好的一個最低數(shù)據(jù)類型。還是sum算子,我們設(shè)定最低數(shù)據(jù)類型為int64是這么做的:
class?ReduceSumFunctor{
public:?
??Maybe?operator()(const?std::shared_ptr&?x,?const?std::vector<int32_t>&?axis,
???????????????????????????const?bool&?keepdims) ?const? {
????...
????TensorProcessor?tensor_processor;
????JUST(tensor_processor.AddInputs({x},?/*lowest_dtype=*/DType::Int64()).Apply());
????TensorTuple?input_tuple?=?JUST(tensor_processor.GetInputs());
??}
??...
};?
總結(jié)
類型提升是一個我們不經(jīng)意間會使用的一個操作,如果沒有正確處理輸出的數(shù)據(jù)類型,則可能導(dǎo)致結(jié)果溢出,出現(xiàn)錯誤的結(jié)果??此坪芎唵危珜嶋H調(diào)研+推敲細節(jié)也搞了兩三周,最后感謝同事在我完成這個功能的期間提供的許多幫助!
點擊下方原文鏈接直達OneFlow倉庫,歡迎關(guān)注。
