[Scaled Matmul] add a sharding rule and fix custom partition method
The sharding rule is needed when Shardy is enabled, and the fix to the custom partition method will ensure that the decisions made by propagation are respected in partitioning (e.g. reduce-scatter dimension based on output sharding).
NOTE: If the change to the custom partition method is causing issues when shardy is enabled, feel free to condition the new behavior on shardy being enabled.
PiperOrigin-RevId: 784585003