xla
c3001d64 - [Backport] Introduce jax_import_guard (#6794)

Commit
1 year ago
[Backport] Introduce jax_import_guard (#6794) Summary: Importing JAX will lock the TPU devices and prevent any pytorch/xla's TPU computations. To address it, we need to acquire the TPU first. Test Plan: python test/test_pallas.py
Author
Parents
Loading