jax
c4ac0dd6 - Implement the extension to the custom_partitioning API.

Commit
1 year ago
Implement the extension to the custom_partitioning API. Add a sharding rule string and trailing factor_sizes to def_partition, to provide a sharding rule specification when Shardy is used. We use this information to construct a SdyShardingRule and invoke SdyShardingRule.build during MLIR lowering. Extend custom_partitioner tests in pjit_test.py for Shardy sharding rule. PiperOrigin-RevId: 713399604
Author
Parents
Loading