jax
e61ca913 - Implement split_axis for all_to_all

Commit
5 years ago
Implement split_axis for all_to_all This allows us to use `all_to_all` over a mix of vmapped and pmapped dimensions, which will be useful for `gmap`.
Author
Committer
Parents
Loading