Fix SyncBatchNorm usage without stats tracking (#50126)
Summary:
In `batch_norm_gather_stats_with_counts_cuda` use `input.scalar_type()` if `running_mean` is not defined
In `SyncBatchNorm` forward function create count tensor with `torch.float32` type if `running_mean` is None
Fix a few typos
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50126
Test Plan:
```
python -c "import torch;print(torch.batch_norm_gather_stats_with_counts( torch.randn(1, 3, 3, 3, device='cuda'), mean = torch.ones(2, 3, device='cuda'), invstd = torch.ones(2, 3, device='cuda'), running_mean = None, running_var = None , momentum = .1, eps = 1e-5, counts = torch.ones(2, device='cuda')))"
```
Fixes https://github.com/pytorch/pytorch/issues/49730
Reviewed By: ngimel
Differential Revision: D25797930
Pulled By: malfet
fbshipit-source-id: 22a91e3969b5e9bbb7969d9cc70b45013a42fe83