jax
bdd65453 - Add more features to the C++ jax.jit. (#4169)

Commit
5 years ago
Add more features to the C++ jax.jit. (#4169) This mainly follows https://github.com/google/jax/pull/4089 by adding: - support for disable_jit from C++ - support for jax._cpp_jit on methods. - supporting applying @jax.jit on top-level functions, by delaying the retrieval of the device and backend. - concurrency support. I am not aware of any feature missing (but I suspect there are still some differences due to the differences between xla_computation and _xla_callable.) See: - https://i.ibb.co/ZMvZ4nK/benchmark.png for the benchmarking comparison (see cr/328899906 + benchmarks for how numbers were generated) - The results of the Jax tests when enabling this: http://sponge2/4a67d132-209f-45c5-ab7b-83716d329ec2 (110 fails, 92 passes, but many common cause of failure).
Author
Parents
Loading