pytorch
0a1bc5f5 - Miscellaneous __torch_function__ fixes

Commit
3 years ago
Miscellaneous __torch_function__ fixes I figured these out by unconditionally turning on a no-op torch function mode on the test suite and then fixing errors as they showed up. Here's what I found: - _parse_to failed internal assert when __torch_function__'ed because it claims its name is "to" to the argument parser; added a name override so we know how to find the correct name - Infix operator magic methods on Tensor did not uniformly handle __torch_function__ and TypeError to NotImplemented. Now, we always do the __torch_function__ handling in _wrap_type_error_to_not_implemented and your implementation of __torch_function__ gets its TypeErrors converted to NotImplemented (for better or for worse; see https://github.com/pytorch/pytorch/issues/75462 ) - A few cases where code was incorrectly testing if a Tensor was Tensor-like in the wrong way, now use is_tensor_like (in grad and in distributions). Also update docs for has_torch_function to push people to use is_tensor_like. - is_grads_batched was dropped from grad in handle_torch_function, now fixed - Report that you have a torch function even if torch function is disabled if a mode is enabled. This makes it possible for a mode to return NotImplemented, pass to a subclass which does some processing and then pass back to the mode even after the subclass disables __torch_function__ (so the tensors are treated "as if" they are regular Tensors). This brings the C++ handling behavior in line with the Python behavior. - Make the Python implementation of overloaded types computation match the C++ version: when torch function is disabled, there are no overloaded types (because they all report they are not overloaded). Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/75484 Approved by: https://github.com/zou3519
Author
Committer
Parents
Loading