[ddp launch] solve zombie problem (#49305)
Summary:
I was exhausted with needing to hunt down zombies when working with ddp launcher, so this PR solves the various zombie issues.
This PR addresses 2 distinct zombie scenarios caused by ddp launch.py:
1. When the main process is killed, the child processes aren't killed and continue running
2. When any of the children processes dies (e.g. OOM), the rest of the children and the parent remain running, but really are stuck
To solve these problems this PR switches from `wait` to `poll` and uses signal handlers.
The main problem with `wait()` was that it's not async, and I was having a 2nd process OOM, and the code was stuck waiting for the first process to finish which will not happen since the first process is blocking now waiting for the 2nd process - a sort of deadlock. My 2nd card is smaller than the first one, so it occasionally OOMs.
Using `asyncio` would probably be the cleanest solution, but as it's relatively new in python, perhaps polling is good enough.
I wrote this little script to reproduce 2 problematic scenarios and a normal running setup, it does 3 different things according to the `--mode` arg
- `oom` - causes the 2nd process to exit prematurely emulating OOM
- `clean-finish` - just exit normally in both processes
- `False` (lack of arg) just keep on running - emulating multiple normally running processes
```
# oom.py
import argparse
from time import sleep
import sys
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=False, type=int)
parser.add_argument("--mode", default=False, type=str)
args, _ = parser.parse_known_args()
print(f"{args.local_rank} is starting")
sleep(3)
if args.mode == "oom":
# emulate OOM in 2nd card
if args.local_rank == 1:
raise RuntimeError("OOM")
if args.mode == "clean-finish":
sleep(1)
print(f"{args.local_rank} is cleanly finishing")
sys.exit(0)
while (True):
# emulate long running process
print(f"{args.local_rank} is running")
sleep(1)
if __name__ == "__main__":
main()
```
Let's begin:
### 1. Normal execution
```
python -m torch.distributed.launch --nproc_per_node=2 ./oom.py --mode=clean-finish
```
All the processes exit upon completion - I won't bother pasting the log here - just testing that my code didn't break the normal running
### 2. OOM
```
python -m torch.distributed.launch --nproc_per_node=2 ./oom.py --mode=oom
```
```
POLLING FOR 17547
POLLING FOR 17548
0
0 is starting
1
1 is starting
POLLING FOR 17547
POLLING FOR 17548
POLLING FOR 17548
POLLING FOR 17547
POLLING FOR 17547
POLLING FOR 17548
0 is running
Traceback (most recent call last):
File "./oom.py", line 33, in <module>
main()
File "./oom.py", line 20, in main
raise RuntimeError("OOM")
RuntimeError: OOM
POLLING FOR 17548
process 17548 is no more
Killing subprocess 17547
Killing subprocess 17548
Traceback (most recent call last):
File "/home/stas/anaconda3/envs/main-38/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/stas/anaconda3/envs/main-38/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/distributed/launch.py", line 341, in <module>
main()
File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/distributed/launch.py", line 327, in main
sigkill_handler(signal.SIGTERM, None) # not coming back
File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/distributed/launch.py", line 301, in sigkill_handler
raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd)
subprocess.CalledProcessError: Command '['/home/stas/anaconda3/envs/main-38/bin/python', '-u', './oom.py', '--local_rank=1', '--mode=oom']' returned non-zero exit status 1.
```
All processes exited and the trace was printed
### 3. Exit on SIGINT/SIGTERM
If I started a process and then realized I made a mistake I want to be able to kill it cleanly and if any sub-processes have already been spawned I want them to be killed too. Here the sighandler takes care of trapping the SIGTERM/SIGINT.
```
python -m torch.distributed.launch --nproc_per_node=2 ./oom.py
```
Here the processes emulate a long normal run.
So let's Ctrl-C the process as soon as it started and see:
```
POLLING FOR 18749
POLLING FOR 18750
0
0 is starting
1
1 is starting
POLLING FOR 18749
POLLING FOR 18750
POLLING FOR 18750
POLLING FOR 18749
POLLING FOR 18749
POLLING FOR 18750
0 is running
1 is running
POLLING FOR 18750
POLLING FOR 18749
0 is running
1 is running
^CTraceback (most recent call last):
Killing subprocess 18749
Traceback (most recent call last):
File "./oom.py", line 33, in <module>
File "./oom.py", line 33, in <module>
Killing subprocess 18750
Parent got kill signal=SIGINT, exiting
```
all processes got killed
--------------------------------
So this covered the 2 problematic cases and 1 normal case
Notes:
- we could probably switch to `sleep(3)` - `1` is probably too fast
- all the debug prints will be removed once you are happy - I left them so that it's easier for you to test that my PR does the right thing.
Thank you!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49305
Reviewed By: izdeby
Differential Revision: D25565617
Pulled By: rohan-varma
fbshipit-source-id: 1ea864113f283d4daac5eef1131c8d745aae4c99