jax
5f12495b - Add `scratch_shapes` to Pallas `pl.core_map`.

Commit
13 days ago
Add `scratch_shapes` to Pallas `pl.core_map`. Also removed the `pl.get_global` for SparseCore because the use case is covered by `scratch_shapes`. Also refactored `pl.kernel` to use core_map with scratch_shapes. Also used a workaround on SC SCS due to existing memory space allocation bug. PiperOrigin-RevId: 868830705
Author
Parents
Loading