benchmark
1e79c04b - Add simple fused Triton kernel benchmark for jagged_mean operator (#2355)

Commit
1 year ago
Add simple fused Triton kernel benchmark for jagged_mean operator (#2355) Summary: Pull Request resolved: https://github.com/pytorch/benchmark/pull/2355 Add Triton kernel benchmark implementing a simple fused `mean` for the `jagged_mean` operator. The Triton kernels perform a `mean` along the ragged dimension of a nested tensor of logical dimensions `(B, *, M)`, where `*` is the ragged dimension. They load in blocks of the values tensor along its last dimension `M`, reduce each block of variable length along its first dimension `*`, and store each of `B` reductions in an output tensor of shape `(B, M)`. The first kernel, `sum_then_buffer`, performs a `sum` on each block of input, then accumulates into a buffer. The second kernel, `buffer_then_sum`, is a faster implementation which accumulates blocks into a buffer, then performs a cumulative `sum`. This diff is particularly useful in emulating the loop in Inductor-generated (`torch.compile`) kernels and serves as a benchmark proxy for Inductor kernels. Use the command-line argument `sum_then_buffer`, defaulted to `0` (as `buffer_then_sum` is faster, shown below), to decide which Triton kernel to benchmark. These Triton kernels are benchmarked against two PyTorch implementations, one of which uses `torch.mean`, and the other `torch.div`, `torch.sum`, and `shape`. This diff follows the general framework found in the jagged_sum operator (D58549297, D59034792). Reviewed By: davidberard98 Differential Revision: D59146627 fbshipit-source-id: 3664b1b861bc368c64c8b34e4301e18da38c9a15
Author
Parents
Loading