jax
a2a5068e - Changed ``pl.BlockSpec`` to accept ``block_shape`` before ``index_map``

Commit
1 year ago
Changed ``pl.BlockSpec`` to accept ``block_shape`` before ``index_map`` So, instead of pl.BlockSpec(lambda i, j: ..., (42, 24)) ``pl.BlockSpec`` now expects pl.BlockSpec((42, 24), lambda i, j: ...) I will update Pallas tests in a follow up. PiperOrigin-RevId: 648486321
Author
Committer
Parents
Loading