enable conv+bn folding for mixed-dtype when bn has post activation (#107142)
For conv+bn+relu6, the joint-graph pass will remove one type of conversion and the graph will be like this:
```
def forward(self, arg0_1: bf16[32, 3, 3, 3], arg1_1: bf16[32], arg2_1: bf16[32], ...)
convolution: bf16[3, 32, 15, 15] = aten..convolution.default(arg6_1, arg0_1, None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1); arg6_1 = arg0_1 = None
# weight upcasting
convert_element_type: f32[32] = torch.ops.prims.convert_element_type.default(arg3_1, torch.float32); arg3_1 = None
convert_element_type_1: f32[32] = torch.ops.prims.convert_element_type.default(arg4_1, torch.float32); arg4_1 = None
...
# end of batch norm
add_1: f32[3, 32, 15, 15] = aten..add.Tensor(mul_2, unsqueeze_7); mul_2 = unsqueeze_7 = None
# output downcast
convert_element_type_2: bf16[3, 32, 15, 15] = torch.ops.prims.convert_element_type.default(add_1, torch.float32); add_1 = None
clamp_min: f32[3, 32, 15, 15] = torch.ops.aten.clamp_min.default(convert_element_type_2, 0.0); convert_element_type_2 = None
clamp_max: f32[3, 32, 15, 15] = torch.ops.aten.clamp_max.default(clamp_min, 6.0); clamp_min = None
convert_element_type_3: bf16[3, 32, 15, 15] = torch.ops.prims.convert_element_type.default(clamp_max, torch.bfloat16); clamp_max = None
```
the conv+bn folding will be failed, this PR will move the joint-graph pass's dtype conversion removing to after of conv_bn folding pass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107142
Approved by: https://github.com/eellison