jax
7f7fd997 - [jax2tf] Add support for first-order AD to converted functions (#3593)

Commit
5 years ago
[jax2tf] Add support for first-order AD to converted functions (#3593) * [jax2tf] Add support for first-order AD to converted functions On conversion, optionally, we convert the jax.vjp of the converted function and set the result as tf.custom_gradient. * Only first-order diff is supported for now. * Add more tests, including for round-tripping through SavedModel. * Minor whitespace fix
Author
Parents
Loading