Fix SyncBatchNorm for empty inputs (#74944)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74944
fixes #36530
Prior to this commit, SyncBatchNorm crashes with the following
error message.
```
File "..../torch/nn/modules/_functions.py", line 17, in forward
mean, invstd = torch.batch_norm_stats(input, eps)
RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous
```
This PR adds a dedicated branch to handle empty inputs. When a process
recieves empty inputs, it will set its local `mean`, `invstd`, and `count`
to zero, and participate in the `all_gather` collective communications in
the forward pass. Then `mean` and `invstd` with zero count will be
filtered out before computing global mean and invstd. In the backward
pass, it also participate in the `all_reduce` communication with zero
tensors to unblock its peers.
Differential Revision:
D35273409
D35273409
Test Plan: Imported from OSS
Reviewed By: datumbox
Pulled By: mrshenli
fbshipit-source-id: 1cee51eea866773c329b3fbf5da2be8a5fee6f0f
(cherry picked from commit f8e2a2357240ebe7b7a058047d376a5300bdeda9)