flax
11f8151d - Introduce an new system of logical axis names for use with pjit.

Commit
4 years ago
Introduce an new system of logical axis names for use with pjit. This introduces flax machinery for locally defining logical axes assignments to parameter definitions and a version of with_sharding_constraint that works with logical axis names. PiperOrigin-RevId: 403524153
Author
Committer
Parents
Loading