benchmark
c75d1a51 - Add simple fused Triton kernel benchmark for jagged_softmax

Commit
1 year ago
Add simple fused Triton kernel benchmark for jagged_softmax Summary: Add Triton kernel benchmark implementing a simple fused `softmax` for the `jagged_softmax` operator. This Triton kernel performs a `softmax` operation along the ragged dimension of a nested tensor of logical dimensions `(B, *, M)`, where `*` is the ragged dimension. The kernel implements `softmax` in four phases. 1. Calculate the maximum values along the ragged dimension of input using a buffer that the loop updates. Given each `(BLOCK_SIZE_RAGGED, BLOCK_SIZE_M)` block of input, find the maximum values along the ragged dimension and update `buffer_max` if necessary. 2. Looping through the same blocks of input, add stable, exponentiated blocks of input to a new buffer. Subtract from each input block the maximum value along the ragged dimension, then exponentiate this stable input and add to a buffer. 3. Calculate the `sum` of the exponentiated buffer, which will be the denominator of the `softmax` function. 4. Divide each input block, the numerator of the `softmax` function, by the `sum` of the exponentiated buffer. As with previous jagged operators, the kernel is benchmarked against a padded PyTorch implementation and verifies accuracy against a baseline PyTorch `unbind` implementation. This implementation uses the `buffer_then_sum` method; it adds all exponentiated input to a buffer, then takes the sum at once. This method has been proven to be faster, as seen in previous jagged operators (like `jagged_sum` and `jagged_mean`) and seems to prove that storing to and loading from many registers is faster than taking multiple `sum`s in Triton. Other approaches I tried: - To minimize the number of repetitive operations, specifically in the calculation of the exponentiated stable input, I tried to cache (`tl.store`) the result of this expression and load (`tl.load`) it back in the second phase (iterative loop). However, the latency almost doubled with this implementation, implying that a singular `tl.store` operation is more expensive than a singular `tl.exp` operation. - I also tried using `tl.div_rn` to divide the exponentiated input by the buffer sum, which almost tripled or quadrupled the latency. Notes: - The kernel currently has three `for` loops, which may be less efficient. We probably cannot avoid this if we still want to use the buffer implementation to collect the exponentiated `sum`. As mentioned above, I did try caching, which was slower than just recalculating the required values. - This implementation works for nested tensors where variable-length tensors have `seqlen = 0` - A previous implementation of this kernel took the maximum value of the input from the `operator`, which did not accurately measure the time taken for the kernel to calculate the maximum value itself, nor did it accurately measure the maximum value itself. This new version includes that implementation detail (the [`softmax` source code](https://www.internalfb.com/code/fbsource/[6ebd07f65500ddfb0a2599e96dc33d91c3d88bf0]/fbcode/caffe2/torch/_refs/__init__.py?lines=3890) helped with understanding this!). Reviewed By: davidberard98 Differential Revision: D59299726 fbshipit-source-id: 461d94d29388519bd5aa73c9327b3f5ffd3d007f
Author
Parents
Loading