[SyncBatchNorm] Support running with low precision parameters (#98332)
This PR fixes https://github.com/pytorch/pytorch/issues/96203.
**Details**
When using `nn.SyncBatchNorm` with the model converted to FP16, there is a dtype discrepancy in the `SyncBatchNorm.forward()` causing an error like:
```
File "/.../pytorch/torch/nn/modules/_functions.py", line 91, in forward
mean, invstd = torch.batch_norm_gather_stats_with_counts(
RuntimeError: Expected counts to have type Half but got Float
```
[`torch.batch_norm_gather_stats_with_counts()`](https://github.com/pytorch/pytorch/blob/fe9da29842a07a1f44d6b8c2a4c75053da9e84d0/torch/nn/modules/_functions.py#L88-L97) requires the `running_mean`, `running_var`, and `counts` to have the same dtype. However, when the model has been converted to FP16, only `running_mean` and `running_var` use FP16, while the `counts` are in FP32 due to [`mean` being in FP32](https://github.com/pytorch/pytorch/blob/fe9da29842a07a1f44d6b8c2a4c75053da9e84d0/torch/nn/modules/_functions.py#L25-L30). This PR resolves this by casting `counts` from FP32 to FP16 instead of the alternative to cast `mean` and `invstd` from FP32 to FP16.
Moreover, for the backward, this PR casts `weight` from FP16 to FP32 to match the dtype of `mean` and `invstd` as required by `torch.batch_norm_backward_elemt()` instead of the alternative to cast `mean` and `invstd` from FP32 to FP16.
**Test Plan**
I dug up this run command from 2021:
For `world_size` in `{1,2}` and `backend` in `{nccl, gloo}`:
```
WORLD_SIZE=world_size BACKEND=backend python -m pytest test/distributed/test_distributed_spawn.py -k test_DistributedDataParallel_SyncBatchNorm_half -vs
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98332
Approved by: https://github.com/rohan-varma