ZeRO3: Improve mismatch detection (#7525)
ZeRO3 tracks DDP (SPMD) behavior by matching values different training
states across ranks. Some of these states are represented as lists, and
mismatches sometimes manifests as hangs during error detection. This PR
improves error detection by first validating the list lengths across
ranks before validating the list contents.
Motivated by
https://github.com/deepspeedai/DeepSpeed/issues/7461#issuecomment-3235146207
---------
Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>