flax
942d6dc2 - Bug fixes and generalizations of nn.partitioning api.

Commit
4 years ago
Bug fixes and generalizations of nn.partitioning api. - We didn't previously support assigning a single logical axis to multiple mesh axes. - We required tedious assignments of logically named axes to None, instead None (ie not-sharded on this dimension) is now a default axis assignment. - We now allow general pytrees of logical axis annotations. PiperOrigin-RevId: 443782364
Author
Committer
Parents
Loading