Add simple fused Triton kernel benchmark for jagged_mean operator (#2355)
Summary:
Pull Request resolved: https://github.com/pytorch/benchmark/pull/2355
Add Triton kernel benchmark implementing a simple fused `mean` for the `jagged_mean` operator. The Triton kernels perform a `mean` along the ragged dimension of a nested tensor of logical dimensions `(B, *, M)`, where `*` is the ragged dimension. They load in blocks of the values tensor along its last dimension `M`, reduce each block of variable length along its first dimension `*`, and store each of `B` reductions in an output tensor of shape `(B, M)`. The first kernel, `sum_then_buffer`, performs a `sum` on each block of input, then accumulates into a buffer. The second kernel, `buffer_then_sum`, is a faster implementation which accumulates blocks into a buffer, then performs a cumulative `sum`.
This diff is particularly useful in emulating the loop in Inductor-generated (`torch.compile`) kernels and serves as a benchmark proxy for Inductor kernels.
Use the command-line argument `sum_then_buffer`, defaulted to `0` (as `buffer_then_sum` is faster, shown below), to decide which Triton kernel to benchmark.
These Triton kernels are benchmarked against two PyTorch implementations, one of which uses `torch.mean`, and the other `torch.div`, `torch.sum`, and `shape`.
This diff follows the general framework found in the jagged_sum operator (D58549297, D59034792).
Reviewed By: davidberard98
Differential Revision: D59146627
fbshipit-source-id: 3664b1b861bc368c64c8b34e4301e18da38c9a15