jax
9e43153e - [Scaled Matmul] add a sharding rule and fix custom partition method

Commit
215 days ago
[Scaled Matmul] add a sharding rule and fix custom partition method The sharding rule is needed when Shardy is enabled, and the fix to the custom partition method will ensure that the decisions made by propagation are respected in partitioning (e.g. reduce-scatter dimension based on output sharding). NOTE: If the change to the custom partition method is causing issues when shardy is enabled, feel free to condition the new behavior on shardy being enabled. PiperOrigin-RevId: 784585003
Author
Parents
Loading