fully_shard load state_dict (#90945)
Ensures that load_state_dict for fully_shard works:
- Don't add back FSDP prefix
- Small fix to ensure mixed precision check for buffers work
Follow ups:
- state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations
- No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90945
Approved by: https://github.com/awgu