Improve BatchNorm1d performance (CUDA) (#57034)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57034
Resolves gh-38915
For the example given in the issue, BatchNorm1d on cuDNN is around 12x slower
than BatchNorm2d. Internally, cuDNN expects at least a 4d tensor (N, C, H, W)
so these two modules actually call the same cuDNN code. My assumption is that
cuDNN just isn't optimized for H=W=1.
Instead, this disables cudnn for 2d batch_norm inputs and improves the CUDA
implementation of `native_batch_norm` to be competative with cuDNN. For the
example in the issue, `BatchNorm1d` now takes 335 us compared to 6.3 ms before,
or a 18x speedup.
Before this change, nvprof shows:
```
Type Time(%) Time Calls Avg Min Max Name
GPU activities: 99.64% 630.95ms 100 6.3095ms 5.6427ms 8.8800ms void cudnn::bn_fw_tr_1C11_kernel_NCHW<float, float, int=512, bool=0, int=2>(cudnnTensorStruct, float const *, cudnn::bn_fw_tr_1C11_kernel_NCHW<float, float, int=512, bool=0, int=2>, cudnnTensorStruct*, float const *, float const , cudnnTensorStruct*, cudnnTensorStruct*, cudnnTensorStruct**, float const *, float const *, float const *, cudnnTensorStruct*, cudnnTensorStruct*)
```
But after, it shows:
```
Type Time(%) Time Calls Avg Min Max Name
GPU activities: 54.76% 14.352ms 100 143.52us 123.52us 756.28us _ZN2at6native27unrolled_elementwise_kernelIZZZNS0_72_GLOBAL__N__48_tmpxft_001e82d0_00000000_7_Normalization_cpp1_ii_db66e07022batch_norm_elementwiseERKNS_6TensorES5_RKN3c108optionalIS3_EESA_S5_S5_ENKUlvE_clEvENKUlvE2_clEvEUlfffffE_NS_6detail5ArrayIPcLi6EEE16OffsetCalculatorILi5EjESI_ILi1EjENS0_6memory15LoadWithoutCastENSL_16StoreWithoutCastEEEviT_T0_T1_T2_T3_T4_
35.09% 9.1951ms 100 91.950us 84.415us 362.17us void at::native::reduce_kernel<int=256, int=2, at::native::ReduceOp<float, at::native::WelfordOps<float, float, int, float, thrust::pair<float, float>>, unsigned int, float, int=2>>(float)
0.71% 186.14us 100 1.8610us 1.8240us 1.9840us _ZN2at6native72_GLOBAL__N__48_tmpxft_001e82d0_00000000_7_Normalization_cpp1_ii_db66e07045unrolled_elementwise_kernel_for_multi_outputsILi3EZZZNS1_34batch_norm_update_stats_and_invertERKNS_6TensorES5_S5_S5_ddlENKUlvE_clEvENKUlvE2_clEvEUlffffE_NS_6detail5ArrayIPcLi7EEE23TrivialOffsetCalculatorILi4EjESD_ILi3EjEEEviT0_T1_T2_T3_
0.59% 153.37us 100 1.5330us 1.4720us 2.6240us
void at::native::vectorized_elementwise_kernel<int=4,
at::native::BUnaryFunctor<at::native::AddFunctor<long>>,
at::detail::Array<char*, int=2>>(int, long,
at::native::AddFunctor<long>)
```
I think there is similar scope to improve the backward implementation.
Test Plan: Imported from OSS
Reviewed By: anjali411
Differential Revision: D28142447
Pulled By: ngimel
fbshipit-source-id: c70109780e206fa85e50a31e90a1cb4c533199da