DDP Communication Hook Main Structure (#40848)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40848
Sub-tasks 1 and 2 of [39272](https://github.com/pytorch/pytorch/issues/39272)
ghstack-source-id: 107787878
Test Plan:
1\. Perf tests to to validate new code (if conditions before `allreduce`) doesn't slow down today's DDP. Execute the following command with diff patched/unpatched (with V25):
* **Unpatched Runs:**
```
hg checkout D22514243
flow-cli canary pytorch.benchmark.main.workflow --parameters-json '{"model_arch": "resnet50", "batch_size": 32, "world_size": 1, "use_fp16": false, "print_percentile": true, "backend": "gloo"}' --entitlement pytorch_ftw_gpu --name test_torchelastic_gloo_masterD22514243 --run-as-secure-group pytorch_distributed
```
* **Run 1 (unpatched):** `elastic_gang:benchmark_single.elastic_operator` Ran for 2 mins 59 s
f204539235
```
sum:
8 GPUs: p25: 0.156 205/s p50: 0.160 200/s p75: 0.164 194/s p90: 0.169 189/s p95: 0.173 185/s
fwds:
8 GPUs: p25: 0.032 1011/s p50: 0.032 1006/s p75: 0.032 1000/s p90: 0.032 992/s p95: 0.033 984/s
bwds:
8 GPUs: p25: 0.121 265/s p50: 0.125 256/s p75: 0.129 248/s p90: 0.134 239/s p95: 0.137 232/s
opts:
8 GPUs: p25: 0.003 11840/s p50: 0.003 11550/s p75: 0.004 8037/s p90: 0.006 5633/s p95: 0.007 4631/s
```
* **Run 2 (unpatched):** `elastic_gang:benchmark_single.elastic_operator` Ran for 3 mins 1 s
f204683840
```
sum:
8 GPUs: p25: 0.145 220/s p50: 0.147 217/s p75: 0.150 213/s p90: 0.154 207/s p95: 0.157 204/s
fwds:
8 GPUs: p25: 0.032 1015/s p50: 0.032 1009/s p75: 0.032 1002/s p90: 0.032 994/s p95: 0.032 990/s
bwds:
8 GPUs: p25: 0.107 297/s p50: 0.111 288/s p75: 0.115 278/s p90: 0.119 268/s p95: 0.122 262/s
opts:
8 GPUs: p25: 0.003 11719/s p50: 0.004 9026/s p75: 0.006 5160/s p90: 0.009 3700/s p95: 0.010 3184/s
```
* **Patched Runs:**
```
hg checkout D22328310
flow-cli canary pytorch.benchmark.main.workflow --parameters-json '{"model_arch": "resnet50", "batch_size": 32, "world_size": 1, "use_fp16": false, "print_percentile": true, "backend": "gloo"}' --entitlement pytorch_ftw_gpu --name test_torchelastic_gloo_localD22328310 --run-as-secure-group pytorch_distributed
```
* **Run 1 (patched):** `elastic_gang:benchmark_single.elastic_operator` Ran for 3 mins 30 s
f204544541
```
sum:
8 GPUs: p25: 0.148 216/s p50: 0.152 210/s p75: 0.156 205/s p90: 0.160 200/s p95: 0.163 196/s
fwds:
8 GPUs: p25: 0.032 1011/s p50: 0.032 1005/s p75: 0.032 999/s p90: 0.032 991/s p95: 0.033 984/s
bwds:
8 GPUs: p25: 0.112 286/s p50: 0.116 275/s p75: 0.120 265/s p90: 0.125 256/s p95: 0.128 250/s
opts:
8 GPUs: p25: 0.003 11823/s p50: 0.003 10948/s p75: 0.004 7225/s p90: 0.007 4905/s p95: 0.008 3873/s
```
* **Run 2 (patched):** `elastic_gang:benchmark_single.elastic_operator`
Ran for 3 mins 14 s
f204684520
```
sum:
8 GPUs: p25: 0.146 219/s p50: 0.147 217/s p75: 0.150 214/s p90: 0.152 210/s p95: 0.153 208/s
fwds:
8 GPUs: p25: 0.032 1013/s p50: 0.032 1008/s p75: 0.032 1002/s p90: 0.032 996/s p95: 0.032 990/s
bwds:
8 GPUs: p25: 0.107 299/s p50: 0.110 290/s p75: 0.114 280/s p90: 0.117 274/s p95: 0.119 269/s
opts:
8 GPUs: p25: 0.003 11057/s p50: 0.005 6490/s p75: 0.008 4110/s p90: 0.010 3309/s p95: 0.010 3103/s
```
* **Run 3 (patched):** `elastic_gang:benchmark_single.elastic_operator` Ran for 2 mins 54 s
f204692872
```
sum:
8 GPUs: p25: 0.145 220/s p50: 0.147 217/s p75: 0.150 213/s p90: 0.154 207/s p95: 0.156 204/s
fwds:
8 GPUs: p25: 0.032 1001/s p50: 0.032 995/s p75: 0.032 988/s p90: 0.033 980/s p95: 0.033 973/s
bwds:
8 GPUs: p25: 0.108 295/s p50: 0.111 287/s p75: 0.114 280/s p90: 0.119 269/s p95: 0.121 264/s
opts:
8 GPUs: p25: 0.003 11706/s p50: 0.003 9257/s p75: 0.005 6333/s p90: 0.008 4242/s p95: 0.009 3554/s
```
* **Memory:**
* Unpatched:
```
CUDA Memory Summary After first iteration: |===========================================================================|
| PyTorch CUDA memory summary, device ID 0 |
|---------------------------------------------------------------------------|
| CUDA OOMs: 0 | cudaMalloc retries: 0 |
|===========================================================================|
| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
|---------------------------------------------------------------------------|
| Allocated memory | 428091 KB | 2892 MB | 9825 MB | 9407 MB |
| from large pool | 374913 KB | 2874 MB | 9752 MB | 9386 MB |
| from small pool | 53178 KB | 52 MB | 73 MB | 21 MB |
|---------------------------------------------------------------------------|
| Active memory | 428091 KB | 2892 MB | 9825 MB | 9407 MB |
| from large pool | 374913 KB | 2874 MB | 9752 MB | 9386 MB |
| from small pool | 53178 KB | 52 MB | 73 MB | 21 MB |
|---------------------------------------------------------------------------|
| GPU reserved memory | 3490 MB | 3490 MB | 3490 MB | 0 B |
| from large pool | 3434 MB | 3434 MB | 3434 MB | 0 B |
| from small pool | 56 MB | 56 MB | 56 MB | 0 B |
|---------------------------------------------------------------------------|
| Non-releasable memory | 315332 KB | 343472 KB | 2295 MB | 1987 MB |
| from large pool | 311166 KB | 340158 KB | 2239 MB | 1935 MB |
| from small pool | 4166 KB | 4334 KB | 56 MB | 52 MB |
|---------------------------------------------------------------------------|
| Allocations | 704 | 705 | 1390 | 686 |
| from large pool | 60 | 131 | 395 | 335 |
| from small pool | 644 | 645 | 995 | 351 |
|---------------------------------------------------------------------------|
| Active allocs | 704 | 705 | 1390 | 686 |
| from large pool | 60 | 131 | 395 | 335 |
| from small pool | 644 | 645 | 995 | 351 |
|---------------------------------------------------------------------------|
| GPU reserved segments | 102 | 102 | 102 | 0 |
| from large pool | 74 | 74 | 74 | 0 |
| from small pool | 28 | 28 | 28 | 0 |
|---------------------------------------------------------------------------|
| Non-releasable allocs | 34 | 54 | 430 | 396 |
| from large pool | 15 | 48 | 208 | 193 |
| from small pool | 19 | 19 | 222 | 203 |
|===========================================================================|
```
* Patched:
```
CUDA Memory Summary After first iteration: |===========================================================================|
| PyTorch CUDA memory summary, device ID 0 |
|---------------------------------------------------------------------------|
| CUDA OOMs: 0 | cudaMalloc retries: 0 |
|===========================================================================|
| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
|---------------------------------------------------------------------------|
| Allocated memory | 428091 KB | 2892 MB | 9825 MB | 9407 MB |
| from large pool | 374913 KB | 2874 MB | 9752 MB | 9386 MB |
| from small pool | 53178 KB | 52 MB | 73 MB | 21 MB |
|---------------------------------------------------------------------------|
| Active memory | 428091 KB | 2892 MB | 9825 MB | 9407 MB |
| from large pool | 374913 KB | 2874 MB | 9752 MB | 9386 MB |
| from small pool | 53178 KB | 52 MB | 73 MB | 21 MB |
|---------------------------------------------------------------------------|
| GPU reserved memory | 3490 MB | 3490 MB | 3490 MB | 0 B |
| from large pool | 3434 MB | 3434 MB | 3434 MB | 0 B |
| from small pool | 56 MB | 56 MB | 56 MB | 0 B |
|---------------------------------------------------------------------------|
| Non-releasable memory | 315332 KB | 343472 KB | 2295 MB | 1987 MB |
| from large pool | 311166 KB | 340158 KB | 2239 MB | 1935 MB |
| from small pool | 4166 KB | 4334 KB | 56 MB | 52 MB |
|---------------------------------------------------------------------------|
| Allocations | 704 | 705 | 1390 | 686 |
| from large pool | 60 | 131 | 395 | 335 |
| from small pool | 644 | 645 | 995 | 351 |
|---------------------------------------------------------------------------|
| Active allocs | 704 | 705 | 1390 | 686 |
| from large pool | 60 | 131 | 395 | 335 |
| from small pool | 644 | 645 | 995 | 351 |
|---------------------------------------------------------------------------|
| GPU reserved segments | 102 | 102 | 102 | 0 |
| from large pool | 74 | 74 | 74 | 0 |
| from small pool | 28 | 28 | 28 | 0 |
|---------------------------------------------------------------------------|
| Non-releasable allocs | 34 | 54 | 431 | 397 |
| from large pool | 15 | 48 | 208 | 193 |
| from small pool | 19 | 19 | 223 | 204 |
|===========================================================================|
```
2\. As of v18: `python test/distributed/test_c10d.py`
```
....................s.....s.....................................................s................................
----------------------------------------------------------------------
Ran 114 tests in 215.983s
OK (skipped=3)
```
3\. Additional tests in `python test/distributed/test_c10d.py`:
* `test_ddp_comm_hook_future_passing_cpu`: This unit test verifies whether the Future object is passed properly. The callback function creates a Future object and sets a value to it.
* `_test_ddp_comm_hook_future_passing_gpu`: This unit test verifies whether the Future object is passed properly. The callback function creates a Future object and sets a value to it.
* `test_ddp_comm_hook_future_passing_gpu_gloo`: This unit test executes _test_ddp_comm_hook_future_passing_gpu using gloo backend.
* `test_ddp_comm_hook_future_passing_gpu_nccl`: This unit test executes _test_ddp_comm_hook_future_passing_gpu using nccl backend.
* `test_ddp_invalid_comm_hook_init`: This unit test makes sure that register_comm_hook properly checks the format of hook defined by user. The Python hook must be callable. This test also checks whether bucket annotation checked properly if defined.
* `test_ddp_invalid_comm_hook_return_type`: This test checks whether return annotation checked properly if defined. It also checks whether an internal error is thrown if return type is incorrect and user hasn't specified any return type annotation.
* `test_ddp_comm_hook_register_just_once`: DDP communication hook can only be registered once. This test validates whether the error is thrown properly when register_comm_hook is called more than once.
Reviewed By: ezyang
Differential Revision: D22328310
fbshipit-source-id: 77a6a71808e7b6e947795cb3fcc68c8c8f024549