jax
c51c17c1 - Adds jax_ragged_dot_use_ragged_dot_instruction config/flag (only applicable to TPUs) that forces ragged_dot() lowering to use chlo.ragged_dot instruction.

Commit
230 days ago
Adds jax_ragged_dot_use_ragged_dot_instruction config/flag (only applicable to TPUs) that forces ragged_dot() lowering to use chlo.ragged_dot instruction. The flag will be deprecated once we fully roll out chlo.ragged_dot instruction. PiperOrigin-RevId: 778542373
Parents
Loading