pytorch
8e5b7ce5 - inductor: fix bf16 legalization issue for fp32 load with to bf16 case (#103080)

Commit
1 year ago
inductor: fix bf16 legalization issue for fp32 load with to bf16 case (#103080) Giving following ir: ``` def body(self, ops): get_index = self.get_index('index0') index_expr = ops.index_expr(get_index, torch.int32) constant = ops.constant(4, torch.int32) lt = ops.lt(index_expr, constant) masked_subblock1 = self.masked_subblock1(lt, 0.0) get_index_1 = self.get_index('index3') load = ops.load('arg2_1', get_index_1) to_dtype = ops.to_dtype(load, torch.bfloat16) where = ops.where(lt, masked_subblock1, to_dtype) get_index_2 = self.get_index('index3') store = ops.store('buf0', get_index_2, where, None) return store def masked_subblock2(self, ops): get_index = self.get_index('index2') load = ops.load('arg1_1', get_index) return load def masked_subblock1(self, ops): get_index = self.get_index('index1') index_expr = ops.index_expr(get_index, torch.int32) constant = ops.constant(1, torch.int32) ge = ops.ge(index_expr, constant) get_index_1 = self.get_index('index1') index_expr_1 = ops.index_expr(get_index_1, torch.int32) constant_1 = ops.constant(3, torch.int32) lt = ops.lt(index_expr_1, constant_1) and_ = ops.and_(ge, lt) masked_subblock2 = self.masked_subblock2(and_, 0.0) get_index_2 = self.get_index('index3') load = ops.load('arg2_1', get_index_2) to_dtype = ops.to_dtype(load, torch.bfloat16) where = ops.where(and_, masked_subblock2, to_dtype) return where ``` before this PR, the ```masked_subblock2``` will legalize as ```load_bf16+to_fp32```, and the ```masked_subblock2```'s output type is ```fp32```, but for ```load = ops.load('arg2_1', get_index_2), to_dtype = ops.to_dtype(load, torch.bfloat16)```, we didn't convert ```to_bf16``` as ```to_fp32```, which the ```op.where``` has mixed type computation, and will has compiler error: ```error: operands to ?: have different types ‘float’ and ‘c10::BFloat16’```. This PR will always convert ```to_bf16``` as ```to_fp32``` to fix such an issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/103080 Approved by: https://github.com/jgong5, https://github.com/desertfire
Author
Committer
Parents
Loading