pytorch
5caf2e55 - [FSDP] fix: fix for fsdp zero2 validation error (#110139)

Commit
1 year ago
[FSDP] fix: fix for fsdp zero2 validation error (#110139) # Problem When sharding_strategy is set to SHARD_GRAD_OP and forward_prefetch is turned on, the validation after the train has an incorrect weight shape. <img width="1508" alt="image" src="https://github.com/pytorch/pytorch/assets/41232043/57a9c3bb-cb5c-46df-ac26-922740686f9e"> # Analyze When using `SHARD_GRAD_OP`, the `free_unsharded_flat_param` in `_post_forward_reshard` is often False, so it does not set the handle's `_prefetched` flag to False after the forward. The normal train phase sets this flag to False in the `_post_backward_final_callback`, and the validation phase doesn't execute the hook, so after the first iter of the validation is done, the flag of the handle of the prefetched will remain True. This will cause the handle to skip the `_unshard` in the next `_pre_forward_unshard`, and the `_prefetch_handle` will not do a prefetch, which will result in an incorrect weight shape. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110139 Approved by: https://github.com/awgu
Author
Committer
Parents
Loading