add jax.custom_gradient wrapper for jax.custom_vjp
There was a deprecatd version of this wrapper implemented in terms of
jax.custom_transforms (which itself is deprecated, and hopefully soon to
be removed), but this commit adds an implementation in terms of
jax.custom_vjp. One drawback it has relative to jax.custom_vjp is that
it doesn't support Python control flow in the backward-pass function.