Migrate `diag` and `trace` from TH to ATen (CUDA) (#36876)
Summary:
Closes https://github.com/pytorch/pytorch/issues/24549 #24649
## diag
Benchmark with same build settings on same system.
gcc : version 7.5.0 (Ubuntu 7.5.0-3ubuntu1~18.04)
CUDA : 10.1
GPU : 1050ti
```python
import time
import torch
import timeit
import math
for n, t in [(100, 20000),
(400, 20000)]:
for dtype in (torch.int8,torch.int16, torch.int32, torch.int64, torch.float, torch.double):
# Input Setup
a = torch.arange(n, dtype=dtype, device="cuda")
b = a.reshape((int(math.sqrt(n)), int(math.sqrt(n))))
print(f'torch.diag a.numel() == {n} for {t} times {dtype}')
for inp, inp_name in [(a, '1-D'), (b, '2-D')]:
start = time.time()
torch.cuda.synchronize()
# Iterate
for _ in range(t):
torch.diag(inp)
# Final Synchronize Before Teardown
torch.cuda.synchronize()
print(inp_name + " Took:", time.time() - start)
```
|Dtype | Before | After |
|------|--------|-------|
| int8-Elems:100 | 1-D Took: 0.20730137825012207<br />2-D Took: 0.12553787231445312<br /> | 1-D Took: 0.33618664741516113<br />2-D Took: 0.1264970302581787<br /> |
| int16-Elems:100 | 1-D Took: 0.2127547264099121<br />2-D Took: 0.12582707405090332<br /> | 1-D Took: 0.2146449089050293<br />2-D Took: 0.12558245658874512<br /> |
| int32-Elems:100 | 1-D Took: 0.2106609344482422<br />2-D Took: 0.12958312034606934<br /> | 1-D Took: 0.2121574878692627<br />2-D Took: 0.1264948844909668<br /> |
| int64-Elems:100 | 1-D Took: 0.20768976211547852<br />2-D Took: 0.1256253719329834<br /> | 1-D Took: 0.2077159881591797<br />2-D Took: 0.12476921081542969<br /> |
| float32-Elems:100 | 1-D Took: 0.2137584686279297<br />2-D Took: 0.12708187103271484<br /> | 1-D Took: 0.21565628051757812<br />2-D Took: 0.1275336742401123<br /> |
| float64-Elems:100 | 1-D Took: 0.21710658073425293<br />2-D Took: 0.12845087051391602<br /> | 1-D Took: 0.219193696975708<br />2-D Took: 0.1264345645904541<br /> |
| int8-Elems:400 | 1-D Took: 0.20585918426513672<br />2-D Took: 0.1257162094116211<br /> | 1-D Took: 0.20970797538757324<br />2-D Took: 0.12455391883850098<br /> |
| int16-Elems:400 | 1-D Took: 0.20943427085876465<br />2-D Took: 0.12425971031188965<br /> | 1-D Took: 0.21483230590820312<br />2-D Took: 0.12662172317504883<br /> |
| int32-Elems:400 | 1-D Took: 0.21058869361877441<br />2-D Took: 0.1312875747680664<br /> | 1-D Took: 0.2092602252960205<br />2-D Took: 0.12785696983337402<br /> |
| int64-Elems:400 | 1-D Took: 0.287722110748291<br />2-D Took: 0.12862586975097656<br /> | 1-D Took: 0.28710484504699707<br />2-D Took: 0.12852025032043457<br /> |
| float32-Elems:400 | 1-D Took: 0.21535277366638184<br />2-D Took: 0.1278238296508789<br /> | 1-D Took: 0.2140669822692871<br />2-D Took: 0.1268482208251953<br /> |
| float64-Elems:400 | 1-D Took: 0.28638601303100586<br />2-D Took: 0.13219022750854492<br /> | 1-D Took: 0.28608059883117676<br />2-D Took: 0.13063836097717285<br /> |
## trace
Benchmark with same build settings on same system.
gcc : version 7.5.0 (Ubuntu 7.5.0-3ubuntu1~18.04)
CUDA : 10.1
GPU : 1050ti
```python
import time
import torch
import timeit
import math
for n, t in [(10000, 20000),
(40000, 20000)]:
for dtype in (torch.int8,torch.int16, torch.int32, torch.int64, torch.float, torch.double):
# Input Setup
a = torch.arange(n, dtype=dtype, device="cuda")
a = a.reshape((int(math.sqrt(n)), int(math.sqrt(n))))
print(f'torch.trace a.numel() == {n} for {t} times {dtype}')
start = time.time()
torch.cuda.synchronize()
# Iterate
for _ in range(t):
torch.trace(a)
# Final Synchronize Before Teardown
torch.cuda.synchronize()
print("Took:", time.time() - start)
```
|Dtype | Before | After |
|------|--------|-------|
| int8-Elems:10000 | Took: 0.4376576900482178<br /> | Took: 0.42725276947021484<br /> |
| int16-Elems:10000 | Took: 0.4334981441497803<br /> | Took: 0.4376239776611328<br /> |
| int32-Elems:10000 | Took: 0.43313121795654297<br /> | Took: 0.43097853660583496<br /> |
| int64-Elems:10000 | Took: 0.28386616706848145<br /> | Took: 0.2827033996582031<br /> |
| float32-Elems:10000 | Took: 0.2905247211456299<br /> | Took: 0.2914285659790039<br /> |
| float64-Elems:10000 | Took: 0.29450368881225586<br /> | Took: 0.2907843589782715<br /> |
| int8-Elems:40000 | Took: 0.4255516529083252<br /> | Took: 0.41020774841308594<br /> |
| int16-Elems:40000 | Took: 0.4287736415863037<br /> | Took: 0.42923426628112793<br /> |
| int32-Elems:40000 | Took: 0.43021249771118164<br /> | Took: 0.42778849601745605<br /> |
| int64-Elems:40000 | Took: 0.2852292060852051<br /> | Took: 0.28212475776672363<br /> |
| float32-Elems:40000 | Took: 0.29549574851989746<br /> | Took: 0.29524707794189453<br /> |
| float64-Elems:40000 | Took: 0.29451632499694824<br /> | Took: 0.2894322872161865<br /> |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36876
Differential Revision: D21940588
Pulled By: ngimel
fbshipit-source-id: f0ec59b1d16a51690390a002b7c46eec93f0b092