[pytorch] CUDA kernel for torch.cat on contiguous tensors with wide loads (#102815)
This PR creates a CUDA kernel for `CatArrayBatchedCopy` that makes use of vectorized memory loads to maximize HBM bandwidth. It also simplifies the kernel code by removing the path handling not-contiguous inputs. It gets called when the following conditions are met:
- tensors are contiguous
- input data types are of 32bit and 64 bit
- all the input are aligned to 16 bytes boundary
We tested on a larger set of problem sizes and there is net gain for 32 bit types and marginal gain for 64 bit types. Based on our analysis the 32 bit cats are by far the dominant kernel being called.
Results:
<img width="1320" alt="Screenshot 2023-06-02 at 8 10 21 AM" src="https://github.com/pytorch/pytorch/assets/23515689/6f083f7c-2e1a-4513-a994-e0cb072d9b5d">
The SASS Code confirms using the wide loads for input tensors and the stores to global memory are unrolled to maximize oversubscription:
<img width="1648" alt="Screenshot 2023-06-02 at 8 16 29 AM" src="https://github.com/pytorch/pytorch/assets/23515689/10325ee6-d3a0-402a-af0d-29cd1a32813b">
Test Code:
```python
import sys
import torch
l_inputs = [
((1024,), 0, 2, 100),
((4096,), 0, 2, 100),
((16384,), 0, 4, 100),
((32000,), 0, 8, 100),
((128 * 1024,), 0, 2, 100),
((256 * 1024,), 0, 3, 100),
((1 * 1024 * 1024,), 0, 2, 100),
((4 * 1024 * 1024,), 0, 2, 100),
((16 * 1024 * 1024,), 0, 2, 100),
((32 * 1024 * 1024,), 0, 2, 100),
((128 * 1024 * 1024,), 0, 2, 50),
((64, 256), 0, 4, 100),
((400, 400), 0, 2, 100),
((640, 1080), 0, 2, 100),
((128, 4096), 1, 2, 100),
((512, 512), 1, 2, 100),
((699, 713), 1, 2, 100),
((1024, 1024), 1, 2, 100),
((2000, 1000), 1, 2, 100),
((4096, 4096), 1, 2, 100),
((16384, 16384), 1, 2, 50),
((384, 256, 16), 1, 2, 100),
((400, 200, 13), 1, 2, 100),
((128, 64, 256), 0, 2, 100),
((512, 256, 256), 1, 2, 100),
((512, 1024, 1024), 2, 2, 10),
((1024, 512, 1024), 2, 2, 10),
((1024, 1024, 512), 2, 2, 10),
((128, 64, 64, 32), 0, 2, 50),
((128, 64, 128, 16), 1, 2, 50),
((100, 45, 45, 32), 3, 2, 50),
((128, 32, 256, 32), 3, 2, 50),
]
prof_inputs = [
((1234567,), 0, 2, 5),
((16 * 1024 * 1024,), 0, 3, 5),
((1013, 1013), 0, 2, 5),
((1024, 1024), 1, 2, 5),
((69, 74, 128), 0, 2, 5),
((128, 128, 128), 2, 2, 5),
]
def generate_tensors(dim_tuple, cat_type, num_tensors):
if cat_type in [torch.int8, torch.int32, torch.int64]:
l_tensors = [
torch.randint(
high=torch.iinfo(cat_type).max,
size=dim_tuple,
dtype=cat_type,
device="cuda",
)
] * num_tensors
return l_tensors
else:
l_tensors = [
torch.randn(dim_tuple, dtype=cat_type, device="cuda")
] * num_tensors
return l_tensors
def test_simple_cat(
dim_tuple, cat_dim: int, num_tensors: int, iterations: int, cat_type
):
torch.cuda.synchronize()
# Allocate a tensor equal to L2 cache size on A100 GPUs
l2_cache_flusher = torch.empty(
int(80 * (1024**2)), dtype=torch.float, device="cuda"
)
# All the tensors in the list get read and written once
total_MB = 2 * num_tensors
for dim in dim_tuple:
total_MB *= dim
total_MB /= 1024 * 1024
# Get the number of bits per element
if cat_type in [torch.int8, torch.int32, torch.int64]:
total_MB *= torch.iinfo(cat_type).bits / 8
else:
total_MB *= torch.finfo(cat_type).bits / 8
l_tensors = generate_tensors(dim_tuple, cat_type, num_tensors)
c = torch.cat(l_tensors, dim=cat_dim)
torch.cuda.synchronize()
# Measure correctness
l_tensors_cpu = []
for t in l_tensors:
l_tensors_cpu.append(t.detach().to("cpu"))
c_cpu = torch.cat(l_tensors_cpu, dim=cat_dim)
c_cpu_dev = c.detach().to("cpu")
if not torch.equal(c_cpu, c_cpu_dev):
missmatches = torch.count_nonzero(torch.abs(c_cpu - c_cpu_dev))
print("Error; num missmatches for {0} = {1}".format(dim_tuple, missmatches))
return
# Measure a few iterations
l_ev_start = [torch.cuda.Event(enable_timing=True)] * iterations
l_ev_stop = [torch.cuda.Event(enable_timing=True)] * iterations
l_cat_times = []
torch.cuda.synchronize()
for i in range(iterations):
l2_cache_flusher.zero_()
torch.cuda._sleep(1_000_000)
l_ev_start[i].record()
c = torch.cat(l_tensors, dim=cat_dim)
l_ev_stop[i].record()
torch.cuda.synchronize()
for i in range(iterations):
t_cat = l_ev_start[i].elapsed_time(l_ev_stop[i]) / 1000
l_cat_times.append(t_cat)
min_cat_time = min(l_cat_times)
# return bandwidth in GB/s
estimated_bw_GBps = total_MB / min_cat_time / 1024
return estimated_bw_GBps
def main(argv):
if len(argv) > 0:
if "profile" in str(argv[0]):
for l_input in prof_inputs:
gbps = test_simple_cat(
l_input[0], l_input[1], l_input[2], l_input[3], torch.float
)
print(
"Bandwidth (GB/s) for {0} fp32 | {1:.2f}".format(
(l_input[0], l_input[1]), gbps
)
)
return
for l_input in l_inputs:
gbps_int8 = test_simple_cat(
l_input[0], l_input[1], l_input[2], l_input[3], torch.int8
)
gbps_fp16 = test_simple_cat(
l_input[0], l_input[1], l_input[2], l_input[3], torch.float16
)
gbps_fp32 = test_simple_cat(
l_input[0], l_input[1], l_input[2], l_input[3], torch.float32
)
gbps_int32 = test_simple_cat(
l_input[0], l_input[1], l_input[2], l_input[3], torch.int32
)
gbps_fp64 = test_simple_cat(
l_input[0], l_input[1], l_input[2], l_input[3], torch.float64
)
gbps_long = test_simple_cat(
l_input[0], l_input[1], l_input[2], l_input[3], torch.long
)
print(
"Bandwidth (GB/s) for {0} int8;fp16;fp32;int32;fp64;long|{1:.2f}|{2:.2f}|{3:.2f}|{4:.2f}|{5:.2f}|{6:.2f}".format(
(l_input[0], l_input[1]),
gbps_int8,
gbps_fp16,
gbps_fp32,
gbps_int32,
gbps_fp64,
gbps_long,
)
)
if __name__ == "__main__":
main(sys.argv[1:])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102815
Approved by: https://github.com/ngimel, https://github.com/malfet