jax
6e671774 - [pmap] Don't check devices on every invocation of a pmapped function.

Commit
28 days ago
[pmap] Don't check devices on every invocation of a pmapped function. Improving the `jax.jit(jax.shard_map)` implementation of `jax.pmap`. PiperOrigin-RevId: 860926820
Author
Parents
Loading