onnxruntime
58c29d34 - WIP: Dp4MatMulNBits accuracy level 4 matmul for WebGPU EP (#23365)

Commit
1 year ago
WIP: Dp4MatMulNBits accuracy level 4 matmul for WebGPU EP (#23365) ### Description This change implements accuracy level 4 - quantize A to int8 matmul for the WebGPU EP. The matmul kernel here uses DP4A for matrix multiplication, in order to keep the DP4A fed co-operative matrix multiplication is implemented which preloads the row/col into local variables before the multiplication operation. Credits to @qjia7 for help with the quantizer shader. Performance metrics on intel ADL/TGL GPU. ``` PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500 Batch size: 1, prompt tokens: 501, tokens to generate: 128 Prompt processing (time to first token): avg (us): 2.76762e+06 **avg (tokens/s): 181.022** <<< Prefill speed p50 (us): 2.74843e+06 stddev (us): 41756.4 n: 5 * 501 token(s) Token generation: avg (us): 81500.7 avg (tokens/s): 12.2698 p50 (us): 81104.1 stddev (us): 2961.31 n: 635 * 1 token(s) Token sampling: avg (us): 13.1836 avg (tokens/s): 75851.9 p50 (us): 12 stddev (us): 6.47085 n: 640 * 1 token(s) E2E generation (entire generation loop): avg (ms): 13120 p50 (ms): 13081.6 stddev (ms): 114.689 n: 5 Peak working set size (bytes): 5467533312 WebGPU device lost (2): Device was destroyed. ``` This kernel is 2.10x faster than its F16 counterpart for a 500 token prefill. Previous prefill record is 86tks/s. In order to support devices with subgroup size 8/32, a no subgroup version of the same shader is included. Performance is slower than the subgroup version on ADL. ``` PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500 Batch size: 1, prompt tokens: 501, tokens to generate: 128 Prompt processing (time to first token): avg (us): 4.11989e+06 avg (tokens/s): 121.605 p50 (us): 4.11847e+06 stddev (us): 2147.48 n: 5 * 501 token(s) Token generation: avg (us): 81174.9 avg (tokens/s): 12.3191 p50 (us): 81301.1 stddev (us): 2177.2 n: 635 * 1 token(s) Token sampling: avg (us): 14.7998 avg (tokens/s): 67568.3 p50 (us): 12.3 stddev (us): 11.5481 n: 640 * 1 token(s) E2E generation (entire generation loop): avg (ms): 14431.1 p50 (ms): 14433.8 stddev (ms): 5.02473 n: 5 Peak working set size (bytes): 5466480640 WebGPU device lost (2): Device was destroyed. ```
Author
Parents
Loading