xla
caf51687 - [SPMD] Avoid recompilations in xs.mark_sharding() (#5300)

Commit
2 years ago
[SPMD] Avoid recompilations in xs.mark_sharding() (#5300) Summary: This pull requests fixes the recompilation issue in xs.mark_sharding(). xtensor->GetXlaData() will compile the program if xtensor is an IR in order to get the BackendData. I believe this is not intended given the error message below suggests only data type xtensors are supported. Test Plan: PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_xla_sharding.py
Author
Parents
Loading