[Distributed] Adopts traceable all_reduce (#4915)
Summary:
This pull request adopts the traceable all_reduce as mentioned in pytorch/pytorch#93173.
1. It registers an implementation to c10d_functional.all_reduce via TORCH_LIBRARY_IMPL interface.
2. It then hooks xm.all_reduce to use the torch.ops.c10d_functional.all_reduce op, which will route to the above implementation.
3. Currently it only supports a very basic usage that assumes scale == 1.0 and groups == [] and pin_layout.
Test Plan:
PJRT_DEVICE=TPU python test/test_mp_replication.py