pytorch
1670ea94 - Remove overload of GPU max_pool3d with kernel_width; fix nan, inf in GPU {fractional,adaptive} max_pool{2,3}d (#39903)

Commit
4 years ago
Remove overload of GPU max_pool3d with kernel_width; fix nan, inf in GPU {fractional,adaptive} max_pool{2,3}d (#39903) Summary: Fix https://github.com/pytorch/pytorch/issues/39846. Fix https://github.com/pytorch/pytorch/issues/39044 The problem was that `max_pool3d_with_indices_single_out_frame` has an overload of kernel_width being a template argument. The two overloaded kernels were supposed to be identical, however, they were not. The general version https://github.com/pytorch/pytorch/blob/da3073e9b1db503f106842339f50f522d973be84/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu#L69-L73 The overloaded version https://github.com/pytorch/pytorch/blob/da3073e9b1db503f106842339f50f522d973be84/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu#L130-L134 While the max_pool3d being "switch-case"-ed to the overloaded version, the NaN value comparison is ignored. Also, maintaining two overloaded versions of such a complicated kernel would be hard. I'm not sure if the overloaded version would even give huge performance benefit. So I propose to remove the kernel_width overloaded version. Also, the current test of max_pool_XD_nan forgot the device kwarg. I added that. Edit: profiling before and after script: https://github.com/xwang233/code-snippet/blob/master/maxpool-3d-kw-template-arg/a.py plot: https://github.com/xwang233/code-snippet/blob/master/maxpool-3d-kw-template-arg/b.ipynb The performance difference is within +- 5%. Pull Request resolved: https://github.com/pytorch/pytorch/pull/39903 Differential Revision: D22080759 Pulled By: ngimel fbshipit-source-id: 4dacdd266a0522b3ff432eb9d58b131fa86821e9
Author
Parents
Loading