pytorch
1c017f0c - Migrate max and min (binary) from TH to ATen. (#30851)

Commit
4 years ago
Migrate max and min (binary) from TH to ATen. (#30851) Summary: TH implementation will be removed after the unary max and min are migrated. Benchmark: (Debian 10, Release build, gcc 7.4, no turbo) ```python import timeit for device in ('cpu', 'cuda'): print(f'device: {device}') for op in ('max', 'min'): for dtype in ('torch.double', 'torch.float', 'torch.int16', 'torch.int32', 'torch.int64'): for n, t in [(10_000, 200000), (100_000, 20000)]: print(f'torch.{op}(a, b), numel() == {n} for {t} times, dtype={dtype}') print(timeit.timeit(f'torch.{op}(a)' + (';torch.cuda.synchronize()' if device == 'cuda' else ''), setup=f'import torch; a = torch.arange({n}, dtype={dtype}); b = torch.ones({n}, 0, dtype={dtype}) * ({n} / 2)', number=t)) print() ``` Before: ``` device: cpu torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.double 2.241763713000182 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.double 1.7138833169992722 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.float 2.2183356810000987 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.float 1.7031846980007685 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.int16 1.7704679510006827 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.int16 1.289198366999699 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.int32 1.7937613740014058 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.int32 1.2930124340000475 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.int64 1.8032857640009752 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.int64 1.2908709189996443 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.double 1.8829010000008566 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.double 1.2994690759987861 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.float 1.8037853410005482 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.float 1.2929310759991495 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.int16 1.8075240359994496 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.int16 1.2932477679987642 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.int32 1.7868400779989315 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.int32 1.2885970789993735 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.int64 1.8389664830010588 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.int64 1.29402057399966 device: cuda torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.double 4.787109836999662 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.double 1.842438002999188 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.float 3.429616614999759 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.float 1.835390076999829 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.int16 2.940423873000327 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.int16 1.4108991760003846 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.int32 2.9318018840003788 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.int32 1.4168134739993548 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.int64 2.9610764919998473 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.int64 1.4189234130008117 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.double 2.960172712999338 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.double 1.4162539499993727 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.float 2.8985912560001452 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.float 1.4113489299998037 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.int16 2.9160250799995993 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.int16 1.4128787690005993 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.int32 2.8806865219994506 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.int32 1.4086357010000938 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.int64 2.9362181240012433 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.int64 1.4151225870009512 ``` After: ``` device: cpu torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.double 2.2685823729998447 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.double 1.72004808300062 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.float 2.212242640000113 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.float 1.7089235590001408 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.int16 1.7767087259999244 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.int16 1.2916517639996528 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.int32 1.8265984959998605 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.int32 1.3002885240002797 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.int64 1.8084679720004715 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.int64 1.3012119999993956 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.double 1.8800218449996464 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.double 1.3060645710002063 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.float 2.4905043950002437 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.float 1.9126290209997023 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.int16 1.7972335520007618 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.int16 1.2918074379995232 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.int32 1.8047651860006226 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.int32 1.2992197730000044 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.int64 1.8526509560006161 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.int64 1.3030709570002728 device: cuda torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.double 4.700986622000528 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.double 1.8415469050005413 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.float 3.3051693249999516 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.float 1.8321999460004008 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.int16 2.8086475109994353 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.int16 1.405110773999695 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.int32 2.913458047999484 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.int32 1.4236377289998927 torch.max(a, b), numel() == 10000 for 200000 times, dtype=torch.int64 2.9386842409994642 torch.max(a, b), numel() == 100000 for 20000 times, dtype=torch.int64 1.4230227469997772 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.double 3.0341797270002644 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.double 1.4289592409995748 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.float 3.6091147850002017 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.float 2.036691903999781 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.int16 2.8256167649997224 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.int16 1.4078955400000268 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.int32 2.8631781489993955 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.int32 1.4210130069996012 torch.min(a, b), numel() == 10000 for 200000 times, dtype=torch.int64 3.0112479260005784 torch.min(a, b), numel() == 100000 for 20000 times, dtype=torch.int64 1.4297719679998409 ``` Solve partly https://github.com/pytorch/pytorch/issues/24594 #24595 Close https://github.com/pytorch/pytorch/issues/25016 Continuing https://github.com/pytorch/pytorch/issues/27185 Pull Request resolved: https://github.com/pytorch/pytorch/pull/30851 Differential Revision: D19515694 Pulled By: ezyang fbshipit-source-id: 1764897f912d6ae24b0c361f19a1aacf96e0826e
Author
Parents
Loading