insert to dtype for fused mem copy scheduler node (#101042)
Fix https://github.com/pytorch/pytorch/issues/100830.
For the inplace node, there will be a `copy_` generated and the `copy_` will be `realized` as a `scheduler buffer` since it is a mutation. This `scheduler buffer` is a memory copy but after fusing with the previous buffer, it will not be a memory copy only buffers.
This PR solves the issue by removing `load_bf16_as_fp32` and `store_bf16_from_fp32`. Instead, enable fp32/bf16 vec conversion in `to_dtype`. Then we always store bf16.
```python
import torch
import torch.nn as nn
torch.manual_seed(420)
from torch._inductor import config
x = torch.randn(1, 18, dtype=torch.bfloat16)
class ExampleModel(nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.relu = nn.ReLU(inplace=True) # nn.ReLU(inplace=False)
def forward(self, input1):
out = self.relu(input1)
# input1.copy_(out)
return out
func = ExampleModel()
with torch.no_grad():
func.train(False)
res1 = func(x) # without jit
print(res1)
jit_func = torch.compile(func)
res2 = jit_func(x)
print(res2)
```
Generated code without this PR: (`tm3` store is wrong, `tmp3` is `float` while `out_ptr1` is `bf16`)
```
auto tmp0 = load_bf16_as_float(out_ptr1 + static_cast<long>(i0));
auto tmp1 = (tmp0);
auto tmp2 = at::vec::clamp_min(tmp1, decltype(tmp1)(0));
auto tmp3 = (tmp2);
store_float_as_bf16(out_ptr0 + static_cast<long>(i0), tmp3);
tmp3.store(out_ptr1 + static_cast<long>(i0), 16);
```
Generated code with this PR:
```
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(out_ptr1 + static_cast<long>(i0), 16);
auto tmp1 = cvt_bf16_to_fp32(tmp0);
auto tmp2 = at::vec::clamp_min(tmp1, decltype(tmp1)(0));
auto tmp3 = cvt_fp32_to_bf16(tmp2);
tmp3.store(out_ptr0 + static_cast<long>(i0), 16);
tmp3.store(out_ptr1 + static_cast<long>(i0), 16);
```
This PR also fixed the data type propagation for `masked_subblock`.
Before the masked_subblock's dtype is propagated by its input which is wrong.
```
opcode name target args kwargs
----------- --------- --------- -------------------------- --------
call_module masked_subblock1 masked_subblock1 (and__2, -inf)
```
Now we propagated it by subblock with the same name:
```
# graph for body.subblocks['masked_subblock1']
opcode name target args kwargs
----------- --------- --------- -------------------------- --------
placeholder ops ops () {}
call_module get_index get_index ('index2',) {}
call_method load load (ops, 'arg0_1', get_index) {}
call_method to_dtype to_dtype (ops, load, torch.float32) {}
output output output (to_dtype,) {}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101042
Approved by: https://github.com/jgong5, https://github.com/jansel