benchmark
cf24fa11 - Add variable-length loop Triton kernel benchmark for jagged_mean operator (#2356)

Comment changes are shownComment changes are hidden
Commit
356 days ago
Add variable-length loop Triton kernel benchmark for jagged_mean operator (#2356) Summary: Pull Request resolved: https://github.com/pytorch/benchmark/pull/2356 Add Triton kernel benchmark implementing a fast, variable-length loop upon the simple fused `mean` kernels in D59146627. This diff enables looping from the beginning (`ragged_start`) to the end (`ragged_end`) of the specific variable-length subsection of the input nested tensor, which eliminates the extra work done by the simple fused kernel in looping over the entire range of the maximum sequence length, `MAX_SEQLEN`. Specifically, this diff eliminates the need to loop over extraneous data beyond the nested tensor's jagged length, terminating the loop before it starts reading, reducing, and writing extra zeros. This diff also contains implementations for `sum_then_buffer` and `buffer_then_sum`, as seen in the simple fused kernels in D59146627. This diff draws from a similar binary search implementation found [here](https://www.internalfb.com/code/fbsource/[d33668c8c8fe3cc7d75beb58b4b0dc51dc6e96a1][diffs]/fbcode/caffe2/torch/_inductor/runtime/triton_helpers.py?lines=195). These Triton kernels are benchmarked against three PyTorch implementations (one of which uses `torch.mean`, another `torch.nanmean` with padding, and the last `torch.sum` with padding) and one Triton implementation (simple fused). This diff follows the general framework found in the `jagged_sum` operator (D59026460, D59034792). Reviewed By: jbschlosser Differential Revision: D59175612 fbshipit-source-id: e782eca5c4ac2c2d789988d8089dfcb169cb45ee
Author
Parents
  • torchbenchmark
    • operators/jagged_mean
      • File
        kernels.py
      • File
        operator.py
    • util
      • File
        jagged_utils.py