[FSDP] full_state_dict impl (#73324)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73324
Implements `state_dict` and `load_state_dict` APIs for FSDP, with the following limitations:
1. Does not support `state_dict_device` (i.e. specifying which device params should be on) which fairscale does currently support
2. Does not yet support offload of state_dict onto CPU
3. Loads state_dict on all ranks currently. In the future we could add support for loading this on only rank 0, to avoid redundancy across ranks as usually only one rank is responsible for saving/loading the model. Along with (2) this would enable larger models to have state_dict called.
As discussed in FSDP checkpoint API proposal, `state_dict` will basically be a `full_state_dict` where full parameters are returned on all ranks. As a result this implies that the model must actually be able to fit on a single GPU.
ghstack-source-id: 150012240
Test Plan: ci
Reviewed By: zhaojuanmao
Differential Revision: D34433514
fbshipit-source-id: 3eb1d679b2236264f9f423e761d1720f9aaec73a
(cherry picked from commit a451d5a08ebfa14a229a25fea35b9ca59fe91a59)