Port FSDP::summon_full_params from fairscale to pytorch. (#71225)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71225
Bring FSDP::summon_full_params from fairscale.
It doesn't support summoning with full precision as this mode is not yet supported by other parts of PT's FSDP port.
One thing that needs figuring out is the semantics we want W.R.T. reshard_after_forward. Right now I'm always discarding the full tensor at the end of _summon_full_params.
Fixes: https://github.com/pytorch/pytorch/issues/69779
Test Plan: Ported the fairscale tests plus added a few more.
Reviewed By: zhaojuanmao, rohan-varma
Differential Revision: D33350378
fbshipit-source-id: d826b7cc1762baa1e6a820651beb715c6428482a
(cherry picked from commit 23c78adda226e57528b3c48238f35ca55d04ba05)