[torch] do not fold bmm into mm when tensor1 dim==3 but not contiguous (#73115)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73115
matmul for [B, M, K] x [K, N] was mapped to mm by folding the first 2dim of tensor1 to [BxM, K] x [K, N] but when M and K are transposed it's better to use BMM to avoid data movement.
We could generalize the condition we don't fold (see more details in the comment) but being conservative here to be cautious about potential unintended regression.
Test Plan:
In the following simple test case, before this diff
0.00652953577041626 0.003044447898864746
Permutation takes about same time as GEMM
After this diff
0.002983328104019165 0.0030336639881134034
Permutation overhead essentially went away.
```
B = 128
M = 1024
N = 128
K = 1024
X = torch.rand(B, K, M).cuda()
b = torch.rand(N).cuda()
W = torch.rand(N, K).cuda()
X = X.permute(0, 2, 1)
Y = F.linear(X, W, b)
X_contiguous = X.contiguous()
Y_ref = F.linear(X_contiguous, W, b)
torch.testing.assert_close(Y, Y_ref)
t1, _ = benchmark_torch_function(F.linear, X, W, b, 0)
t2, _ = benchmark_torch_function(F.linear, X_contiguous, W, b, 0)
print(t1, t2)
```
Reviewed By: ngimel
Differential Revision: D34350990
fbshipit-source-id: 73e99f785a405cf7a92b909b16f2022b48b1660f
(cherry picked from commit bec995b899710991bb2a304a8009a67f38244114)