Introduce an new system of logical axis names for use with pjit.
This introduces flax machinery for locally defining logical axes assignments
to parameter definitions and a version of with_sharding_constraint that works
with logical axis names.
PiperOrigin-RevId: 403524153