jax
b23c4237 - [sharding_in_types] If an indexing operation hits into `gather_p`, error out saying to use `.at[...].get(out_spec=...)` instead.

Commit
1 year ago
[sharding_in_types] If an indexing operation hits into `gather_p`, error out saying to use `.at[...].get(out_spec=...)` instead. This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`. Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 716295953
Author
Parents
Loading