[DDP] Separate error messages for unused params in forward and not all outputs (#52391)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52391
There are 2 ways DDP can throw the exception refactored here -
1) Unused params in the forward pass. We provide `find_unused_parameters=True` for this.
2) All params used in fwd pass, but not all outputs used in loss computation. There are a few workarounds for this but we do not provide native support.
Previously, these 2 issues were combined into 1 error message but that has historically resulted in confusion, with users reporting getting this error even when they enable `find_unused_parameters=True` (which they expect to fix this error). As a result there is additional churn to debug these issues because the true cause (1) vs (2) is not known.
This commit helps to fix the issue by separating out the 2 error messages depending on if we ran with unused parameter detection or not. Hopefully this should make the error message much more clear and actionable.
error msg with `find_unused_params=True`:
```
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. Since `find_unused_parameters=True` is enabled, this likely means that not all `forward` outputs participate in computing loss. You can fix this by making sure all `forward` function outputs participate in calculating loss.
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
```
error msg without `find_unused_params` specified:
```
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by
making sure all `forward` function outputs participate in calculating loss.
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
```
ghstack-source-id: 122097900
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D26496688
fbshipit-source-id: 4a9eeeda10293da13d94a692d10cb954e4506d7c