speed up SyncBatchNorm by batching distributed communication (#38246)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38246
Speeds up SyncBatchNorm by batching the distributed communication.
Initial benchmarks show a ~15+% speed improvement on MobileNetV2 and
EfficientNetB3 on a single machine with 8 gpus. Improvement
vs baseline increases as # of gpus increases.
Test Plan:
verified that before+after intermediate values in fwd/bwd pass are equivalent (with `torch.allclose`)
benchmark runner:
https://gist.github.com/vkuzo/7b1ce1b1b051ee6d46877d0f18ab9b1f
results (1 forward pass + 1 backward pass, 1 machine, 8x Tesla-P100, batch_size=20 per node):
```
model gpus before_ms after_ms speedup
efficientnet-b3 2 660 654 0.00909
efficientnet-b3 4 777 710 0.08623
efficientnet-b3 8 988 838 0.15182
mobilenet-v2 2 267 266 0.00375
mobilenet-v2 4 328 289 0.1189
mobilenet-v2 8 453 373 0.1766
```
Imported from OSS
Differential Revision: D21505905
fbshipit-source-id: 3e796343fce8329a2e17671d60ae66c0387924e7