onnxruntime
6d34aba9 - Introducing BF16 Pointwise NCHWc Convolution for Arm64 (#26838)

Commit
118 days ago
Introducing BF16 Pointwise NCHWc Convolution for Arm64 (#26838) ## Description This PR adds a BF16 (bfloat16) pointwise convolution kernel for ARM64 NCHWc format, leveraging the existing SBGEMM infrastructure. When the `mlas.enable_gemm_fastmath_arm64_bfloat16` session option is enabled on supported ARM64 Linux hardware, Pointwise Conv is rerouted to use this BF16 implementation. This is an opt-in feature, similar to how BF16 matmul is opt-in. Added a bool ZeroMode field to `MLAS_SBGEMM_DATA_PARAMS` (default `true` for backward compatibility) to enable per-batch control over output accumulation. This mirrors the beta parameter in FP32's `MlasGemmBatch` and is required for Pointwise convolutions with >128 input channels, where multiple GEMM calls must accumulate into the same output buffer. ## Motivation and Context The existing `mlas.enable_gemm_fastmath_arm64_bfloat16` session option accelerates MatMul operations on ARM64 processors with BF16 support, but convolution operations did not benefit from this optimization. Pointwise convolutions (1x1 kernels) are essentially batched matrix multiplications. This change extends the BF16 fastmath optimization to pointwise NCHWc convolutions, reusing the same session option. The implementation mirrors the FP32 pointwise kernel structure while delegating the actual computation to SBGEMM, ensuring correctness and maintainability. ## Performance improvement Measured a 15-20% gain on Mobilenet inference on an AWS Graviton4 instance. Before (FP32) ``` /build/Linux/Release/onnxruntime_perf_test -C "mlas.enable_gemm_fastmath_arm64_bfloat16|0" -x 32 -I -m times -r 2000 ~/scripts/mobilenet.onnx Number of inferences per second: 559.154 ``` After (BF16) ``` ./build/Linux/Release/onnxruntime_perf_test -C "mlas.enable_gemm_fastmath_arm64_bfloat16|1" -x 32 -I -m times -r 2000 ~/scripts/mobilenet.onnx Number of inferences per second: 651.221 ```
Parents
Loading