jax
e61ca913
- Implement split_axis for all_to_all
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
5 years ago
Implement split_axis for all_to_all This allows us to use `all_to_all` over a mix of vmapped and pmapped dimensions, which will be useful for `gmap`.
References
#4429 - Add support for `all_to_all` in vmap
Author
apaszke
Committer
apaszke
Parents
fa38f250
Loading