pytorch
b691d090 - fix: reset prefetch flag upon reshard (#111354)

Commit
1 year ago
fix: reset prefetch flag upon reshard (#111354) The `prefetched` flag should be reset upon reshard. Otherwise, for zero2, next access to the corresponding parameter will skip "unshard" operation, and results in wrong parameter shape. The need of unsharding is also metioned [in the comment of `FlatParameterHandle.unshard`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_flat_param.py#L1241-L1242). As [`FlatParameterHandle` already guarded it against unnecessary all gather](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_flat_param.py#L1240), this shouldn't incur extra communication overhead. _Personally I also find `_prefetched` a bit of mis-named, it should really be `_unsharded`._ Pull Request resolved: https://github.com/pytorch/pytorch/pull/111354 Approved by: https://github.com/awgu
Author
Committer
Parents
Loading