pytorch
adcefcb3 - insert to dtype for fused mem copy scheduler node (#101042)

Commit
1 year ago
insert to dtype for fused mem copy scheduler node (#101042) Fix https://github.com/pytorch/pytorch/issues/100830. For the inplace node, there will be a `copy_` generated and the `copy_` will be `realized` as a `scheduler buffer` since it is a mutation. This `scheduler buffer` is a memory copy but after fusing with the previous buffer, it will not be a memory copy only buffers. This PR solves the issue by removing `load_bf16_as_fp32` and `store_bf16_from_fp32`. Instead, enable fp32/bf16 vec conversion in `to_dtype`. Then we always store bf16. ```python import torch import torch.nn as nn torch.manual_seed(420) from torch._inductor import config x = torch.randn(1, 18, dtype=torch.bfloat16) class ExampleModel(nn.Module): def __init__(self): super(ExampleModel, self).__init__() self.relu = nn.ReLU(inplace=True) # nn.ReLU(inplace=False) def forward(self, input1): out = self.relu(input1) # input1.copy_(out) return out func = ExampleModel() with torch.no_grad(): func.train(False) res1 = func(x) # without jit print(res1) jit_func = torch.compile(func) res2 = jit_func(x) print(res2) ``` Generated code without this PR: (`tm3` store is wrong, `tmp3` is `float` while `out_ptr1` is `bf16`) ``` auto tmp0 = load_bf16_as_float(out_ptr1 + static_cast<long>(i0)); auto tmp1 = (tmp0); auto tmp2 = at::vec::clamp_min(tmp1, decltype(tmp1)(0)); auto tmp3 = (tmp2); store_float_as_bf16(out_ptr0 + static_cast<long>(i0), tmp3); tmp3.store(out_ptr1 + static_cast<long>(i0), 16); ``` Generated code with this PR: ``` auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(out_ptr1 + static_cast<long>(i0), 16); auto tmp1 = cvt_bf16_to_fp32(tmp0); auto tmp2 = at::vec::clamp_min(tmp1, decltype(tmp1)(0)); auto tmp3 = cvt_fp32_to_bf16(tmp2); tmp3.store(out_ptr0 + static_cast<long>(i0), 16); tmp3.store(out_ptr1 + static_cast<long>(i0), 16); ``` This PR also fixed the data type propagation for `masked_subblock`. Before the masked_subblock's dtype is propagated by its input which is wrong. ``` opcode name target args kwargs ----------- --------- --------- -------------------------- -------- call_module masked_subblock1 masked_subblock1 (and__2, -inf) ``` Now we propagated it by subblock with the same name: ``` # graph for body.subblocks['masked_subblock1'] opcode name target args kwargs ----------- --------- --------- -------------------------- -------- placeholder ops ops () {} call_module get_index get_index ('index2',) {} call_method load load (ops, 'arg0_1', get_index) {} call_method to_dtype to_dtype (ops, load, torch.float32) {} output output output (to_dtype,) {} ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/101042 Approved by: https://github.com/jgong5, https://github.com/jansel
Author
Committer
Parents
Loading