MatMul heuristics for aarch64 (#107167)
This PR focuses on improving MatMul performance for aarch64 only. It introduces a light-weight heuristic that dispatches small or tall/flat MatMul operations to OpenBLAS while other shapes to MKLDNN/ACL.
On average, the proposed heuristics improve MatMul operator latency by 1.03x / 1.04x / 1.05x / 1.09x / 1.22x for 1 / 2 / 4 / 8 / 16 threads, respectively (baseline is using ACL for all MatMuls on AWS Graviton c7g instances).
Fixes #107168
<details>
<summary>Full MatMul benchmark script and result</summary>
Run this following script `run.sh` with `ABt.py` under the same directory:
```shell
#!/bin/bash
script=ABt.py
prefix=acl
OMP_NUM_THREADS=1 DNNL_DEFAULT_FPMATH_MODE=BF16 TORCH_MKLDNN_MATMUL_MIN_DIM=0 TORCH_MKLDNN_MATMUL_MIN_SIZE=0 python ${script} > ${prefix}_t1.csv
OMP_NUM_THREADS=2 DNNL_DEFAULT_FPMATH_MODE=BF16 TORCH_MKLDNN_MATMUL_MIN_DIM=0 TORCH_MKLDNN_MATMUL_MIN_SIZE=0 python ${script} > ${prefix}_t2.csv
OMP_NUM_THREADS=4 DNNL_DEFAULT_FPMATH_MODE=BF16 TORCH_MKLDNN_MATMUL_MIN_DIM=0 TORCH_MKLDNN_MATMUL_MIN_SIZE=0 python ${script} > ${prefix}_t4.csv
OMP_NUM_THREADS=8 DNNL_DEFAULT_FPMATH_MODE=BF16 TORCH_MKLDNN_MATMUL_MIN_DIM=0 TORCH_MKLDNN_MATMUL_MIN_SIZE=0 python ${script} > ${prefix}_t8.csv
OMP_NUM_THREADS=16 DNNL_DEFAULT_FPMATH_MODE=BF16 TORCH_MKLDNN_MATMUL_MIN_DIM=0 TORCH_MKLDNN_MATMUL_MIN_SIZE=0 python ${script} > ${prefix}_t16.csv
prefix=heur
OMP_NUM_THREADS=1 DNNL_DEFAULT_FPMATH_MODE=BF16 TORCH_MKLDNN_MATMUL_MIN_DIM=8 TORCH_MKLDNN_MATMUL_MIN_SIZE=8192 python ${script} > ${prefix}_t1.csv
OMP_NUM_THREADS=2 DNNL_DEFAULT_FPMATH_MODE=BF16 TORCH_MKLDNN_MATMUL_MIN_DIM=8 TORCH_MKLDNN_MATMUL_MIN_SIZE=8192 python ${script} > ${prefix}_t2.csv
OMP_NUM_THREADS=4 DNNL_DEFAULT_FPMATH_MODE=BF16 TORCH_MKLDNN_MATMUL_MIN_DIM=8 TORCH_MKLDNN_MATMUL_MIN_SIZE=8192 python ${script} > ${prefix}_t4.csv
OMP_NUM_THREADS=8 DNNL_DEFAULT_FPMATH_MODE=BF16 TORCH_MKLDNN_MATMUL_MIN_DIM=8 TORCH_MKLDNN_MATMUL_MIN_SIZE=8192 python ${script} > ${prefix}_t8.csv
OMP_NUM_THREADS=16 DNNL_DEFAULT_FPMATH_MODE=BF16 TORCH_MKLDNN_MATMUL_MIN_DIM=8 TORCH_MKLDNN_MATMUL_MIN_SIZE=8192 python ${script} > ${prefix}_t16.csv
```
`ABt.py`:
```python
import argparse
import timeit
import torch
import numpy as np
BATCH = 1
DIM_MIN = 8
DIM_MAX = 256
M_MIN = DIM_MIN
K_MIN = DIM_MIN
N_MIN = DIM_MIN
M_MAX = DIM_MAX
K_MAX = DIM_MAX
N_MAX = DIM_MAX
min_iterations = 1000
min_time = 0.1
def get_stats(timings):
times = np.array(timings)
time_avg = np.average(times) * 1000
time_med = np.median(times) * 1000
time_90th = np.percentile(times, 90) * 1000
time_99th = np.percentile(times, 99) * 1000
return time_avg, time_med, time_90th, time_99th
def bench(M, K, N, min_iterations, min_time):
a = torch.randn(M, K, dtype=torch.float32)
b = torch.randn(N, K, dtype=torch.float32)
timings = []
with torch.no_grad():
for _ in range(max(1, min_iterations // 100)):
c = torch.matmul(a, b.transpose(0, 1))
bench_time = timeit.default_timer()
while True:
for _ in range(min_iterations):
start_time = timeit.default_timer()
c = torch.matmul(a, b.transpose(0, 1))
end_time = timeit.default_timer()
timings.append(end_time - start_time)
if timeit.default_timer() - bench_time > min_time:
break
return get_stats(timings)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-l', '--loop', dest='loop', action='store_true')
flags = parser.parse_args()
if flags.loop:
while True:
for M in range(M_MAX//2, M_MAX+1, 8):
for K in range(K_MAX//2, K_MAX+1, 8):
for N in range(N_MAX//2, N_MAX+1, 8):
stats = bench(M, K, N, min_iterations, min_time)
else:
torch.manual_seed(0)
print(f"M, K, N, latency")
for M in range(M_MIN, M_MAX+1, 8):
for K in range(K_MIN, K_MAX+1, 8):
for N in range(N_MIN, N_MAX+1, 8):
stats = bench(M, K, N, min_iterations, min_time)
print(f"{M}, {K}, {N}, {stats[2]}")
```
Here's the benchmark result collected on AWS Graviton c7g instance. Due to the size of the table, I can only attach the result in this text file:
[table.txt](https://github.com/pytorch/pytorch/files/12374265/table.txt)
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107167
Approved by: https://github.com/malfet