Speedup JVP for triangular solve (#1466)
* Speedup JVP for triangular solve
There is still room for improvement, e.g., by combining the currently separate
JVPs for a and b into a single expression (which will allow for saving an inner
triangular solve when both arguments are being differentiated), but this is
already significantly faster in the typical case of only solving a single
vector.
On my laptop's CPU, I measure 2.98 ms before vs 1.18 ms after on
a 500x500 matrix:
rs = onp.random.RandomState(0)
a = rs.randn(500, 500)
b = rs.randn(500)
@jax.jit
def loss(a, b):
return np.sum(jax.scipy.linalg.solve_triangular(a, b))
grad = jax.jit(jax.grad(loss))
%timeit jax.device_get(grad(a, b))
* comment
* Optimal order for left_side=False, too
* Test the JVP for lax_linalg.triangular_solve directly