[WIP][FSDP] Mixed precision enablement (#74452)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74452
Useful clarifications while reviewing the diff:
How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.
How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency).
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases.
Why rebuild_full_params checks for summon_full_params training state:
- summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale.
- Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded.
Test coverage:
[ ] Test1
ghstack-source-id: 152654758
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D35000703
fbshipit-source-id: 4bd7937ff36bdb3afd60eda981afc9d8731b823a
(cherry picked from commit 6ed6721aaf18f323656686200465fc78cef1d0dd)