Give at::cuda::blas::gemv<at::Half> parity with <float> and <double>. Nature is healing. (#37569)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/37157 on my machine.
This was annoying to track down. The essence is that cublas expects column major inputs and Pytorch tensors are usually row major. Cublas lets you request that it act on transposed data, and the erroring `gemv` calls in https://github.com/pytorch/pytorch/issues/37157 make that request. The problem is, [cublasSgemv and cublasDgemv](https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-gemv) (called by [`gemv<float>`](https://github.com/pytorch/pytorch/blob/091a1192d7c1013915b100dd1a4d00eecf6abe5e/aten/src/ATen/cuda/CUDABlas.cpp#L318) and `gemv<double>`) regard their `m, n` arguments values as _pre_-transpose sizes, while [cublasGemmEx](https://docs.nvidia.com/cuda/cublas/index.html#cublas-GemmEx) (called by `gemv<at::Half>`, see [here](https://github.com/pytorch/pytorch/blob/091a1192d7c1013915b100dd1a4d00eecf6abe5e/aten/src/ATen/cuda/CUDABlas.cpp#L342) and [here](https://github.com/pytorch/pytorch/blob/091a1192d7c1013915b100dd1a4d00eecf6abe5e/aten/src/ATen/cuda/CUDABlas.cpp#L229)) regards its `m, k` argument values as _post_-transpose sizes. This is inconsistent. It turns out the `gemv<float>/<double>` calls are configured correctly and the `gemv<at::Half>` calls aren't.
Strikethrough text below is no longer accurate, ngimel suggested a better way to handle gemv->gemm forwarding. [Comments in code](https://github.com/pytorch/pytorch/pull/37569/files#diff-686aa86335f96b4ecb9b37f562feed12R323-R348) provide an up-to-date explanation.
Keeping out-of-date strikethrough text because I don't have the heart to delete it all and because it captures an intermediate state of my brain that will help orient me if i ever have to fix this again.
~~To convince myself this PR keeps `at::cuda::blas::gemv`'s external API consistent across dtypes, I need to think through what happens when a pytorch tensor input of size `(a,b)` multiples a vector of size `(b,)` for 4 cases:~~
### ~~1. input is row-major (needs cublas internal transpose)~~
#### ~~1a. input is float or double~~
~~`gemv<float>/<double>` call `cublasS/Dgemv`, forwarding `trans`,** `m`, and `n` directly.~~
~~`cublasS/Ggemv` expects "a m × n matrix stored in column-major format" (so m is the input's fast dim). Input has size `(a, b)` in row-major format. We can reinterpret it as a column-major matrix with size `(b, a)` without any memory movement. So the gemv call should supply `m=b`, `n=a`. However, we're not trying to multiply a matrix `(b, a)` x a vector `(b,)`, we're trying to sum across `b` for matrix and vector. So we also request that cublas transpose the matrix internally by supplying `trans='t'` to `blas::gemv`, which becomes `trans=CUBLAS_OP_T` to the `cublasS/Ggemv`.~~
~~As long as the code calling `blas::gemv` thinks carefully and passes `trans='t'`, `m=b`, `n=a`, cublas carries out `(a, b) x (b,)` and all is well.~~
#### ~~1b. input is half or bfloat16~~
~~`blas::gemv<at::Half>` takes a different code path, calling `gemm<at::Half>` which calls `cublasGemmEx`. The job of this PR is to make sure the exterior `blas::gemv` caller's carefully thought-out argument choices (`trans='t'`, `m=b`, `n=a`) remain correct.~~
~~`cublasGemmEx` takes args `transa, transb, m, n, k, ....others we don't care about` and carries out~~
```
C = α op ( A ) op ( B ) + β C
where α and β are scalars, and A , B and C are matrices stored in column-major format with
dimensions op ( A ) m × k , op ( B ) k × n and C m × n Also, for matrix A
A if transa == CUBLAS_OP_N
op ( A ) = A^T if transa == CUBLAS_OP_T ...
```
~~`gemv<at::Half>` hacks a gemv by calling gemm such that the raw gemm's `m` is the output dim, `k` is the summed dim, and `n=1`, . Reasonable, as long as we get the values right, given that we also need to transpose the input.~~
~~To conform with cublas docs we interpret input as column-major with size `(b, a)`. As for the `<float>/<double>` gemv we want cublas to carry out input (interpreted as column major), internally transposed, times vector of size `(b,)`. In other words we want cublas to apply `op(A) x B`, where op is transpose and `A` is input interpreted as column major. Docs define `m` and `k` by saying `op(A)` has dims `m x k` **(`m` and `k` are _post_-`op` sizes)**. `A` was `(b, a)`, `op(A)` is `(a, b)`, so the correct thing is to supply `m=a`, `k=b` to the underlying gemm. **For the `<float>/<double>` gemv, we passed `m=b`, not `m=a`, to the raw `cublasS/Dgemv`.**~~
~~The exterior `blas::gemv` must have been called with `trans='t'`, `m=b`, `n=a` (as required by the `<float>/<double>` versions). So when gemv is about to call gemm, **we [swap](https://github.com/pytorch/pytorch/pull/37569/files#diff-686aa86335f96b4ecb9b37f562feed12R330) the local values of `m` and `n` so that `m=a`, `n=b`,** then put `m (=a)` in the gemm's `m` spot, 1 in the gemm's `n` spot, and `n (=b)` in the gemm's `k` spot. All is well (we made the right gemm call after ingesting the same arg values as `blas::gemv<float>/<double>`).~~
### ~~2. input is column-major (doesn't need cublas transpose)~~
#### ~~2a. input is float or double~~
~~input is `(a,b)`, already column-major with strides `(1,a)`. Code calling `blas::gemv` supplies `trans='n'` (which becomes `CUBLAS_OP_N`, no internal transpose), `m=a`, `n=b`.~~
#### ~~2b. input is half or bfloat16~~
~~`blas::gemv` should pass `transa='n'`, `m=a`, `n=1`, `k=b` to the underlying gemm. The exterior `blas::gemv` must have been called with `trans='t'`, `m=a`, `n=b` (as required by the `<float>/<double>` versions). So **in this case we _don't_ swap `blas::gemv`'s local values of `m` and `n`.** We directly put `m (=a)` in the gemm's `m` spot, 1 in the gemm's `n` spot, and `n (=b)` in the gemm's `k` spot. All is well (we made the right gemm call after ingesting the same arg values as `blas::gemv<float>/<double>`).~~
~~** `trans` is a string `t` or `n` in the `at::cuda::blas::gemv` API, which gets [converted](https://github.com/pytorch/pytorch/blob/091a1192d7c1013915b100dd1a4d00eecf6abe5e/aten/src/ATen/cuda/CUDABlas.cpp#L314) to a corresponding cublas enum value `CUBLAS_OP_T` (do transpose internally) or `CUBLAS_OP_N` (don't transpose internally) just before the raw cublas call.~~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37569
Differential Revision: D21405955
Pulled By: ngimel
fbshipit-source-id: e831414bbf54860fb7a4dd8d5666ef8081acd3ee