pytorch
45b40be0 - [FSDP()] Fix `fully_shard` fwd hook registration (#90201)

Commit
2 years ago
[FSDP()] Fix `fully_shard` fwd hook registration (#90201) I need to rebase later after Shen's PRs land. The idea is to only register the pre/post-forward hook on the _root modules_ among the modules that consume a `FlatParameter`. (Yes, the term _root module_ is heavily overloaded. We may want to clarify that at some point. Here, _root_ is being used in the graph sense, meaning parent-less, and the scope is only among the modules consuming a `FlatParameter`.) This avoids unnecessary pre/post-forward hooks running, which would lead to errors because the unshard is not truly idempotent. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90201 Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
Author
Committer
Parents
Loading