pytorch
3f052ba0 - Remove unnecessary dtype checks for complex types & disable complex dispatch for CPU min/max pointwise ops (#50465)

Commit
3 years ago
Remove unnecessary dtype checks for complex types & disable complex dispatch for CPU min/max pointwise ops (#50465) Summary: Fixes https://github.com/pytorch/pytorch/issues/50064 **PROBLEM DESCRIPTION:** 1. Had not removed dtype checks for complex types in the previous PR (https://github.com/pytorch/pytorch/issues/50347) for this issue. These type-checks were added in https://github.com/pytorch/pytorch/issues/36377, but are no longer necessary, as we now rely upon dispatch macros to produce error messages. 2. dtype checks in `clamp_max()` and `clamp_min()` for complex inputs had not been removed either. 3. For min/max pointwise ops in TensorCompareKernel.cpp, complex dispatch had not been removed for min/max functions. ### **FIX DESCRIPTION:** **FIX SUMMARY:** 1. Removed dtype checks added in https://github.com/pytorch/pytorch/issues/36377, and added 3 more in TensorCompare.cpp. 2. Removed dtype checks for complex inputs in `clamp_max()` and `clamp_min()`. 3. Disabled complex dispatch for min/max pointwise ops in TensorCompareKernel.cpp. 4. Error messages in the exceptions raised due to min/max ops not being implemented are now checked for containing the text _not support_ (which can also be present in _not supported_), or _not implemented_, so one of them should be a part of error messages, in order for them to be informative. **REASON FOR NOT CHANGING DISPATCH FOR CUDA AND CLAMP OPS**: As for the CUDA min/max operations, their kernels do not seem to be compiled & dispatched for complex types anyway, so no further changes seem to be required. Basically, the dispatch macros currently being used don't have cases for complex types. For example, 1. the reduce CUDA ops use [AT_DISPATCH_ALL_TYPES_AND2 (https://github.com/pytorch/pytorch/commit/678fe9f0771a5cd98ead214363d70480ba03000d)](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h#L548-L575) in [ReduceMinMaxKernel.cu](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu), and that macro doesn't allow complex types. 2. In [MinMaxElementwiseKernel.cu](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/MaxMinElementwiseKernel.cu), the CUDA pointwise ops use [`AT_DISPATCH_FLOATING_TYPES_AND2 (https://github.com/pytorch/pytorch/commit/678fe9f0771a5cd98ead214363d70480ba03000d)`](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h#L240-L263) for non-integral & non-boolean types, and this marco doesn't have a case for complex types either. 3. [clamp CUDA ops](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/UnaryOpsKernel.cu#L170-L211) use `AT_DISPATCH_ALL_TYPES_AND2 (https://github.com/pytorch/pytorch/commit/678fe9f0771a5cd98ead214363d70480ba03000d)`, which doesn't have a case for complex types. Similarly, [CPU clamp min/max ops](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp#L428-L458) use the `AT_DISPATCH_ALL_TYPES_AND `dispatch macro, which doesn't have a case for complex types. **REASON FOR ADDING 3 dtype CHECKS:** There are a few cases in which the methods corresponding to `min_stub()` or `max_stub()` are not called, so dispatch macros don't get invoked, resulting in no exceptions being raised. Hence, `dtype` checks are necessary at 3 places to raise exceptions: 1. https://github.com/pytorch/pytorch/blob/52dcc7299925de055d330781d2fe0dad71182829/aten/src/ATen/native/TensorCompare.cpp#L342 2. https://github.com/pytorch/pytorch/blob/52dcc7299925de055d330781d2fe0dad71182829/aten/src/ATen/native/TensorCompare.cpp#L422 3. https://github.com/pytorch/pytorch/blob/52dcc7299925de055d330781d2fe0dad71182829/aten/src/ATen/native/TensorCompare.cpp#L389 The first dtype check requirement can be verified from the following example Python code based on `test_complex_unsupported()`: ``` import unittest import torch class MyTestCase(unittest.TestCase): def test_1(self): t = torch.tensor((1 + 1j), device='cpu', dtype=torch.complex128) with self.assertRaises(Exception): torch.max(t, dim=0) if __name__ == '__main__': unittest.main() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/50465 Reviewed By: mruberry Differential Revision: D25938106 Pulled By: ngimel fbshipit-source-id: 95e2df02ba8583fa3ce87d4a2fdcd60b912dda46
Parents
Loading