pytorch
9c3346c8 - reduce max_num_threads for complex double ops in reduce_kernel (#61438)

Commit
3 years ago
reduce max_num_threads for complex double ops in reduce_kernel (#61438) Summary: reduce_kernel currently has a all-purpose MAX_NUM_THREADS of 512, which causes register spilling in various kernel instantiations for the various ops that use it as a template (ReduceLogicKernel, ReduceMinMaxKernel, ReduceMomentKernel, ReduceNormKernel, and ReduceSumProdKernel). This is a coarse first attempt at mitigating spillage by reducing max_num_threads to 256 for all complex double ops, which are by far the most common and egregious offenders, while keeping it 512 for the other normal ops, the large majority of which are fine. Besides complex double ops, the remaining kernels which exhibit lmem usage are ReduceMinMax double, long, and BFloat16; ReduceMomentKernel BFloat16, Half, float, and double; and ReduceNorm double. The proposed fix manages to eliminate lmem usage and massively improve runtime (by 3-5x) for complex double ops. All other ops are unaffected and have the same runtime; if they used lmem before, they still do now. We would still strongly recommend further testing of input shapes and ops as well as looking into if there's a cleaner approach to doing this. We tested the following ops for both complex double instantiations, as well as testing torch.max and torch.argmax with doubles to make sure they didn't break. We didn't include the double instantiations in the timing data, since they remain unchanged post-fix vs pre-fix. Timing data for the complex double ops below (all done on Nvidia Titan-V GPU): torch.mean: ![MeanTimingData](https://user-images.githubusercontent.com/22803332/125005623-0f424800-e011-11eb-864e-8419485a9c76.PNG) torch.linalg.norm: ![NormTimingData](https://user-images.githubusercontent.com/22803332/125005649-179a8300-e011-11eb-96e1-54e18c85a336.PNG) torch.sum: ![SumTimingData](https://user-images.githubusercontent.com/22803332/125005655-1b2e0a00-e011-11eb-928e-ee5941608fb2.PNG) Pull Request resolved: https://github.com/pytorch/pytorch/pull/61438 Reviewed By: mrshenli Differential Revision: D29756863 Pulled By: ngimel fbshipit-source-id: 4c4635df58af9313966ff1df1095f7e15a39bb07
Author
Parents
Loading