Enhance ProcessGroupWrapper with additional checks + refactor (#60237)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60237
Closes https://github.com/pytorch/pytorch/issues/58711
This diff refactors the collective consistency checking in `ProcessGroupWrapper` as described in the above issue. In particular, we no longer run separate verification checks (`all_gather`s) for shapes, op type, etc. Instead, we implement a function `serialize_fingerprint` to serialize all this data into a single tensor and only verify that.
This has the benefit of being a lot more extensible, the developer does not need to add separate `all_gather` calls in order to verify additional data in the future. We can also provide some sort of mechanism where we allow data that needs to be verified to be "registered" in the `CollectiveFingerPrint` struct and make it even easier to add additional data, we can consider doing this if there are significant additions to `process group wrapper`.
We now also begin to check tensor `dtypes` and device types for consistency as well. Tests are refactored/added accordingly.
ghstack-source-id: 132520261
Test Plan: CI
Reviewed By: cbalioglu
Differential Revision: D28597287
fbshipit-source-id: b09f14f628df9e2457623ba81fc13fd4e214f3c9