jax
a6a8f485 - [JAX] Don't include ShardingSpecs or out_indices in the data passed to the C++ pmap() fast path.

Commit
2 years ago
[JAX] Don't include ShardingSpecs or out_indices in the data passed to the C++ pmap() fast path. The pmap() fast path doesn't even look the ShardingSpec or the out_indices since the jax.Sharding rework. PiperOrigin-RevId: 553206145
Author
Committer
Parents
Loading