flax
fcc2e0e5 - Create sharding via Partitioned.get_sharding()

Commit
1 year ago
Create sharding via Partitioned.get_sharding() This change modifies the global get_sharding() function to call into Partitioned.get_sharding(). This allows subclasses of Partitioned to override the way sharding is created. PiperOrigin-RevId: 705204218
Author
hhb hhb
Committer
Parents
Loading