Add variable-length loop Triton kernel benchmark for jagged_mean operator (#2356)
Summary:
Pull Request resolved: https://github.com/pytorch/benchmark/pull/2356
Add Triton kernel benchmark implementing a fast, variable-length loop upon the simple fused `mean` kernels in D59146627. This diff enables looping from the beginning (`ragged_start`) to the end (`ragged_end`) of the specific variable-length subsection of the input nested tensor, which eliminates the extra work done by the simple fused kernel in looping over the entire range of the maximum sequence length, `MAX_SEQLEN`. Specifically, this diff eliminates the need to loop over extraneous data beyond the nested tensor's jagged length, terminating the loop before it starts reading, reducing, and writing extra zeros.
This diff also contains implementations for `sum_then_buffer` and `buffer_then_sum`, as seen in the simple fused kernels in D59146627.
This diff draws from a similar binary search implementation found [here](https://www.internalfb.com/code/fbsource/[d33668c8c8fe3cc7d75beb58b4b0dc51dc6e96a1][diffs]/fbcode/caffe2/torch/_inductor/runtime/triton_helpers.py?lines=195).
These Triton kernels are benchmarked against three PyTorch implementations (one of which uses `torch.mean`, another `torch.nanmean` with padding, and the last `torch.sum` with padding) and one Triton implementation (simple fused).
This diff follows the general framework found in the `jagged_sum` operator (D59026460, D59034792).
Reviewed By: jbschlosser
Differential Revision: D59175612
fbshipit-source-id: e782eca5c4ac2c2d789988d8089dfcb169cb45ee