pytorch
7ec7a820 - Test FSDP with submodule non-reentrant checkpointing (#89781)

Comment changes are shownComment changes are hidden
Commit
2 years ago
Test FSDP with submodule non-reentrant checkpointing (#89781) With combining FSDP with reentrant checkpointing, the post backward hook might run twice, and then hit [this error](https://github.com/pytorch/pytorch/blob/e20ec44544c17d6d3d411f88b870e05043bda731/torch/distributed/fsdp/_runtime_utils.py#L487). This is because reentrant backward uses nested autograd GraphTasks. The inner GraphTask is not aware of the outer one and therefore will flush pending `AccumulateGrad` invocations on exit, which in turn triggers the post backward hooks registered by FSDP. Later, the outer GraphTask will trigger that again, leading to the above error. PR #89791 relaxes the FSDP training state check, but we still run into grad value check failures occasionally. Therefore, this PR only lands the test for non-reentrant test, and we can enable the reentrant test when the accuracy issues are addressed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89781 Approved by: https://github.com/rohan-varma
Author
Committer
Parents
  • test/distributed/fsdp
    • File
      test_fsdp_checkpoint.py
Loading