Rewrite NCCL watchdog to more reliably throw timeout (#97066)
Fixes #97191
This PR aims to propagate collective exceptions (async error or timeout) up to the program, so as to avoid silent stuck job.
### Previous output in #97191
```
Rank 0 is the problematic rank
Rank 4 completed
Rank 5 completed
Rank 3 completed
Rank 6 completed
Rank 2 completed
Rank 7 completed
Rank 1 completed
[E ProcessGroupNCCL.cpp:464] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=10000) ran for 10917 milliseconds before timing out.
Rank 0 completed
[E ProcessGroupNCCL.cpp:478] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:483] To avoid data inconsistency, we are taking the entire process down.
```
Although it says that it is taking the process down, it sometimes fails to do so.
### New output after this PR:
```
...
[E ProcessGroupNCCL.cpp:459] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=10000) ran for 10599 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:473] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:479] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:818] [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=10000) ran for 10599 milliseconds before timing out.
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 194470) of binary: /data/home/kw2501/repos/pytorch-dev-env/bin/python
Traceback (most recent call last):
File "/pytorch-dev-env/bin/torchrun", line 33, in <module>
sys.exit(load_entry_point('torch', 'console_scripts', 'torchrun')())
File "/pytorch-dev/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
return f(*args, **kwargs)
File "/pytorch-dev/torch/distributed/run.py", line 794, in main
run(args)
File "/pytorch-dev/torch/distributed/run.py", line 785, in run
elastic_launch(
File "/pytorch-dev/torch/distributed/launcher/api.py", line 134, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/pytorch-dev/torch/distributed/launcher/api.py", line 250, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
hang.py FAILED
------------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2023-03-20_22:00:42
host : node0
rank : 0 (local_rank: 0)
exitcode : -6 (pid: 194470)
error_file: <N/A>
traceback : Signal 6 (SIGABRT) received by PID 194470
============================================================
```
The log suggests that TorchX monitor is triggered, and job is torn down.
### Major changes in this PR:
1. Merge ncclWatchDog thread and workCleanupLoop thread into one so that the watch action and the throw action are streamlined.
Previously, ncclWatchDog is responsible for watching comm error and timeout, and workCleanupLoop is responsible for watching Work item error and throwing exception. This two-thread design is not streamlined, raising the chance of missing the throw. Also, it is duplicated to watch at multiple level.
2. Rethrow exception at watchdog thread.
3. Clean up a bunch of duplicated functions, e.g. `checkAndThrowException` and `handleNcclException`.
4. Turn on ASYNC_ERROR_HANDLING by default
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97066
Approved by: https://github.com/rohan-varma