pytorch
dc7066a8 - Add support for multiple inputs to out_wrapper and strict dtype checking (#79941)

Commit
2 years ago
Add support for multiple inputs to out_wrapper and strict dtype checking (#79941) When a function returns multiple parameters in PyTorch, the `out` parameter takes a tuple of tensors (see `linalg.svd` for example). The current implementation in `out_wrapper_multi` modelled this wrong, as it assumed that it would take a number of different named parameters. This PR implements the correct behaviour in `out_wrapper`. As a small side-effect, we now need to call `@out_wrapper()` when the output is just one tensor. This PR also implements an additional optional parameter that checks whether the dtype of the given `out` is exactly the dtype that the meta function requires. This is the behaviour that we currently have in PyTorch, and this check is necessary in eager when we call with these tensors into external libraries. We also make the functions with several outputs return a namedtuple, similar to what we do in PyTorch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/79941 Approved by: https://github.com/mruberry, https://github.com/ezyang
Author
Committer
Parents
Loading