pytorch
1af40d51 - [cublas][cublasLt] Fall back to unfused `addmm` for 2-byte-aligned inputs (#92201)

Commit
2 years ago
[cublas][cublasLt] Fall back to unfused `addmm` for 2-byte-aligned inputs (#92201) Fix for this issue surfaced from the discuss forum: https://discuss.pytorch.org/t/cuda-error-cublas-status-not-supported-when-calling-cublasltmatmul-from-torch-nn-functional-linear/170214 Note that PyTorch builds before #71200 should not be affected as there was no `cublasLt` dispatch path. Additionally, the provided repro has the quirk of using a 3D input, which means it will not dispatch to `cublasLt`-backed `addmm` until builds that include #72728. Changing the input to 2D by trivially removing the size `1` dimension will surface the failure on builds after #71200. Interestingly, the use-case where _all_ inputs are 2-byte aligned are supported (runs without crashing), but when some are > 2-byte and some are == 2-byte are not. This behavior suggests that the `cuBlastLt` heuristics are incorrect, as the heuristic function has visibility of the raw pointer values via the descriptors when it is called. We will follow up with `cuBlasLt` but this fix is needed to prevent unnecessary crashes for now. CC @ptrblck @ngimel Pull Request resolved: https://github.com/pytorch/pytorch/pull/92201 Approved by: https://github.com/ngimel
Author
eqy eqy
Committer
Parents
  • aten/src/ATen/native/cuda
    • File
      Blas.cpp
  • test
    • File
      test_matmul_cuda.py
Loading