jax
9a9fc023 - Experimental pallas-triton lowering of ragged_dot_general on GPU.

Commit
8 days ago
Experimental pallas-triton lowering of ragged_dot_general on GPU. This is an experiment in practicability of lowering lax ops to pallas. In this case the TPU lowering of [ragged dot general](https://docs.jax.dev/en/latest/_autosummary/jax.lax.ragged_dot_general.html#jax.lax.ragged_dot_general) in XLA is good, but the GPU lowering currently pads the ragged lhs to the worst case and performs a masked matmul, which can be highly inefficient. This PR adds a pallas-triton kernel lowering for the op on GPU which should be more efficient in all cases. The next steps in this exploration are: * exposing `lowering_options` to allow user tuning * testing the real-world performance of this GPU ragged dot general kernel PiperOrigin-RevId: 895423037
Author
Parents
Loading