pytorch
8383b5c4 - Improve `bsr @ strided` performance in `baddmm` for `bfloat16/half` with Triton kernels. (#88078)

Commit
2 years ago
Improve `bsr @ strided` performance in `baddmm` for `bfloat16/half` with Triton kernels. (#88078) As per title. Additionally we also introduce support for: - Rectangular block sizes which are powers of 2 and at least 16 (triton's `dot` limitation). - Batch support with broadcasting for either of the arguments. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88078 Approved by: https://github.com/cpuhrsch
Author
Committer
Parents
Loading