Add more features to the C++ jax.jit. (#4169)
This mainly follows https://github.com/google/jax/pull/4089 by adding:
- support for disable_jit from C++
- support for jax._cpp_jit on methods.
- supporting applying @jax.jit on top-level functions, by delaying the retrieval of the device and backend.
- concurrency support.
I am not aware of any feature missing (but I suspect there are still some differences due to the differences between xla_computation and _xla_callable.)
See:
- https://i.ibb.co/ZMvZ4nK/benchmark.png for the benchmarking comparison (see
cr/328899906 + benchmarks for how numbers were generated)
- The results of the Jax tests when enabling this:
http://sponge2/4a67d132-209f-45c5-ab7b-83716d329ec2 (110 fails, 92 passes, but many common cause of failure).