xla
6c7ca993 - [Distributed] Adopts traceable all_reduce (#4915)

Commit
2 years ago
[Distributed] Adopts traceable all_reduce (#4915) Summary: This pull request adopts the traceable all_reduce as mentioned in pytorch/pytorch#93173. 1. It registers an implementation to c10d_functional.all_reduce via TORCH_LIBRARY_IMPL interface. 2. It then hooks xm.all_reduce to use the torch.ops.c10d_functional.all_reduce op, which will route to the above implementation. 3. Currently it only supports a very basic usage that assumes scale == 1.0 and groups == [] and pin_layout. Test Plan: PJRT_DEVICE=TPU python test/test_mp_replication.py
Author
Parents
Loading