flax
5012a650 - Introduce an new system of logical axis names for use with pjit.

Commit
4 years ago
Introduce an new system of logical axis names for use with pjit. This introduces flax machinery for locally defining logical axes assignments to parameter definitions and a version of with_sharding_constraint that works with logical axis names. PiperOrigin-RevId: 402700753
References
Author
Committer
Parents
Loading