unify compare kernels (#69111)
Summary:
This unifies 6 compare ops (NE, EQ, LT, LE, GE, GT) into 2 kernels, reducing context size. Performance is ~5% worse for low width broadcasted cases, on-par for non-broadcasted
With this PR, benchmarks for contiguous, 1M-MM, 1M-M1, op with scalar (size in MB and bandwidth in GB/s):
```
5.0, 795.9
10.0, 650.5
15.0, 706.2
20.0, 731.6
25.0, 744.9
30.0, 758.1
35.0, 762.6
40.0, 768.8
45.0, 775.7
50.0, 780.7
55.0, 781.7
60.0, 783.0
65.0, 784.8
70.0, 790.7
75.0, 789.2
80.0, 794.4
85.0, 794.2
90.0, 797.4
95.0, 796.3
100.0, 798.0
3.0, 363.7 1.0, 122.2 3.0, 385.5
6.0, 420.4 2.0, 142.9 6.0, 755.5
9.0, 438.3 3.0, 151.6 9.0, 684.5
12.0, 449.5 4.0, 156.4 12.0, 702.9
15.0, 463.7 5.0, 159.6 15.0, 716.8
18.0, 472.7 6.0, 161.4 18.0, 737.0
21.0, 477.6 7.0, 162.4 21.0, 745.6
24.0, 480.9 8.0, 164.1 24.0, 755.4
27.0, 483.7 9.0, 163.7 27.0, 760.7
30.0, 487.3 10.0, 165.9 30.0, 770.4
33.0, 491.4 11.0, 166.3 33.0, 774.3
36.0, 492.9 12.0, 166.2 36.0, 779.0
39.0, 494.7 13.0, 166.7 39.0, 782.5
42.0, 491.3 14.0, 166.7 42.0, 789.0
45.0, 495.1 15.0, 167.5 45.0, 790.0
48.0, 499.7 16.0, 167.7 48.0, 791.8
51.0, 496.2 17.0, 166.9 51.0, 794.0
54.0, 497.6 18.0, 167.7 54.0, 797.4
57.0, 497.1 19.0, 167.5 57.0, 798.6
60.0, 498.8 20.0, 168.8 60.0, 802.1
```
Master
```
5.0, 743.4
10.0, 665.7
15.0, 702.3
20.0, 727.5
25.0, 740.7
30.0, 757.5
35.0, 760.3
40.0, 768.5
45.0, 775.7
50.0, 776.8
55.0, 781.1
60.0, 786.5
65.0, 786.8
70.0, 790.1
75.0, 789.7
80.0, 789.1
85.0, 793.2
90.0, 793.8
95.0, 795.9
100.0, 796.0
3.0, 383.1 1.0, 129.0 3.0, 337.0
6.0, 445.0 2.0, 149.6 6.0, 670.6
9.0, 445.3 3.0, 159.6 9.0, 678.6
12.0, 474.9 4.0, 164.1 12.0, 705.5
15.0, 480.8 5.0, 167.2 15.0, 718.3
18.0, 490.3 6.0, 169.1 18.0, 733.3
21.0, 493.9 7.0, 168.5 21.0, 742.5
24.0, 503.8 8.0, 171.9 24.0, 756.4
27.0, 506.7 9.0, 171.3 27.0, 759.8
30.0, 508.7 10.0, 172.4 30.0, 767.1
33.0, 515.7 11.0, 174.2 33.0, 773.7
36.0, 516.7 12.0, 170.4 36.0, 781.7
39.0, 519.1 13.0, 174.4 39.0, 782.1
42.0, 515.7 14.0, 174.1 42.0, 787.0
45.0, 519.2 15.0, 172.7 45.0, 788.1
48.0, 522.2 16.0, 175.4 48.0, 791.7
51.0, 519.6 17.0, 175.1 51.0, 795.7
54.0, 518.5 18.0, 174.8 54.0, 795.8
57.0, 519.1 19.0, 174.4 57.0, 796.6
60.0, 521.5 20.0, 175.6 60.0, 800.1
```
<details>
<summary>Benchmarking script </summary>
```
import torch
from matplotlib import pyplot as plt
from torch.utils.benchmark import Timer, Compare
import math
import click
print(torch.cuda.get_device_capability()) # check that we are on Volta (compute capability 7,0)
#torch.cuda.set_device(1)
# don't benchmark on anything too small, you'll see only overhead
click.command()
click.option('--op_str', default="torch.gt")
click.option('--dtype_str', default="float", type=click.Choice(['float', 'half']))
def bench(op_str, dtype_str):
if dtype_str == "float":
dtype = torch.float
elif dtype_str == "half":
dtype = torch.half
MB = 1024 * 1024
size = MB
results = []
sizes = []
for _ in range(20):
torch.cuda.memory.empty_cache()
a=torch.randn(int(size), device="cuda", dtype=dtype)
b=torch.randn(int(size), device="cuda", dtype=dtype)
t = Timer(stmt=f"{op_str}(a,b)", label = op_str, sub_label=f"{size/MB} MB", description="contiguous", globals = {"a":a, "b":b})
res = t.blocked_autorange()
results.append(res)
sizes.append(size)
size += MB
del a #to save memory for next iterations
del b
c=Compare(results)
#print(c)
bw=[]
bytes=[]
element_size = torch.tensor([], dtype=dtype).element_size()
output_element_size = 1
for res, size in zip(results,sizes):
bytes_io = 2*size*element_size + output_element_size * size
bytes.append(bytes_io/MB)
# we'll report bandwidth in GB/s
bw.append(bytes_io/res.median * 1e-9)
print(f"{bytes_io/MB:7.1f}, {bw[-1]:7.1f}")
sizes = []
results = [[],[],[]]
size = MB
for _ in range(20):
torch.cuda.memory.empty_cache()
M = math.floor(math.sqrt(size))
a=torch.randn(1, M, device="cuda", dtype=dtype)
b=torch.randn(M, M, device="cuda", dtype=dtype)
b1 = torch.randn(M, 1, device="cuda", dtype=dtype)
tb = Timer(stmt=f"{op_str}(a,b)", label = op_str, sub_label=f"{M*M/MB} MB", description="MMM1", globals = {"a":a, "b":b})
t1 = Timer(stmt=f"{op_str}(a,b1)", label = op_str, sub_label=f"{M*M/MB} MB", description="M11M", globals = {"a":a, "b1":b1})
ts = Timer(stmt=f"{op_str}(b,1.)", label = op_str, sub_label=f"{M*M/MB} MB", description="scalar", globals = {"a":a, "b":b})
res = [t.blocked_autorange() for t in (tb, t1, ts)]
for (rl, r) in zip(results, res):
rl.append(r)
sizes.append(M)
size += MB
del a #to save memory for next iterations
del b
comps = [Compare(r) for r in results]
#[print(c) for c in comps]
bw=[[],[],[]]
for res, res1, ress, size in zip(results[0],results[1],results[2], sizes):
bytes_io = (size+size*size)*element_size + output_element_size * size*size #(size+size+size*size)*4
bytes_io1 = (size+size)*element_size + output_element_size * size*size #(size+size+size*size)*4
bytes_ios = (size*size)*element_size + output_element_size * size * size
bytes_iol = (bytes_io, bytes_io1, bytes_ios)
for (bw_elem, bytes_elem, res_elem) in zip(bw, bytes_iol, (res, res1, ress)):
bw_elem.append(bytes_elem/res_elem.median * 1e-9)
print(f"{bytes_iol[0]/MB:7.1f}, {bw[0][-1]:7.1f}", f"{bytes_iol[1]/MB:7.1f}, {bw[1][-1]:7.1f}",
f"{bytes_iol[2]/MB:7.1f}, {bw[2][-1]:7.1f}")
if __name__ == '__main__':
bench()
```
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69111
Reviewed By: mruberry
Differential Revision: D32851098
Pulled By: ngimel
fbshipit-source-id: cfb83922b2e8eb6a0ad0621ff07c2dada9c8e626