pytorch
b2f6e2bd - Migrate neg's CUDA implementation to ATen. (#23617)

Commit
5 years ago
Migrate neg's CUDA implementation to ATen. (#23617) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23617 Doesn't seem to cause any performance regression. Performance difference in the benchmarks is negligible. Benchmark script: ```python import timeit for n, t in [(10, 100000), (1000, 10000)]: print('a.neg() (a.numel() == {}) for {} times'.format(n, t)) for device in ('cpu', 'cuda'): for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.float', 'torch.double') + (('torch.half',) if device == 'cuda' else ()): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit(f'a.neg()\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.ones({n}, device="{device}", dtype={dtype})', number=t)) ``` Before: ``` a.neg() (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 2.5537249100016197 device: cpu, dtype: torch.uint8, 100000 times 2.512518662999355 device: cpu, dtype: torch.int16, 100000 times 2.548207502000878 device: cpu, dtype: torch.int32, 100000 times 2.5974994509997487 device: cpu, dtype: torch.int64, 100000 times 2.6533011499996064 device: cpu, dtype: torch.float, 100000 times 2.6474813019995054 device: cpu, dtype: torch.double, 100000 times 2.6949866009999823 device: cuda, dtype: torch.int8, 100000 times 5.820120684998983 device: cuda, dtype: torch.uint8, 100000 times 5.732108927997615 device: cuda, dtype: torch.int16, 100000 times 5.791249125999457 device: cuda, dtype: torch.int32, 100000 times 5.816761754998879 device: cuda, dtype: torch.int64, 100000 times 5.935873205999087 device: cuda, dtype: torch.float, 100000 times 6.276509613999224 device: cuda, dtype: torch.double, 100000 times 6.122782447000645 device: cuda, dtype: torch.half, 100000 times 6.161522764999972 a.neg() (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.3766637519984215 device: cpu, dtype: torch.uint8, 10000 times 0.37288786600038293 device: cpu, dtype: torch.int16, 10000 times 0.3485262310023245 device: cpu, dtype: torch.int32, 10000 times 0.41810554200128536 device: cpu, dtype: torch.int64, 10000 times 0.5609612200023548 device: cpu, dtype: torch.float, 10000 times 0.39054008099992643 device: cpu, dtype: torch.double, 10000 times 0.4946578170020075 device: cuda, dtype: torch.int8, 10000 times 0.5843639539998549 device: cuda, dtype: torch.uint8, 10000 times 0.5780841570012853 device: cuda, dtype: torch.int16, 10000 times 0.5819949180004187 device: cuda, dtype: torch.int32, 10000 times 0.5827294059999986 device: cuda, dtype: torch.int64, 10000 times 0.5861426519986708 device: cuda, dtype: torch.float, 10000 times 0.5929420489992481 device: cuda, dtype: torch.double, 10000 times 0.594638443999429 device: cuda, dtype: torch.half, 10000 times 0.5903799709994928 ``` After: ``` a.neg() (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 2.4983287129980454 device: cpu, dtype: torch.uint8, 100000 times 2.479393904999597 device: cpu, dtype: torch.int16, 100000 times 2.5382055320005747 device: cpu, dtype: torch.int32, 100000 times 2.5587980189993687 device: cpu, dtype: torch.int64, 100000 times 2.637738788002025 device: cpu, dtype: torch.float, 100000 times 2.602799075997609 device: cpu, dtype: torch.double, 100000 times 2.6648931070012623 device: cuda, dtype: torch.int8, 100000 times 5.793338211999071 device: cuda, dtype: torch.uint8, 100000 times 5.782462584000314 device: cuda, dtype: torch.int16, 100000 times 5.824340334998851 device: cuda, dtype: torch.int32, 100000 times 5.851659068001027 device: cuda, dtype: torch.int64, 100000 times 5.8898071570001775 device: cuda, dtype: torch.float, 100000 times 5.913144636000652 device: cuda, dtype: torch.double, 100000 times 5.963339805999567 device: cuda, dtype: torch.half, 100000 times 5.87889370099947 a.neg() (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.37244726499920944 device: cpu, dtype: torch.uint8, 10000 times 0.36641623199830065 device: cpu, dtype: torch.int16, 10000 times 0.3449854829996184 device: cpu, dtype: torch.int32, 10000 times 0.4127863069988962 device: cpu, dtype: torch.int64, 10000 times 0.5551902160004829 device: cpu, dtype: torch.float, 10000 times 0.38593814199703047 device: cpu, dtype: torch.double, 10000 times 0.48877579500185675 device: cuda, dtype: torch.int8, 10000 times 0.5862828740027908 device: cuda, dtype: torch.uint8, 10000 times 0.5836667540024791 device: cuda, dtype: torch.int16, 10000 times 0.5918155769977602 device: cuda, dtype: torch.int32, 10000 times 0.5961457039993547 device: cuda, dtype: torch.int64, 10000 times 0.5963898690024507 device: cuda, dtype: torch.float, 10000 times 0.5985483309996198 device: cuda, dtype: torch.double, 10000 times 0.6027148480025062 device: cuda, dtype: torch.half, 10000 times 0.5961164370019105 ``` Test Plan: Imported from OSS Differential Revision: D16617574 Pulled By: ezyang fbshipit-source-id: c90aa410f6385ce94fe6b84ebeceffa5effd0267
Author
Parents
Loading