Improve BatchNorm1d training performance (CPU) (#57033)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57033
CPU part of gh-38915
BatchNorm1d is implemented by looping over the channels, selecting one channel
at a time and performing cpu_serial_kernel loops per-channel. For (N, C)
contiguous layout this results in a sub-optimal strided memory access pattern;
guarunteeing no elements will ever be in the same cache line.
I fix this by passing the entire input into one `TensorIterator` and letting
it decide which dimensions to iterate over and how to divide work among threads.
For statistic updates and the backward function, I use `at::mean` and `at::sum`
instead of the ad-hoc reductions there. Not only does this allow better memory
access patterns, it also enables vectorization and so performance improves for
BatchNorm2d as well. Unfortunately, `at::var` and `at::var_mean` don't perform
as well so I've left the other reductions as they were.
Overall, on my machine this takes the 1d example from 24 ms down to 4 ms and
the 2d example from 2.5 ms down to 2 ms.
Test Plan: Imported from OSS
Reviewed By: mruberry
Differential Revision: D28142333
Pulled By: ngimel
fbshipit-source-id: 066fe4f37f29b6458005e513e85faa398eeb9e2d