jax
f81f2589 - [Pallas TPU] Support memory space constraints on pallas_call inputs.

Commit
316 days ago
[Pallas TPU] Support memory space constraints on pallas_call inputs. This CL adds: - a new function `with_memory_space_constraint` that allow to add a memory space constraint on a Pallas call input. - `HBM` enum value for Pallas TPU memory spaces This CL only supports HBM and VMEM at the moment. In general these annotations only work on TPU because the mechanism that enforces them is only present in XLA TPU. PiperOrigin-RevId: 770694413
Author
Parents
Loading