Allow optimizer state conversion to accommodate optimizers that have no tensor state (e.g. SGD) (#111501)
Fixes #111499
This PR slightly alters the new fused `all_gather` `optim_state_dict` implementation to support optimizers without tensor state (e.g. SGD) in a `use_orig_params=True` context.
The principle change is to short-circuit `_allgather_orig_param_states` if an empty `state_buffers` dict is returned after completing `_convert_all_state_info` here:
https://github.com/pytorch/pytorch/blob/93e5065ba0a16db280157a2a28702b684fca3bb1/torch/distributed/fsdp/_optim_utils.py#L1481-L1484
To allow `_convert_all_state_info` to accommodate optimizers with no tensor state, I also change the scope of `dtype` and make the return type `Optional`.
As discussed in the issue this PR fixes, I'm [extending](https://github.com/pytorch/pytorch/blob/93e5065ba0a16db280157a2a28702b684fca3bb1/test/distributed/fsdp/test_fsdp_optim_state.py#L1915I) `test_state_dict_with_none_tensor_state` to test with both Adam and SGD optimizers to validate scalar and non-tensor states continue to be restored for both optimizer types.
Thanks to the distributed team as always for their adroit design and exceptionally valuable contributions to the open source ML community. Hope you all feel appreciated commensurate with the compounding progress your work enables.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111501
Approved by: https://github.com/fegin