xla
7cf9f10a - [Pallas] Introduce jax_import_guard (#6747)

Commit
1 year ago
[Pallas] Introduce jax_import_guard (#6747) Summary: Importing JAX will lock the TPU devices and prevent any pytorxh/xla's TPU computations. To address it, we need to acquire the TPU first. Test Plan: python test/test_pallas.py
Author
Parents
Loading