[jax2tf] Add support for first-order AD to converted functions (#3593)
* [jax2tf] Add support for first-order AD to converted functions
On conversion, optionally, we convert the jax.vjp of the converted
function and set the result as tf.custom_gradient.
* Only first-order diff is supported for now.
* Add more tests, including for round-tripping through SavedModel.
* Minor whitespace fix