jax
51e27923 - Simplify pjit's batching rule now that xmap is deleted. Also do cleanup around adding manual axes under shard_map

Commit
1 year ago
Simplify pjit's batching rule now that xmap is deleted. Also do cleanup around adding manual axes under shard_map PiperOrigin-RevId: 655776234
Author
Committer
Parents
Loading