Bug fixes and generalizations of nn.partitioning api.
- We didn't previously support assigning a single logical axis to multiple mesh axes.
- We required tedious assignments of logically named axes to None, instead None (ie not-sharded on this dimension) is now a default axis assignment.
- We now allow general pytrees of logical axis annotations.
PiperOrigin-RevId: 443782364