jax
5981df7b - Removed unnecessary jax.tree.map calls from *_callback_impl functions

Commit
1 year ago
Removed unnecessary jax.tree.map calls from *_callback_impl functions jax.device_put works for any PyTree. PiperOrigin-RevId: 626510762
Author
Committer
Parents
Loading