jax
3c6cdcfc - add jax.custom_gradient wrapper for jax.custom_vjp

Commit
5 years ago
add jax.custom_gradient wrapper for jax.custom_vjp There was a deprecatd version of this wrapper implemented in terms of jax.custom_transforms (which itself is deprecated, and hopefully soon to be removed), but this commit adds an implementation in terms of jax.custom_vjp. One drawback it has relative to jax.custom_vjp is that it doesn't support Python control flow in the backward-pass function.
Author
Committer
Parents
Loading