pytorch
4011597d - [Composable API] Refactor `test_fully_shard.py` to use common models (#90386)

Commit
2 years ago
[Composable API] Refactor `test_fully_shard.py` to use common models (#90386) Unlike for FSDP, where we already diverged to using per-test-file models, let us try to use the same set of models for the composable API effort. This can improve debugging efficiency because we know which module structures we support and which we do not _across all of our composable APIs_. This PR had to perform some surgery for `test_materialize_meta_module`. Writing a correct parameter initialization function for meta device initialization is not easy, and we should revisit this. The old implementation, which followed the style of the previous unit tests--namely, using `module.to_empty()`--is actually incorrect for nested FSDP applications because `module.to_empty()` will re-initialize already materialized parameters and the module materialization proceeds bottom up. The existing unit test in `test_fsdp_meta.py` passes because it sets every parameter to ones (`self.weight.fill_(1)`), which is idempotent to re-initialization. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90386 Approved by: https://github.com/mrshenli
Author
Andrew Gu
Committer
Parents
Loading