jax
c8247786 - Speedup JVP for triangular solve (#1466)

Commit
6 years ago
Speedup JVP for triangular solve (#1466) * Speedup JVP for triangular solve There is still room for improvement, e.g., by combining the currently separate JVPs for a and b into a single expression (which will allow for saving an inner triangular solve when both arguments are being differentiated), but this is already significantly faster in the typical case of only solving a single vector. On my laptop's CPU, I measure 2.98 ms before vs 1.18 ms after on a 500x500 matrix: rs = onp.random.RandomState(0) a = rs.randn(500, 500) b = rs.randn(500) @jax.jit def loss(a, b): return np.sum(jax.scipy.linalg.solve_triangular(a, b)) grad = jax.jit(jax.grad(loss)) %timeit jax.device_get(grad(a, b)) * comment * Optimal order for left_side=False, too * Test the JVP for lax_linalg.triangular_solve directly
Author
Parents
Loading