jax
802a14cd
- Re-pack gradients of jax.experimental.sparse.grad() to match original pytrees & test cases
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
1 year ago
Re-pack gradients of jax.experimental.sparse.grad() to match original pytrees & test cases
References
#19760 - Update sparse.grad() to support re-packing gradients into PyTrees
Author
Blair-Johnson
Committer
Blair-Johnson
Parents
85e83b50
Loading