jax
8fe3c59c - Add explicit derivative for jax.numpy.linalg.pinv. (#2794)

Commit
5 years ago
Add explicit derivative for jax.numpy.linalg.pinv. (#2794) * Add explicit derivative for jax.numpy.linalg.pinv. * Fix type confusion problems in the JVP rule for SVD that meant it produced 64-bit tangents for 32-bit primals.
Author
Parents
Loading