fix(xla): convert group-local to global ranks in broadcast (#9657)
Related AWS Neuron ticket:
https://t.corp.amazon.com/V1941917988/overview
broadcast was passing group-local ranks directly to
xm.collective_broadcast() which expects global ranks, causing data
curroption in single-member process groups
TEST:
```
import os
import torch
import torch.distributed as dist
import torch_xla as xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr
def main():
dist.init_process_group(backend="xla")
rank = dist.get_rank()
world_size = dist.get_world_size()
tp = dist.new_group(ranks=[rank])
tp_rank = dist.get_rank(group=tp)
tp_size = dist.get_world_size(group=tp)
print(
f">>>> pid={os.getpid()}, rank={rank}\n"
f">>> world_size={world_size}\n"
f">>> tp_rank={tp_rank}, tp_size={tp_size}, tp_members={dist.get_process_group_ranks(tp)}"
)
do_train, do_valid, do_test = 0.1, 0.2, 0.3
# breakpoint()
flags = torch.tensor([do_train, do_valid, do_test], dtype=torch.float32, device='xla')
# breakpoint()
dist.broadcast(flags, rank, group=tp)
print(f">>>> pid={os.getpid()}, rank={rank}\n"
f">>> do_train={flags[0].item()}, do_valid={flags[1].item()}, do_test={flags[2].item()}\n"
f">>> global_ordinal={xr.global_ordinal()}")
if __name__ == "__main__":
main()
```
Results after this fix:
```
torchrun --nproc-per-node=2 --nnodes=1 ./bug.py
W0926 18:50:41.903000 1081605 torch/distributed/run.py:766]
W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] *****************************************
W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] *****************************************
>>>> pid=1081679, rank=0
>>> world_size=2
>>> tp_rank=0, tp_size=1, tp_members=[0]
>>>> pid=1081680, rank=1
>>> world_size=2
>>> tp_rank=0, tp_size=1, tp_members=[1]
.
.
.
2.19.8089.0+8ab9f450/MODULE_10344927339446294134+e30acd3a/model.neff
>>>> pid=1081680, rank=1
>>> do_train=0.10000000149011612, do_valid=0.20000000298023224, do_test=0.30000001192092896
>>> global_ordinal=1
>>>> pid=1081679, rank=0
>>> do_train=0.10000000149011612, do_valid=0.20000000298023224, do_test=0.30000001192092896
```
Now both ranks have the correct values. Previously Rank1 was all zeros.