jax
88cd785d - [Pallas TPU] Refactor Pallas Device ID Logic to use JAX lowering

Commit
42 days ago
[Pallas TPU] Refactor Pallas Device ID Logic to use JAX lowering This change modifies `_device_id_to_logical` within the Pallas/Mosaic lowering codebase. Previously, it operated directly on MLIR values. The logic is now implemented using native JAX operations (like `jax.lax.div/rem`) operating on JAX arrays. PiperOrigin-RevId: 889532032
Author
Parents
Loading