pytorch
a26e5e21 - Improve type hints for Module forward hooks (#92061)

Commit
2 years ago
Improve type hints for Module forward hooks (#92061) Fixes #91654. Currently, the `hook` parameters of `nn.Module.register_forward_pre_hook` and `nn.Module.register_forward_hook` are typed as `Callable[..., None]`, which 1) does not enable the validation of the signature of `hook` and 2) incorrectly restricts the return type of `hook`, which the docstrings of these methods themselves state can be non-`None`. The typing of the first parameter of `hook` as `TypeVar("T", bound="Module")` allows the binding of `Callable` whose first parameter is a subclass of `Module`. --- Here are some examples of: 1. forward hooks and pre-hook hooks being accepted by mypy according to the new type hints 2. mypy throwing errors d.t. incorrect `hook` signatures 3. false negatives of pre-hooks being accepted as forward hooks 4. false negatives of hooks with kwargs being accepted irrespective of the value provided for `with_kwargs` ```python from typing import Any, Dict, Tuple import torch from torch import nn def forward_pre_hook( module: nn.Linear, args: Tuple[torch.Tensor, ...], ) -> None: ... def forward_pre_hook_return_input( module: nn.Linear, args: Tuple[torch.Tensor, ...], ) -> Tuple[torch.Tensor, ...]: ... def forward_pre_hook_with_kwargs( module: nn.Linear, args: Tuple[torch.Tensor, ...], kwargs: Dict[str, Any], ) -> None: ... def forward_pre_hook_with_kwargs_return_input( module: nn.Linear, args: Tuple[torch.Tensor, ...], kwargs: Dict[str, Any], ) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]: ... def forward_hook( module: nn.Linear, args: Tuple[torch.Tensor, ...], output: torch.Tensor, ) -> None: ... def forward_hook_return_output( module: nn.Linear, args: Tuple[torch.Tensor, ...], output: torch.Tensor, ) -> torch.Tensor: ... def forward_hook_with_kwargs( module: nn.Linear, args: Tuple[torch.Tensor, ...], kwargs: Dict[str, Any], output: torch.Tensor, ) -> None: ... def forward_hook_with_kwargs_return_output( module: nn.Linear, args: Tuple[torch.Tensor, ...], kwargs: Dict[str, Any], output: torch.Tensor, ) -> torch.Tensor: ... model = nn.Module() # OK model.register_forward_pre_hook(forward_pre_hook) model.register_forward_pre_hook(forward_pre_hook_return_input) model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=True) model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=True) model.register_forward_hook(forward_hook) model.register_forward_hook(forward_hook_return_output) model.register_forward_hook(forward_hook_with_kwargs, with_kwargs=True) model.register_forward_hook(forward_hook_with_kwargs_return_output, with_kwargs=True) # mypy(error): [arg-type] model.register_forward_pre_hook(forward_hook) model.register_forward_pre_hook(forward_hook_return_output) model.register_forward_pre_hook(forward_hook_with_kwargs) model.register_forward_pre_hook(forward_hook_with_kwargs_return_output) model.register_forward_hook(forward_pre_hook) model.register_forward_hook(forward_pre_hook_return_input) # false negatives model.register_forward_hook(forward_pre_hook_with_kwargs) model.register_forward_hook(forward_pre_hook_with_kwargs_return_input) model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=False) model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=False) ... ``` --- Though it is not functional as of mypy 0.991, the ideal typing of these methods would use [`typing.Literal`](https://mypy.readthedocs.io/en/stable/literal_types.html#literal-types): ```python T = TypeVar("T", bound="Module") class Module: @overload def register_forward_hook( self, hook: Callable[[T, Tuple[Any, ...], Any], Optional[Any]], *, prepend: bool = ..., with_kwargs: Literal[False] = ..., ) -> RemovableHandle: ... @overload def register_forward_hook( self, hook: Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]], *, prepend: bool = ..., with_kwargs: Literal[True] = ..., ) -> RemovableHandle: ... def register_forward_hook(...): ... ``` which would: 1. validate the signature of `hook` according to the corresponding literal value provided for `with_kwargs` (and fix the false negative examples above) 2. implicitly define the [fallback `bool` signature](https://github.com/python/mypy/issues/6113#issuecomment-1266186192) e.g. to handle if a non-literal is provided for `with_kwargs` Pull Request resolved: https://github.com/pytorch/pytorch/pull/92061 Approved by: https://github.com/albanD
Author
Committer
Parents
Loading