Improve `bsr @ strided` performance in `baddmm` for `bfloat16/half` with Triton kernels. (#88078)
As per title.
Additionally we also introduce support for:
- Rectangular block sizes which are powers of 2 and at least 16 (triton's `dot` limitation).
- Batch support with broadcasting for either of the arguments.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88078
Approved by: https://github.com/cpuhrsch