pytorch
775a22c7 - Add all_gather_into_tensor in place of _all_gather_base (#85686)

Commit
2 years ago
Add all_gather_into_tensor in place of _all_gather_base (#85686) ### Description - This PR renames `_all_gather_base` to `all_gather_into_tensor` so that it is clearer in meaning. - The `all_gather_into_tensor` API differs from the `all_gather` API in the output it accepts -- a single, large tensor instead of a list of tensors. - This PR also adds deprecation warning to `_all_gather_base`. ### Issue `_all_gather_base` was implemented in https://github.com/pytorch/pytorch/pull/33924 to avoid unnecessary flattening. There was previous effort (#82639) to merge `_all_gather_base` with the existing `all_gather` API by detecting the parameter type passed in for the output. There are, however, two "blockers" that make the merge difficult: (i) The merge leads to backward compatibility break. We would need to change the parameter name `tensor_list` in `all_gather` to a general name `output` that can cover both tensor and tensor list. (ii) Recently, the `all_gather` API has added uneven tensor support, utilizing the tensor boundaries implied by the list. We are, however, not sure to add such support to the `_all_gather_base` function, because that would require users to pass in additional tensor boundary information. In view of the above, we decided to productize `_all_gather_base` as a separate function, but with a clearer name. ### Testing Added tests: - `test_all_gather_into_cat_tensor_cuda` -- output form as with `torch.cat`. For example: ``` >>> tensor_in tensor([1, 2], device='cuda:0') # Rank 0 tensor([3, 4], device='cuda:1') # Rank 1 >>> tensor_out tensor([1, 2, 3, 4], device='cuda:0') # Rank 0 tensor([1, 2, 3, 4], device='cuda:1') # Rank 1 ``` - `test_all_gather_into_stack_tensor_cuda` -- output form as with `torch.stack`. For example: ``` >>> tensor_out2 tensor([[1, 2], [3, 4]], device='cuda:0') # Rank 0 tensor([[1, 2], [3, 4]], device='cuda:1') # Rank 1 ``` The output form is determined by the shape of the output tensor passed by the user, no flag used. Cc @rohan-varma @mrshenli @crcrpar @ptrblck @H-Huang Pull Request resolved: https://github.com/pytorch/pytorch/pull/85686 Approved by: https://github.com/rohan-varma, https://github.com/crcrpar
Author
Committer
Parents
Loading