Support differentiation through jax.lax.all_to_all (#3733)
* Support differentiation through jax.lax.all_to_all
Credit to @levskaya for the solution.
* Test gradient of all_to_all
We are testing all_to_all through pswapaxes, since general all_to_all is problematic according to https://github.com/google/jax/issues/1332.
* Removed trailing spaces