flax
1ec5ef29 - Fixes spmd to work correctly with xaot compilation by using global mesh's device instead of jax.devices()[0]

Commit
344 days ago
Fixes spmd to work correctly with xaot compilation by using global mesh's device instead of jax.devices()[0] PiperOrigin-RevId: 729357183
Author
Committer
Parents
Loading