jax
49561f67 - [pmap] Remove any wrappings that have auxiliary values.

Commit
11 days ago
[pmap] Remove any wrappings that have auxiliary values. We want to avoid StoreException or StoreEmpty errors. Previous code added another transformation to reset stores. New code does away with this and also gives a chance to remove some unnecessary flatten/unflattens. Improving the `jax.jit(jax.shard_map)` implementation of `jax.pmap`. PiperOrigin-RevId: 861864367
Author
Parents
Loading