pytorch
e87887cc - Update type hints for torch.optim.optimizer.Optimizer (#32900)

Commit
6 years ago
Update type hints for torch.optim.optimizer.Optimizer (#32900) Summary: This PR fixes type hints for `torch.optim.optimizer.Optimizer` object, issue also reported in https://github.com/pytorch/pytorch/issues/23731 To test things I used following optimiser implementation, that is fully covered with type hints: ```python from typing import Optional, Callable, Union, Iterable from torch import Tensor from torch.optim.optimizer import Optimizer OptClosure = Optional[Callable[[], float]] _params_t = Union[Iterable[Tensor], Iterable[dict]] class SGD(Optimizer): def __init__(self, params: _params_t, lr: float = 0.1) -> None: defaults = dict(lr=lr) super(SGD, self).__init__(params, defaults) def __setstate__(self, state: dict) -> None: super(SGD, self).__setstate__(state) def step(self, closure: OptClosure = None) -> Optional[float]: loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue d_p = p.grad.data p.data.add_(-group['lr'], d_p) return loss ``` Without fix `mypy` reports bunch of inconsistencies in types and missing properties: ```bash $ mypy torch_optimizer/sgd.py torch_optimizer/sgd.py:14: error: Too many arguments for "__init__" of "Optimizer" torch_optimizer/sgd.py:17: error: "__setstate__" undefined in superclass torch_optimizer/sgd.py:19: error: Return type "Optional[float]" of "step" incompatible with return type "None" in supertype "Optimizer" torch_optimizer/sgd.py:24: error: "SGD" has no attribute "param_groups" Found 4 errors in 1 file (checked 1 source file) ``` with fix not issues: ```bash $ mypy torch_optimizer/sgd.py Success: no issues found in 1 source file ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/32900 Differential Revision: D19697175 Pulled By: ezyang fbshipit-source-id: d5e2b3c421f69da3df8c32b3d53b4b6d15d61a41
Author
Parents
Loading