Update type hints for torch.optim.optimizer.Optimizer (#32900)
Summary:
This PR fixes type hints for `torch.optim.optimizer.Optimizer` object, issue also reported in https://github.com/pytorch/pytorch/issues/23731
To test things I used following optimiser implementation, that is fully covered with type hints:
```python
from typing import Optional, Callable, Union, Iterable
from torch import Tensor
from torch.optim.optimizer import Optimizer
OptClosure = Optional[Callable[[], float]]
_params_t = Union[Iterable[Tensor], Iterable[dict]]
class SGD(Optimizer):
def __init__(self, params: _params_t, lr: float = 0.1) -> None:
defaults = dict(lr=lr)
super(SGD, self).__init__(params, defaults)
def __setstate__(self, state: dict) -> None:
super(SGD, self).__setstate__(state)
def step(self, closure: OptClosure = None) -> Optional[float]:
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
p.data.add_(-group['lr'], d_p)
return loss
```
Without fix `mypy` reports bunch of inconsistencies in types and missing properties:
```bash
$ mypy torch_optimizer/sgd.py
torch_optimizer/sgd.py:14: error: Too many arguments for "__init__" of "Optimizer"
torch_optimizer/sgd.py:17: error: "__setstate__" undefined in superclass
torch_optimizer/sgd.py:19: error: Return type "Optional[float]" of "step" incompatible with return type "None" in supertype "Optimizer"
torch_optimizer/sgd.py:24: error: "SGD" has no attribute "param_groups"
Found 4 errors in 1 file (checked 1 source file)
```
with fix not issues:
```bash
$ mypy torch_optimizer/sgd.py
Success: no issues found in 1 source file
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32900
Differential Revision: D19697175
Pulled By: ezyang
fbshipit-source-id: d5e2b3c421f69da3df8c32b3d53b4b6d15d61a41