onnxruntime
df581e1c - [WebNN EP] Support MultiHeadAttention(MHA) (#24079)

Commit
316 days ago
[WebNN EP] Support MultiHeadAttention(MHA) (#24079) ### Description <!-- Describe your changes. --> Adds support for MultiHeadAttention via WebNN matmul, transpose, reshape, and other operations that follow the logic in the MHA subgraph below ``` Abbreviatios: B is batch_size, S is sequence_length, W is hidden_size, P is past_sequence_length N is number of attention heads, H is head size, and W=N*H, h=Sqrt(H) Notes: If the datatype of the inputs (qkv and past kv) is float16, we cast them to float32 to ensure data precision. query key value | | | q_Reshape k_Reshape v_Reshape (shape=B,S,H,N) | | | q_Transpose k_Transpose v_Transpose (perm=0,2,1,3) \ / | \ / | present_key<---\----Concat <---------|----past_key | | | | opt_k_transpose | \ (0,1,3,2) | \ / | past_value qk_MatMul | / | scale | / | / | / qk_Div Concat------> present_value | | | / Add <----------/---------------attention_bias | / Softmax / \ / \ / qkv_MatMul | Transpose (perm=0,2,1,3) | Reshape---(shape=B,P,W) | output ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
Author
Parents
Loading