Fix potential hang when exiting main process (#33721)
Summary:
The following script reproduces the hang
```py
import multiprocessing, logging
logger = multiprocessing.log_to_stderr()
logger.setLevel(multiprocessing.SUBDEBUG)
import torch
class Dataset:
def __len__(self):
return 23425
def __getitem__(self, idx):
return torch.randn(3, 128, 128), idx % 100
ds = Dataset()
trdl = torch.utils.data.DataLoader(ds, batch_size=64, num_workers=300, pin_memory=True, shuffle=True)
for e in range(1000):
for ii, (x, y) in enumerate(trdl):
print(f'tr {e: 5d} {ii: 5d} avg y={y.mean(dtype=torch.double).item()}')
if ii % 2 == 0:
print("="*200 + "BEFORE ERROR" + "="*200)
1/0
```
The process will hang at joining the putting thread of `data_queue` in **main process**. The root cause is that too many things are put in the queue from the **worker processes**, and the `put` at https://github.com/pytorch/pytorch/blob/062ac6b472af43c9cf83d285e661e24244551f85/torch/utils/data/dataloader.py#L928 is blocked at background thread. The `pin_memory_thread` exits from the set `pin_memory_thread_done_event`, without getting the `(None, None)`. Hence, the main process needs the same treatment as the workers did at
https://github.com/pytorch/pytorch/blob/062ac6b472af43c9cf83d285e661e24244551f85/torch/utils/data/_utils/worker.py#L198 .
After the patch, the script finishes correctly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33721
Differential Revision: D20089209
Pulled By: ezyang
fbshipit-source-id: e73fbfdd7631afe1ce5e1edd05dbdeb7b85ba961