Fix optim state_dict casting to allow step to cast to CPU (#102619)
I'm guessing this should fix https://github.com/pytorch/pytorch/pull/88015#issuecomment-1569523106 but am waiting on @ychfan to supply more details so I could write a good test case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102619
Approved by: https://github.com/albanD