Fix decision logic for `should_cast_forward_inputs` in `_root_pre_forward()` and `_pre_forward()` (#99546)
Fixes #99545
There is currently no topological constraint dictating FSDP instances own ``FlatParamHandle`` s directly. If all parameters are managed by descendant FSDP instances leaving an FSDP instance with no direct ``state._handles``, the ``should_cast_forward_inputs`` decisions below in both ``_root_pre_forward()`` and ``_pre_forward()`` respectively can return incorrect decisions [^1].
For [``_root_pre_forward()``](https://github.com/pytorch/pytorch/blob/436edc5ac3de4c4e677ed136473bafe72002cc93/torch/distributed/fsdp/_runtime_utils.py#L514):
https://github.com/pytorch/pytorch/blob/436edc5ac3de4c4e677ed136473bafe72002cc93/torch/distributed/fsdp/_runtime_utils.py#L602-L604
For [``_pre_forward``](https://github.com/pytorch/pytorch/blob/436edc5ac3de4c4e677ed136473bafe72002cc93/torch/distributed/fsdp/_runtime_utils.py#L384):
https://github.com/pytorch/pytorch/blob/436edc5ac3de4c4e677ed136473bafe72002cc93/torch/distributed/fsdp/_runtime_utils.py#L420-L422
See the [related issue](https://github.com/pytorch/pytorch/issues/99545) for reproduction.
### Remediation
In this PR, I amend the two decision statements referenced above (in both `_root_pre_forward()` and `_pre_forward()`) to account for FSDP instances without direct handles:
```python
should_cast_forward_inputs = len(state._handles) > 0 and all(
not handle._force_full_precision for handle in state._handles
)
```
If one configures ``MixedPrecision`` in the example above with ``cast_forward_inputs=True`` and the ``should_cast_forward_inputs`` adjustment above, FSDP returns to the expected behavior and produces no error.
Though the check is the same in both ``_root_pre_forward()`` and ``_pre_forward()`` and hence could be refactored into a separate function, I figured it may make sense to retain separate statements to preserve the ability for root-specific behavior in the future. Whichever approach the team prefers I can update this PR with.
### Implementation considerations and questions:
1. Rather than write a test that would arguably have a poor utility/resource usage profile, I have not added any tests associated with this PR. The new decision logic is exercised by all existing tests (which continue to pass after this PR of course) so I think the utility of new tests is fairly modest. Let me know if you think new tests should be added and I'm happy to do so.
2. As discussed above, the decision statement shared among ``_pre_forward()`` and ``_root_pre_forward()`` could be factored out into a separate function. Given the simplicity of the statement and to retain current flexibility for root-specific decisions it might not be worth the refactor so I haven't done it yet. Let me know if you'd like me to do so.
3. The note below could be updated to indicate the utility of setting ``cast_forward_inputs=True`` for the situations addressed with this PR but I haven't done so since I'm not sure it's worth complicating the current usage guidance. I'd be happy to add verbiage describing the use case if the team wants it.
https://github.com/pytorch/pytorch/blob/cde35b406902d421ea5ae5f10a114da95e7171f1/torch/distributed/fsdp/api.py#L175-L181
Thanks again to the PyTorch distributed team for your immensely valuable contributions to the open-source ML community!
[^1]: Though one could keep the existing decision logic and impose a new topological constraint requiring all FSDP instances have direct `_handles`, I think retaining the current wrapping flexibility is both convenient and useful enough (e.g. programmatic wrapping of modules that may or may not already have all parameters handled by descendant FSDP instances) to update the decision logic as discussed here instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99546
Approved by: https://github.com/awgu