jax
5eac4772 - [jax2tf] Implementation of random_gamma (#4192)

Commit
5 years ago
[jax2tf] Implementation of random_gamma (#4192) * [jax2tf] implementation of random_gamma The simplest implementation is by converting the JAX own impl_rule, which rewrites gamma into other JAX primitives. On TPU with use_vmap=True the performance is the same for JAX and TF, provided we use tf.function(compile=True).
Author
Parents
Loading