flax
4feaadbc - [JAX] Change the default pmap() ordering to match the ordering of jax.devices() for single-process TPU jobs.

Commit
3 years ago
[JAX] Change the default pmap() ordering to match the ordering of jax.devices() for single-process TPU jobs. PiperOrigin-RevId: 484062717
Author
Committer
Parents
Loading