jax
80d1fbac - Handle `sharding` param in convert_element_type's batching rule properly by adding the explicit mesh axis on dim 0

Commit
1 year ago
Handle `sharding` param in convert_element_type's batching rule properly by adding the explicit mesh axis on dim 0 PiperOrigin-RevId: 749125322
Author
Parents
Loading