pytorch
34e5b099 - [reland] Make allreduce compatible with make_fx (#84221)

Commit
3 years ago
[reland] Make allreduce compatible with make_fx (#84221) land after #83122 This PR explores solutions for 2 issues: 1. Collective comm ops are inplace ops, and does not return a tensor. With that, `make_fx` cannot include comm ops in the traced graph. The current solution is to make comm ops return a tuple of `(output_tensors, work_handle)`, so that [`proxy_call`](https://github.com/pytorch/pytorch/blob/90821aab100a436424113e2306eac63f5e247ee5/torch/fx/experimental/proxy_tensor.py#L170-L172) can handle that. It won't change the behavior of existing c10d Python/C++ APIs, so I directly added the code to `Ops.cpp`. 2. `make_fx` does not recognize `ProcessGroup::Work` and will ignore the `wait()` call on the work when tracing graph. However, this might break correctness, as when running the traced function, it could consume a tensor before it's ready. The current solution is to create a `CommTensor` tensor subclass to explicitly call `wait()`. In this PR, I am only doing this in the test, as we will need more discussion to see if we can add this to c10d Python implementations. kudos to Chillee wanchaol Edit: `print_tabular` breaks CI. removing that from tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/84221 Approved by: https://github.com/wanchaol
Author
Committer
Parents
Loading