jax
23563247 - Implement explicit mode sharding rule for scatter

Commit
269 days ago
Implement explicit mode sharding rule for scatter Associated changes: - Check on updates batching dims added to scatter_shape_rule - Sharding rule for scatter and scatter_add/mul/max... added - resolve_mesh added as function to extract mesh (if uniquely defined) from a group of meshes PiperOrigin-RevId: 786300765
Parents
Loading