Remove overload of GPU max_pool3d with kernel_width; fix nan, inf in GPU {fractional,adaptive} max_pool{2,3}d (#39903)
Summary:
Fix https://github.com/pytorch/pytorch/issues/39846.
Fix https://github.com/pytorch/pytorch/issues/39044
The problem was that `max_pool3d_with_indices_single_out_frame` has an overload of kernel_width being a template argument. The two overloaded kernels were supposed to be identical, however, they were not.
The general version
https://github.com/pytorch/pytorch/blob/da3073e9b1db503f106842339f50f522d973be84/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu#L69-L73
The overloaded version
https://github.com/pytorch/pytorch/blob/da3073e9b1db503f106842339f50f522d973be84/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu#L130-L134
While the max_pool3d being "switch-case"-ed to the overloaded version, the NaN value comparison is ignored. Also, maintaining two overloaded versions of such a complicated kernel would be hard. I'm not sure if the overloaded version would even give huge performance benefit. So I propose to remove the kernel_width overloaded version.
Also, the current test of max_pool_XD_nan forgot the device kwarg. I added that.
Edit: profiling before and after
script: https://github.com/xwang233/code-snippet/blob/master/maxpool-3d-kw-template-arg/a.py
plot: https://github.com/xwang233/code-snippet/blob/master/maxpool-3d-kw-template-arg/b.ipynb
The performance difference is within +- 5%.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39903
Differential Revision: D22080759
Pulled By: ngimel
fbshipit-source-id: 4dacdd266a0522b3ff432eb9d58b131fa86821e9