conv-bn folding in low precision (#106576)
Batchnorm inference is done in fp32 if the inputs are in fp16/bf16 and the output is casted back down to its original precision. This causes the batchnorm weights to get constant folded to fp32, and prevented Conv-BN folding from firing.
```
def forward(self, arg0_1: bf16[32, 3, 3, 3], arg1_1: bf16[32], arg2_1: bf16[32], ...)
convolution: bf16[3, 32, 15, 15] = aten..convolution.default(arg6_1, arg0_1, None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1); arg6_1 = arg0_1 = None
# weight upcasting
convert_element_type: f32[32] = torch.ops.prims.convert_element_type.default(arg3_1, torch.float32); arg3_1 = None
convert_element_type_1: f32[32] = torch.ops.prims.convert_element_type.default(arg4_1, torch.float32); arg4_1 = None
...
# end of batch norm
add_1: f32[3, 32, 15, 15] = aten..add.Tensor(mul_2, unsqueeze_7); mul_2 = unsqueeze_7 = None
# output downcast
convert_element_type_2: bf16[3, 32, 15, 15] = torch.ops.prims.convert_element_type.default(add_1, torch.bfloat16); add_1 = None
```
I mark the convolutions which are followed by binary foldable ops in a higher precision that are then get converted back down to the original conv dtype. We fold the weights in fp32 because it's slightly better accuracy, then at the end of the pass convert back the weights to their original dtype.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106576
Approved by: https://github.com/XiaobingSuper, https://github.com/yanboliang
ghstack dependencies: #106471, #106575