pytorch
b049493e - Legalize BFloat16 in NNC. (#83988)

Commit
2 years ago
Legalize BFloat16 in NNC. (#83988) Regarding BF16 support in NNC, we always convert the BF16 to FP32 and then compute with FP32. After the FP32 computation, we convert the FP32 result to BF16. This logic has been supported in [half_support.h](https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/half_support.h). But the B16/FP32 conversion has not been supported by LLVM. Currently, LLVM only supports the BF16 in its front end but still cannot generate the assembly code. So we implement this feature in [llvm_codegen](https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/llvm_codegen.cpp) like aten implementation. In this PR, it contains three points - Take BF16 as uint16, Convert BF16 to FP32 and Convert FP32 to BF16. - Take BF16 as uint16 - [PR Change](https://github.com/pytorch/pytorch/pull/83988/files#diff-9d09ca2fce1c689ab43cd71795ab9b8b63447de950cf98ae8a18114e18d87e79R544-R546) We cannot naively convert map the BF16 to LLVM BF16 as the LLVM backend still does not support this data type as I mentioned. Meanwhile, the BF16 in PyTorch is a [structure](https://github.com/pytorch/pytorch/blob/master/c10/util/BFloat16.h#L73). Its realdata is uint16. So we also bitcast the BF16 tensor to uint16 - BF16 to FP32 conversion - [PR Change](https://github.com/pytorch/pytorch/pull/83988/files#diff-9d09ca2fce1c689ab43cd71795ab9b8b63447de950cf98ae8a18114e18d87e79R1057-R1061) we just need to shift the BF16 value left by 16bits and then bit cast the shifted value to FP32 - FP32 to BF16 conversion - [PR Change](https://github.com/pytorch/pytorch/pull/83988/files#diff-9d09ca2fce1c689ab43cd71795ab9b8b63447de950cf98ae8a18114e18d87e79R1066-R1084) The conversion is kind of complex. Because we use RNE to implement it. The RNR to convert the FP32 to BF16 is as follows. ```C++ uint16_t round_to_nearest_even(float src) { if (std::isnan(src)) { return UINT16_C(0x7FC0); } else { union { uint32_t U32; float F32; }; F32 = src; uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); return static_cast<uint16_t>((U32 + rounding_bias) >> 16); } } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/83988 Approved by: https://github.com/ZolotukhinM, https://github.com/frank-wei
Author
Committer
Parents
Loading