onnxruntime
f12a89e9 - [WebNN EP] Support GroupQueryAttention(GQA) (#23416)

Commit
262 days ago
[WebNN EP] Support GroupQueryAttention(GQA) (#23416) ### Description <!-- Describe your changes. --> Adds support for GroupQueryAttention via WebNN matmul, transpose, reshape, and other operations that follow the logic in the GQA subgraph below. ``` Abbreviations: B is batch_size, S is sequence_length, W is hidden_size, P is past_sequence_length N is number of attention heads, H is head size, and W=N*H, h=Sqrt(H), G is group size. GQA inputs: query, key value, past_key, past_value, seqlens_k, total_sequence_length Notes: If the datatype of the inputs (qkv and past kv) is float16, we cast them to float32 to ensure data precision. query key value | | | Reshape Reshape Reshape (B,S,H,N) seqlens_k | | | / | | | past_value | (scatter_indices*) | q_Transpose | \ | / | (0,2,1,3) | past_key ScatterND-----------------------|------> present_value \ | / | | present_key<--\----ScatterND Expand(G) (attention_bias, one/finfo_min mask*) \ | | / | Expand(G) | / | | | / | k_Transpose | / | (0,1,3,2) | / | | | / +---------------------------------------+ | ScaledDotProductAttention | +---------------------------------------+ | output ``` The ScaledDotProductAttention logic is: ``` ScaledDotProductAttention Subgraph: The basis for MultiHeadAttention and GroupQueryAttention inputs: query, key, value, scale, attention mask, and reshape_output_shape (for reshape) Abbreviatios: B is batch_size, S is query sequence_length, kv_S is key/value sequence length, N is number of attention heads, H is head size, W is hidden_size query key | | +---matmul---+ scale | | +-----div-----+ attn_mask | | +-----add-----+ value | | +------matmul-----+ | (0,2,1,3) transpose B,H,S,N -> B,S,H,N | Reshape B,S,H,N -> B,S,W | output ``` scatter_indices's calculation: ``` if_prefill (0/1 constant) | scatter_indices_left_constant scatter_indices_right_constant 0 ---> Where <--- Cast <---seqlens_k | | | | Add <--------------------------- scatter_pos* | | +--------------------+---------------------+ | scatter_indices ``` attention_bias's calculation: ``` ones_array (shape=B,N,S,P) range_of_qkv_sequence_length_constant (0,1,2,...) (shape=S) | | CumSum (axis=3, exclusive=true, reversed=false) Add <--- scatter_pos | | | Expand (shape=P,S) | | +-------------------------------> Lesser <------------------------------Transpose (1,0) | 1 ---> Where <--- finfo_min (minimum value of FP32) | attention_bias ``` *Notes: Now we only support `past_sequence_length == total_sequence_length` for GQA.* ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
Author
Parents
Loading