jax
16b3f00e - Register GPU/TPU lowering for pallas_call_p lazily

Commit
1 year ago
Register GPU/TPU lowering for pallas_call_p lazily Prior to this change we had to import jax.experimental.pallas.{gpu,tpu} in jax.experimental.pallas only to get the lowering rules registered. PiperOrigin-RevId: 620957622
Author
Committer
Parents
Loading