benchmark
11ebabb2 - Add variable-length loop Triton kernel benchmark for jagged_softmax

Commit
1 year ago
Add variable-length loop Triton kernel benchmark for jagged_softmax Summary: Add Triton kernel benchmark implementing a variable-length loop `softmax` for the `jagged_softmax` operator. This Triton kernel performs a `softmax` operation along the ragged dimension of a nested tensor of logical dimensions `(B, *, M)`, where `*` is the ragged dimension. The kernel implements `softmax` in two phases. 1. As with previous implementations of the variable-length loop kernel (e.g. D59175612), the first phase loops through the maximum sequence length `MAX_SEQLEN` and loads in the input, filling in any blocks outside of the ragged bounds with negative infinity (as `e^-inf = 0`). It then subtracts the maximum value in the input tensor from the input block and takes the exponent, then stores the exponentiated block into the buffer. After the first phase, the kernel takes the `sum` of the buffer along the ragged dimension, resulting in a `sum` over all exponentiated input. 2. The second phase iterates over `MAX_SEQLEN` again to load the input, then divides the input by the `sum` of the buffer and stores this result in the output. For division, I used `tl.fdiv`, which performs just a bit faster than the regular division operator `/`. As with previous jagged operators, the kernel is benchmarked against a padded PyTorch implementation and the simple fused Triton kernel, and it verifies accuracy against a baseline PyTorch `unbind` implementation. This implementation uses the `buffer_then_sum` method; it adds all exponentiated input to a buffer, then takes the `sum` at once. This method has been proven to be faster, as seen in previous jagged operators (like `jagged_sum` and `jagged_mean`) and seems to prove that storing to and loading from many registers is faster than taking multiple sums in Triton. For more information regarding other approaches I tried and some notes on the implementation, see D59299726. Reviewed By: davidberard98 Differential Revision: D59309772 fbshipit-source-id: c618f567c5ba9827cfe5763b3c82c2e304c1e6af
Author
Parents
Loading