jax
5981df7b
- Removed unnecessary jax.tree.map calls from *_callback_impl functions
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
1 year ago
Removed unnecessary jax.tree.map calls from *_callback_impl functions jax.device_put works for any PyTree. PiperOrigin-RevId: 626510762
References
#20838 - Removed unnecessary jax.tree.map calls from *_callback_impl functions
Author
superbobry
Committer
a-googler
Parents
52f5f703
Loading