pytorch
c25fdc20 - [cuBLAS][cuBLASLt] Allow user-specified cuBLASLt workspace size via `CUBLASLT_WORKSPACE_SIZE` (#101145)

Commit
2 years ago
[cuBLAS][cuBLASLt] Allow user-specified cuBLASLt workspace size via `CUBLASLT_WORKSPACE_SIZE` (#101145) Provide an option to configure the workspace size used by cuBLASLt rather than fixing it as a compile-constant of 1MiB due to observed performance differences on H100 and recommendations from cuBLAS e.g., https://docs.nvidia.com/cuda/archive/11.8.0/cuda-toolkit-release-notes/index.html#title-cublas-library. Some quick profiling shows that in some cases up to 32MiB of workspace is needed on H100: ``` import torch import time m = 1024 n = 2048 warmup = 20 iters = 200 dtype = torch.bfloat16 for k in (1024, 2048, 4096, 8192, 9376, 16384, 32768): a = torch.randn(m, k, device='cuda', dtype=dtype) b = torch.randn(n, k, device='cuda', dtype=dtype).transpose(1, 0) i = torch.randn(n, device='cuda', dtype=dtype) for _ in range(warmup): torch.addmm(i, a, b) torch.cuda.synchronize() t1 = time.perf_counter() for _ in range(iters): torch.addmm(i, a, b) torch.cuda.synchronize() t2 = time.perf_counter() print(f"m:{m}, n:{n}, k:{k} TFLOP/s: {( 2*m*n*k)*iters/(t2 - t1)/1e12}") ``` 1MiB: ``` m:1024, n:2048, k:1024 TFLOP/s: 62.40964655242158 m:1024, n:2048, k:2048 TFLOP/s: 79.33321703070685 m:1024, n:2048, k:4096 TFLOP/s: 96.69701590181765 m:1024, n:2048, k:8192 TFLOP/s: 83.2892371366678 m:1024, n:2048, k:9376 TFLOP/s: 83.91872373271516 m:1024, n:2048, k:16384 TFLOP/s: 86.57820235279185 m:1024, n:2048, k:32768 TFLOP/s: 88.37227761178467 ``` 32 MiB: ``` m:1024, n:2048, k:1024 TFLOP/s: 73.50633352382425 m:1024, n:2048, k:2048 TFLOP/s: 104.32016319633199 m:1024, n:2048, k:4096 TFLOP/s: 131.37290416527784 m:1024, n:2048, k:8192 TFLOP/s: 152.08780769805506 m:1024, n:2048, k:9376 TFLOP/s: 154.93898780286096 m:1024, n:2048, k:16384 TFLOP/s: 165.13973167154688 m:1024, n:2048, k:32768 TFLOP/s: 160.62065020500813 ``` CC @ngimel Pull Request resolved: https://github.com/pytorch/pytorch/pull/101145 Approved by: https://github.com/ngimel
Author
eqy eqy
Committer
Parents
Loading