pytorch
964c7e3e - [BE][DTensor] fix DTensor equal op (#99014)

Commit
1 year ago
[BE][DTensor] fix DTensor equal op (#99014) ## What problem this PR solves? #97170 fixed `equal` operator return type (old: Tensor, now: bool) by giving it the correct sharding propagation. This is consistent with the `aten::equal` op. However, the correctness only stays at the local result level: * `equal` op returns True if the local copy of dtensor A equals to the the local copy of dtensor B This is not the correct semantic of `equal` which should return True if all local copies of A are equal to the corresponding local copies of B. ## What is this PR? 1. For non-participating ranks, if the return type is scalar, `local_results` is set to `None` which means the default value is a reduced result of participating ranks only. 2. For all ranks, if the return type is scalar and the `op_call` is `aten::equal`(because `aten::equal` is the only function that returns scalar value and needs communication), all gather the `local_results` within the `default pg` and reduce on them with `operator.and_`. The result will be the new `local_result`. ## Result/Impact For non-participating ranks and the return type is scalar: 1. op is `aten::equal`, the return value is same with all other ranks 2. op is not `aten::equal`, the return value is None. Before this PR, this will raise "NotImplementedError" but has not been tested. For participating ranks and the return type is scalar: 1. op is `aten::equal`, the return value is the equality of two dtensor operands - True if all copies are equal, False otherwise. 2. op is not `aten::equal`, simply the local computation result. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99014 Approved by: https://github.com/wanchaol
Author
Committer
Parents
Loading