pytorch
9c867eae - nnc: fix Store if value is fp32 while buf is bf16 (#86788)

Commit
2 years ago
nnc: fix Store if value is fp32 while buf is bf16 (#86788) Fixes https://github.com/pytorch/pytorch/issues/86533. For the below graph: ```bash [DUMP kernel.cpp:1690] TensorExprKernel graph: [DUMP kernel.cpp:1690] graph(%x.1 : BFloat16(10, strides=[1], requires_grad=0, device=cpu)): [DUMP kernel.cpp:1690] %1 : int = prim::Constant[value=0]() [DUMP kernel.cpp:1690] %2 : BFloat16(10, strides=[1], requires_grad=0, device=cpu) = aten::pow(%x.1, %1) # test/test_tensorexpr.py:1330:29 [DUMP kernel.cpp:1690] %3 : BFloat16(10, strides=[1], requires_grad=0, device=cpu) = aten::sin(%2) # test/test_tensorexpr.py:1330:19 [DUMP kernel.cpp:1690] return (%3) ``` **Loop stmt before the fix:** The store value `0.8414709568023682f` is float while the scalar_type of the store buf `aten_sin` is bf16. ```bash [DEBUG llvm_codegen.cpp:489] After HalfRewriter { [DEBUG llvm_codegen.cpp:489] aten_sin[Ramp(0ll, 1ll, 8)] = Broadcast(0.8414709568023682f, 8); [DEBUG llvm_codegen.cpp:489] for (int64_t i_1_tail_tail = 0ll; i_1_tail_tail < 2ll; i_1_tail_tail++) { [DEBUG llvm_codegen.cpp:489] aten_sin[i_1_tail_tail + 8ll] = 0.8414709568023682f; [DEBUG llvm_codegen.cpp:489] } [DEBUG llvm_codegen.cpp:489] } ``` **Loop stmt after the fix:** ```bash [DEBUG llvm_codegen.cpp:489] After HalfRewriter { [DEBUG llvm_codegen.cpp:489] aten_sin[Ramp(0ll, 1ll, 8)] = bfloat16(Broadcast(0.8414709568023682f, 8)); [DEBUG llvm_codegen.cpp:489] for (int64_t i_1_tail_tail = 0ll; i_1_tail_tail < 2ll; i_1_tail_tail++) { [DEBUG llvm_codegen.cpp:489] aten_sin[i_1_tail_tail + 8ll] = bfloat16(0.8414709568023682f); [DEBUG llvm_codegen.cpp:489] } [DEBUG llvm_codegen.cpp:489] } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/86788 Approved by: https://github.com/EikanWang, https://github.com/kit1980
Author
Committer
Parents
Loading