jax
64775d02 - Async dispatch expensive computations on the JAX CPU backend.

Commit
1 year ago
Async dispatch expensive computations on the JAX CPU backend. Before the change, on CPU backend we always run computations inline unless there are multiple CPU devices with potential collectives. Now, we will use `HloCostAnalysis` to estimate the cost of the computation and do async dispatch if it is expensive. Add a JAX flag for users to opt-out by adding `jax.config.update('jax_cpu_enable_async_dispatch', False)` in their programs. PiperOrigin-RevId: 625064815
Author
Committer
Parents
Loading