jax
a628f21c - [pallas] rewrite pull_block_spec rule to work in terms of block_index transforms

Commit
47 days ago
[pallas] rewrite pull_block_spec rule to work in terms of block_index transforms Instead of modifying block spec index_maps the pull rules are written in terms of post-evaluation block index transforms (applied after original index_map is evaluated, so we needn't consider grids or kernel scalar prefetch variables when pulling block specs). PiperOrigin-RevId: 885206213
Author
Parents
Loading