inductor: fix bf16 legalization issue for fp32 load with to bf16 case (#103080)
Giving following ir:
```
def body(self, ops):
get_index = self.get_index('index0')
index_expr = ops.index_expr(get_index, torch.int32)
constant = ops.constant(4, torch.int32)
lt = ops.lt(index_expr, constant)
masked_subblock1 = self.masked_subblock1(lt, 0.0)
get_index_1 = self.get_index('index3')
load = ops.load('arg2_1', get_index_1)
to_dtype = ops.to_dtype(load, torch.bfloat16)
where = ops.where(lt, masked_subblock1, to_dtype)
get_index_2 = self.get_index('index3')
store = ops.store('buf0', get_index_2, where, None)
return store
def masked_subblock2(self, ops):
get_index = self.get_index('index2')
load = ops.load('arg1_1', get_index)
return load
def masked_subblock1(self, ops):
get_index = self.get_index('index1')
index_expr = ops.index_expr(get_index, torch.int32)
constant = ops.constant(1, torch.int32)
ge = ops.ge(index_expr, constant)
get_index_1 = self.get_index('index1')
index_expr_1 = ops.index_expr(get_index_1, torch.int32)
constant_1 = ops.constant(3, torch.int32)
lt = ops.lt(index_expr_1, constant_1)
and_ = ops.and_(ge, lt)
masked_subblock2 = self.masked_subblock2(and_, 0.0)
get_index_2 = self.get_index('index3')
load = ops.load('arg2_1', get_index_2)
to_dtype = ops.to_dtype(load, torch.bfloat16)
where = ops.where(and_, masked_subblock2, to_dtype)
return where
```
before this PR, the ```masked_subblock2``` will legalize as ```load_bf16+to_fp32```, and the ```masked_subblock2```'s output type is ```fp32```, but for ```load = ops.load('arg2_1', get_index_2), to_dtype = ops.to_dtype(load, torch.bfloat16)```, we didn't convert ```to_bf16``` as ```to_fp32```, which the ```op.where``` has mixed type computation, and will has compiler error: ```error: operands to ?: have different types ‘float’ and ‘c10::BFloat16’```.
This PR will always convert ```to_bf16``` as ```to_fp32``` to fix such an issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103080
Approved by: https://github.com/jgong5, https://github.com/desertfire