jax
8fe3c59c
- Add explicit derivative for jax.numpy.linalg.pinv. (#2794)
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
5 years ago
Add explicit derivative for jax.numpy.linalg.pinv. (#2794) * Add explicit derivative for jax.numpy.linalg.pinv. * Fix type confusion problems in the JVP rule for SVD that meant it produced 64-bit tangents for 32-bit primals.
References
#2794 - Add explicit derivative for jax.numpy.linalg.pinv.
Author
hawkinsp
Parents
c3ab1fc5
Loading