Improve native_batch_norm_backward performance (CUDA) (#58240)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/38915
The original code uses a single kernel to do both the reduction and the elementwise backward calculations. Whereas the `SyncBatchNorm` kernels are split, which makes them slightly slower in some cases. I try to use the fused kernel when it's beneficial, but otherwise choose the optimized channels last split kernels. There is also eval mode, where the reduction is sometimes unnecessary in which case split kernels are a win even without channels last.
Benchmarks on my system show significant speedups for channels last reductions and eval mode, with only a few minor slowdowns in training mode. These slowdowns are for 2 x 2048 shape in training, which is a small channels last inputs. But for larger batches or channels, the channels last kernels are much faster.
|N |C |L |training|backward|old |new |cudnn |
|----|----|----|--------|--------|------|------|------|
|1 |256 |3136|TRUE |all |70.25 |64.93 |68.90 |
| | | |TRUE |self |69.77 |64.61 |69.42 |
| | | |FALSE |all |70.10 |51.12 |x |
| | | |FALSE |self |70.17 |51.17 |x |
|3136|256 | |TRUE |all |554.08|76.63 |549.88|
| | | |TRUE |self |553.34|78.19 |552.36|
| | | |FALSE |all |565.40|55.09 |x |
| | | |FALSE |self |565.71|54.84 |x |
|2 |8192|1 |TRUE |all |155.47|47.26 |202.26|
| | | |TRUE |self |155.46|48.36 |203.72|
| | | |FALSE |all |178.28|40.90 |x |
| | | |FALSE |self |178.21|40.69 |x |
|2 |2048|1 |TRUE |all |43.50 |48.21 |57.47 |
| | | |TRUE |self |43.63 |47.24 |55.22 |
| | | |FALSE |all |49.36 |39.27 |x |
| | | |FALSE |self |49.25 |42.02 |x |
|128 |8192|1 |TRUE |all |762.70|106.45|336.52|
| | | |TRUE |self |765.79|107.04|337.32|
| | | |FALSE |all |792.68|74.94 |x |
| | | |FALSE |self |793.86|74.83 |x |
|128 |2048|1 |TRUE |all |188.37|46.20 |85.02 |
| | | |TRUE |self |188.47|47.57 |85.04 |
| | | |FALSE |all |191.57|40.44 |x |
| | | |FALSE |self |190.13|41.55 |x |
|2 |8192| |TRUE |all |156.03|43.01 |155.19|
| | | |TRUE |self |156.24|46.59 |156.93|
| | | |FALSE |all |179.34|40.06 |x |
| | | |FALSE |self |179.20|41.85 |x |
|2 |2048| |TRUE |all |44.05 |50.15 |44.21 |
| | | |TRUE |self |44.10 |48.97 |44.11 |
| | | |FALSE |all |49.72 |40.95 |x |
| | | |FALSE |self |49.87 |43.43 |x |
|128 |8192| |TRUE |all |775.19|96.60 |777.64|
| | | |TRUE |self |776.20|96.85 |774.21|
| | | |FALSE |all |797.64|68.01 |x |
| | | |FALSE |self |806.25|68.05 |x |
|128 |2048| |TRUE |all |188.49|48.10 |188.97|
| | | |TRUE |self |188.07|46.97 |187.98|
| | | |FALSE |all |192.32|43.78 |x |
| | | |FALSE |self |193.72|40.82 |x |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58240
Reviewed By: bdhirsh
Differential Revision: D28435158
Pulled By: ngimel
fbshipit-source-id: bf62a1ee1c5d95a2caf55bee6176ae5c965688ec