pytorch
b2116f58 - Port FSDP::summon_full_params from fairscale to pytorch. (#71225)

Commit
4 years ago
Port FSDP::summon_full_params from fairscale to pytorch. (#71225) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71225 Bring FSDP::summon_full_params from fairscale. It doesn't support summoning with full precision as this mode is not yet supported by other parts of PT's FSDP port. One thing that needs figuring out is the semantics we want W.R.T. reshard_after_forward. Right now I'm always discarding the full tensor at the end of _summon_full_params. Fixes: https://github.com/pytorch/pytorch/issues/69779 Test Plan: Ported the fairscale tests plus added a few more. Reviewed By: zhaojuanmao, rohan-varma Differential Revision: D33350378 fbshipit-source-id: d826b7cc1762baa1e6a820651beb715c6428482a (cherry picked from commit 23c78adda226e57528b3c48238f35ca55d04ba05)
Author
Rodrigo Kumpera
Committer
Parents
Loading