jax
f02d5b46 - Support differentiation through jax.lax.all_to_all (#3733)

Commit
5 years ago
Support differentiation through jax.lax.all_to_all (#3733) * Support differentiation through jax.lax.all_to_all Credit to @levskaya for the solution. * Test gradient of all_to_all We are testing all_to_all through pswapaxes, since general all_to_all is problematic according to https://github.com/google/jax/issues/1332. * Removed trailing spaces
Author
Parents
Loading