BlasKernel: Improve gemm's inner dot product when a is transposed (#80977)
`gemm_transab_` accumulates the sum in the output, despite the inner
loop being over a single output element. This changes it to accumulate
in a register, which also avoids early truncation for bfloat16.
I've also factored out a generic `sum` function that can be shared
with `gemm_transa_` to handle unrolling and multiple accumulators.
I have benchmarked addmm for bfloat16 with shapes
(320,600) X (600,320) and for both layouts I see a significant
speedup.
| layout | Before (ms) | After (ms) |
|----------|-------------|------------|
| transa | 71.5 | 31 |
| transab | 249 | 35 |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80977
Approved by: https://github.com/ngimel