jax
6e671774
- [pmap] Don't check devices on every invocation of a pmapped function.
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
28 days ago
[pmap] Don't check devices on every invocation of a pmapped function. Improving the `jax.jit(jax.shard_map)` implementation of `jax.pmap`. PiperOrigin-RevId: 860926820
Author
danielsuo
Committer
Google-ML-Automation
Parents
31c17496
Loading