jax
a6a8f485
- [JAX] Don't include ShardingSpecs or out_indices in the data passed to the C++ pmap() fast path.
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
2 years ago
[JAX] Don't include ShardingSpecs or out_indices in the data passed to the C++ pmap() fast path. The pmap() fast path doesn't even look the ShardingSpec or the out_indices since the jax.Sharding rework. PiperOrigin-RevId: 553206145
References
#16926 - [JAX] Don't include ShardingSpecs or out_indices in the data passed to the C++ pmap() fast path.
Author
hawkinsp
Committer
a-googler
Parents
f498442d
Loading