jax
0d5f15f5 - Fix the abstract eval and translation rule for all_to_all

Commit
5 years ago
Fix the abstract eval and translation rule for all_to_all The previous rules assumed that `split_axis == concat_axis` (i.e. that the used collective is equivalent to `pswapaxes`). Since we expose this as part of our API, we should probably make sure that we handle other cases too. Fixes #1332.
Author
Committer
Parents
Loading