benchmark
f50c632f - Add torch.sum benchmark for jagged_softmax operator

Commit
1 year ago
Add torch.sum benchmark for jagged_softmax operator Summary: Add to TritonBench a `jagged_softmax` reduction operator benchmark for nested tensors using the PyTorch `torch.sum` function, [`torch.ops.aten._jagged_to_padded_dense_forward`](https://www.internalfb.com/code/fbsource/[92c2a067ab04e3eebc999254fed4ae2fbea6def3]/fbcode/deeplearning/fbgemm/fbgemm_gpu/fb/inductor_lowerings/elementwise_ops.py?lines=26), and [`torch.ops.aten._padded_dense_to_jagged_forward`](https://www.internalfb.com/code/fbsource/[16a15f9537d5a41100caaf394a398a0ab447d865]/xplat/caffe2/torch/_inductor/jagged_lowerings.py?lines=251). This diff implements two benchmarks: 1. The baseline PyTorch benchmark uses `unbind` to call `torch.softmax` on each variable-length tensor in the nested tensor. This implementation is extremely slow, resulting in very high latency for all input shapes. 2. The more efficient PyTorch benchmark leverages `torch.sum` as well as aten lowerings to pad a jagged tensor, perform operations that execute a softmax, and unpad back into a jagged tensor format. First, the benchmark pads the input tensor using [`torch.ops.aten._jagged_to_padded_dense_forward`](https://www.internalfb.com/code/fbsource/[92c2a067ab04e3eebc999254fed4ae2fbea6def3]/fbcode/deeplearning/fbgemm/fbgemm_gpu/fb/inductor_lowerings/elementwise_ops.py?lines=26). It then stabilizes the padded tensor by subtracting the maximum value in the padded tensor. Next, it performs a softmax operation by dividing the exponent of the padded tensor by the `sum` of the exponent tensor. Lastly, it unpads the padded dense tensor using [`torch.ops.aten._padded_dense_to_jagged_forward`](https://www.internalfb.com/code/fbsource/[16a15f9537d5a41100caaf394a398a0ab447d865]/xplat/caffe2/torch/_inductor/jagged_lowerings.py?lines=251), which returns the `values` tensor as a result. The efficient PyTorch benchmark in this diff avoids a GPU/CPU sync as long as I provide to [`torch.ops.aten._padded_dense_to_jagged_forward`](https://www.internalfb.com/code/fbsource/[16a15f9537d5a41100caaf394a398a0ab447d865]/xplat/caffe2/torch/_inductor/jagged_lowerings.py?lines=251) the [`total_L`](https://www.internalfb.com/code/fbsource/[4b8f2012e316d83f01a16078b334cb485d31b04c]/fbcode/caffe2/test/test_nestedtensor.py?lines=5419) parameter, which is the total length of the `values` tensor of a nested tensor. Notes - This [Stack Overflow post](https://stackoverflow.com/questions/49036993/pytorch-softmax-what-dimension-to-use) was helpful in understanding how to perform softmax along a specific dimension! - [`log_softmax`](https://pytorch.org/docs/stable/generated/torch.nn.functional.log_softmax.html) may be faster and is worth a shot in the future - This implementation works for nested tensors where variable-length tensors have `seqlen = 0` Note on 0-seqlen nested tensors: As discussed offline, we chose to allow 0-seqlen nested tensors (a nested tensor with at least one tensor of dimension (0, M) for a nested tensor of shape (B, *, M)) to no-op cleanly while supporting offsets with duplicated values (e.g. [0, 2, 2, 3, 5]). The PyTorch softmax implementation pads 0-seqlen tensors with NaNs, which are then entirely removed from the result via the unpadding. The custom Triton implementations will perform no operations on the 0-seqlen tensors, effectively performing a no-op. Reviewed By: davidberard98 Differential Revision: D59288946 fbshipit-source-id: 92bd7ba02bdadb945a34a5948b8732692e54afe7
Author
Parents
Loading