pytorch
8820dda9 - Revise def of contiguity in bmm (#110811)

Commit
1 year ago
Revise def of contiguity in bmm (#110811) Fixes #108754. `hf_T5_generate` would encounter a regression when calling `extern_kernels.bmm`, if one input is `reinterpret_tensor(buf2, (8, 1, 64), (64, 0, 1))` rather than `reinterpret_tensor(buf2, (8, 1, 64), (64, 512, 1), 0)`. As @jgong5 mentioned in comment, in fact the two tensors are equivalent: The stride doesn't matter when the corresponding size is 1. We revise the definition of contiguity in `bmm` to add the above situation as a contiguous case. Thus, when stride equals to 0, `extern_kernels.bmm` could still use `gemm` of MKL to gain the performance. Speedup of `hf_T5_generate` is **1.343x** now and **1.138x** before, with script `bash inductor_single_test.sh multiple inference performance torchbench hf_T5_generate float32 first dynamic default 0`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110811 Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/Chillee
Author
Committer
Parents
Loading