[pt][ATen] Optimize bmm (#49506)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49506
- Get rid of expensive stuff like `TensorArg`, `checkBackend`, `checkSize`, and `TensorAccessor`.
- Add `checkDim` that does not require creating a `TensorArg` which incurs a refcount bump
- Avoid unnecessary calls to `torch.select`, which goes through the dispatcher in the cases we care about, with mat1 and mat2 not permuted or permuted with dims = [0, 2, 1]. The pt version of bmm supports crazy cases like when the inputs are permuted with dims = [1, 2, 0], which is uncommon in SparseNNs.
Test Plan:
Unit test:
```
buck test //caffe2/test:linalg
```
Benchmark with the adindexer model:
```
Before:
I1216 14:02:24.155516 2595800 PyTorchPredictorBenchLib.cpp:209] PyTorch run finished. Milliseconds per iter: 0.0847197. Iters per second: 11803.6
After:
I1216 14:02:26.583878 2595939 PyTorchPredictorBenchLib.cpp:209] PyTorch run finished. Milliseconds per iter: 0.082051. Iters per second: 12187.5
```
Reviewed By: bwasti
Differential Revision: D25577574
fbshipit-source-id: 8aba69b950e7b4d9d1b14ba837931695a908c068