pytorch
363d5300 - Fix decision logic for `should_cast_forward_inputs` in `_root_pre_forward()` and `_pre_forward()` (#99546)

Commit
1 year ago
Fix decision logic for `should_cast_forward_inputs` in `_root_pre_forward()` and `_pre_forward()` (#99546) Fixes #99545 There is currently no topological constraint dictating FSDP instances own ``FlatParamHandle`` s directly. If all parameters are managed by descendant FSDP instances leaving an FSDP instance with no direct ``state._handles``, the ``should_cast_forward_inputs`` decisions below in both ``_root_pre_forward()`` and ``_pre_forward()`` respectively can return incorrect decisions [^1]. For [``_root_pre_forward()``](https://github.com/pytorch/pytorch/blob/436edc5ac3de4c4e677ed136473bafe72002cc93/torch/distributed/fsdp/_runtime_utils.py#L514): https://github.com/pytorch/pytorch/blob/436edc5ac3de4c4e677ed136473bafe72002cc93/torch/distributed/fsdp/_runtime_utils.py#L602-L604 For [``_pre_forward``](https://github.com/pytorch/pytorch/blob/436edc5ac3de4c4e677ed136473bafe72002cc93/torch/distributed/fsdp/_runtime_utils.py#L384): https://github.com/pytorch/pytorch/blob/436edc5ac3de4c4e677ed136473bafe72002cc93/torch/distributed/fsdp/_runtime_utils.py#L420-L422 See the [related issue](https://github.com/pytorch/pytorch/issues/99545) for reproduction. ### Remediation In this PR, I amend the two decision statements referenced above (in both `_root_pre_forward()` and `_pre_forward()`) to account for FSDP instances without direct handles: ```python should_cast_forward_inputs = len(state._handles) > 0 and all( not handle._force_full_precision for handle in state._handles ) ``` If one configures ``MixedPrecision`` in the example above with ``cast_forward_inputs=True`` and the ``should_cast_forward_inputs`` adjustment above, FSDP returns to the expected behavior and produces no error. Though the check is the same in both ``_root_pre_forward()`` and ``_pre_forward()`` and hence could be refactored into a separate function, I figured it may make sense to retain separate statements to preserve the ability for root-specific behavior in the future. Whichever approach the team prefers I can update this PR with. ### Implementation considerations and questions: 1. Rather than write a test that would arguably have a poor utility/resource usage profile, I have not added any tests associated with this PR. The new decision logic is exercised by all existing tests (which continue to pass after this PR of course) so I think the utility of new tests is fairly modest. Let me know if you think new tests should be added and I'm happy to do so. 2. As discussed above, the decision statement shared among ``_pre_forward()`` and ``_root_pre_forward()`` could be factored out into a separate function. Given the simplicity of the statement and to retain current flexibility for root-specific decisions it might not be worth the refactor so I haven't done it yet. Let me know if you'd like me to do so. 3. The note below could be updated to indicate the utility of setting ``cast_forward_inputs=True`` for the situations addressed with this PR but I haven't done so since I'm not sure it's worth complicating the current usage guidance. I'd be happy to add verbiage describing the use case if the team wants it. https://github.com/pytorch/pytorch/blob/cde35b406902d421ea5ae5f10a114da95e7171f1/torch/distributed/fsdp/api.py#L175-L181 Thanks again to the PyTorch distributed team for your immensely valuable contributions to the open-source ML community! [^1]: Though one could keep the existing decision logic and impose a new topological constraint requiring all FSDP instances have direct `_handles`, I think retaining the current wrapping flexibility is both convenient and useful enough (e.g. programmatic wrapping of modules that may or may not already have all parameters handled by descendant FSDP instances) to update the decision logic as discussed here instead. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99546 Approved by: https://github.com/awgu
Author
Committer
Parents
Loading