flax
4feaadbc
- [JAX] Change the default pmap() ordering to match the ordering of jax.devices() for single-process TPU jobs.
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
3 years ago
[JAX] Change the default pmap() ordering to match the ordering of jax.devices() for single-process TPU jobs. PiperOrigin-RevId: 484062717
Author
hawkinsp
Committer
a-googler
Parents
3747c877
Loading