Refactor code for sum Triton kernels (#2303)
Summary:
Pull Request resolved: https://github.com/pytorch/benchmark/pull/2303
Refactor code to improve readability and logical flow for cases which select the `sum` Triton kernel implementation to run. Create helper functions for the following cases:
- Reduce N-dimensional input to scalar output
- Reduce 2-dimensional input to 1-dimensional output
- Reduce 3-dimensional input along dimension 1 to 2-dimensional output
Add command line argument parsing for the `input_dim` parameter, which specifies the number of dimensions desired in kernel inputs.
Modify absolute tolerance to account for floating-point operation error.
Reviewed By: jbschlosser
Differential Revision: D58488137
fbshipit-source-id: 01e1f6104383cb0ec5338c4d0427b3a30c2bffd1