xla
13485454 - fix(xla): convert group-local to global ranks in broadcast (#9657)

Commit
137 days ago
fix(xla): convert group-local to global ranks in broadcast (#9657) Related AWS Neuron ticket: https://t.corp.amazon.com/V1941917988/overview broadcast was passing group-local ranks directly to xm.collective_broadcast() which expects global ranks, causing data curroption in single-member process groups TEST: ``` import os import torch import torch.distributed as dist import torch_xla as xla import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.runtime as xr def main(): dist.init_process_group(backend="xla") rank = dist.get_rank() world_size = dist.get_world_size() tp = dist.new_group(ranks=[rank]) tp_rank = dist.get_rank(group=tp) tp_size = dist.get_world_size(group=tp) print( f">>>> pid={os.getpid()}, rank={rank}\n" f">>> world_size={world_size}\n" f">>> tp_rank={tp_rank}, tp_size={tp_size}, tp_members={dist.get_process_group_ranks(tp)}" ) do_train, do_valid, do_test = 0.1, 0.2, 0.3 # breakpoint() flags = torch.tensor([do_train, do_valid, do_test], dtype=torch.float32, device='xla') # breakpoint() dist.broadcast(flags, rank, group=tp) print(f">>>> pid={os.getpid()}, rank={rank}\n" f">>> do_train={flags[0].item()}, do_valid={flags[1].item()}, do_test={flags[2].item()}\n" f">>> global_ordinal={xr.global_ordinal()}") if __name__ == "__main__": main() ``` Results after this fix: ``` torchrun --nproc-per-node=2 --nnodes=1 ./bug.py W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] ***************************************** W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] ***************************************** >>>> pid=1081679, rank=0 >>> world_size=2 >>> tp_rank=0, tp_size=1, tp_members=[0] >>>> pid=1081680, rank=1 >>> world_size=2 >>> tp_rank=0, tp_size=1, tp_members=[1] . . . 2.19.8089.0+8ab9f450/MODULE_10344927339446294134+e30acd3a/model.neff >>>> pid=1081680, rank=1 >>> do_train=0.10000000149011612, do_valid=0.20000000298023224, do_test=0.30000001192092896 >>> global_ordinal=1 >>>> pid=1081679, rank=0 >>> do_train=0.10000000149011612, do_valid=0.20000000298023224, do_test=0.30000001192092896 ``` Now both ranks have the correct values. Previously Rank1 was all zeros.
Author
Parents
Loading