pytorch
0b31f87c - [FSDP] Use correct handle training state when prefetching (#98249)

Commit
1 year ago
[FSDP] Use correct handle training state when prefetching (#98249) This PR ensures that when prefetching a `FlatParamHandle.unshard()`, we temporarily set the `FlatParamHandle._training_state` to the expected training state as if the `unshard()` were not prefetched since the `as_params` argument to `_use_unsharded_views()` depends on the handle's training state. Pull Request resolved: https://github.com/pytorch/pytorch/pull/98249 Approved by: https://github.com/rohan-varma
Author
Committer
Parents
Loading