jax
692d820f - [pmap] Inline _prepare_pmap and clean up unused structs.

Commit
93 days ago
[pmap] Inline _prepare_pmap and clean up unused structs. Improving the `jax.jit(jax.shard_map)` implementation of `jax.pmap`. PiperOrigin-RevId: 861777664
Author
Parents
Loading