pytorch
753536b7 - BlasKernel: Improve gemm's inner dot product when a is transposed (#80977)

Commit
2 years ago
BlasKernel: Improve gemm's inner dot product when a is transposed (#80977) `gemm_transab_` accumulates the sum in the output, despite the inner loop being over a single output element. This changes it to accumulate in a register, which also avoids early truncation for bfloat16. I've also factored out a generic `sum` function that can be shared with `gemm_transa_` to handle unrolling and multiple accumulators. I have benchmarked addmm for bfloat16 with shapes (320,600) X (600,320) and for both layouts I see a significant speedup. | layout | Before (ms) | After (ms) | |----------|-------------|------------| | transa | 71.5 | 31 | | transab | 249 | 35 | Pull Request resolved: https://github.com/pytorch/pytorch/pull/80977 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading