pytorch
e667c006 - [FSDP()][2/N] Refactor training state (#87916)

Commit
2 years ago
[FSDP()][2/N] Refactor training state (#87916) This PR actually has meaningful changes. We stratify `TrainingState` into two levels: one is per FSDP instance and one is per `FlatParamHandle`/`FlatParameter`. - At the FSDP instance level, we only care about `IDLE`, FSDP computation (i.e. `FORWARD_BACKWARD`), or `SUMMON_FULL_PARAMS`. These dynamically modify behavior (e.g. `summon_full_params()` forces full precision). - At the `FlatParamHandle` level, we care about the training state for invariants and debugging. Hence, we keep `IDLE`, `FORWARD`, `BACKWARD_PRE`, `BACKWARD_POST`, and `SUMMON_FULL_PARAMS`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87916 Approved by: https://github.com/mrshenli
Author
Committer
Parents
Loading