Migrate apex.parallel.SyncBatchNorm channels_last to pytorch (#46906)
Summary:
per title
This PR did
- Migrate `apex.parallel.SyncBatchNorm` channels_last to pytorch `torch.nn.SyncBatchNorm`
- Fix a TODO here by fusing `sum`, `div` kernels into backward elementwise kernel
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L76-L95
Todo
- [x] Discuss a regression introduced in https://github.com/pytorch/pytorch/pull/37133#discussion_r512530389, which is the synchronized copy here
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L32-L34
**Comment**: This PR uses apex version for the size check. Test passed and I haven't seen anything wrong so far.
- [x] The restriction to use channels_last kernel will be like this
```
inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
return self.is_contiguous(at::MemoryFormat::ChannelsLast) || self.ndimension() == 2;
}
```
I think we can relax that for channels_last_3d as well?
**Comment**: we don't have benchmark for this now, will check this and add functionality later when needed.
- [x] Add test
- [x] Add benchmark
Detailed benchmark is at https://github.com/xwang233/code-snippet/tree/master/syncbn-channels-last
Close https://github.com/pytorch/pytorch/issues/50781
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46906
Reviewed By: albanD
Differential Revision: D26771437
Pulled By: malfet
fbshipit-source-id: d00387044e9d43ac7e6c0e32a2db22c63d1504de