benchmark
339ccfde - Refactor code for sum Triton kernels (#2303)

Commit
1 year ago
Refactor code for sum Triton kernels (#2303) Summary: Pull Request resolved: https://github.com/pytorch/benchmark/pull/2303 Refactor code to improve readability and logical flow for cases which select the `sum` Triton kernel implementation to run. Create helper functions for the following cases: - Reduce N-dimensional input to scalar output - Reduce 2-dimensional input to 1-dimensional output - Reduce 3-dimensional input along dimension 1 to 2-dimensional output Add command line argument parsing for the `input_dim` parameter, which specifies the number of dimensions desired in kernel inputs. Modify absolute tolerance to account for floating-point operation error. Reviewed By: jbschlosser Differential Revision: D58488137 fbshipit-source-id: 01e1f6104383cb0ec5338c4d0427b3a30c2bffd1
Author
Parents
Loading