[WebNN EP] Support GroupQueryAttention(GQA) (#23416)
### Description
<!-- Describe your changes. -->
Adds support for GroupQueryAttention via WebNN matmul, transpose,
reshape, and other operations that follow the logic in the GQA subgraph
below.
```
Abbreviations: B is batch_size, S is sequence_length, W is hidden_size, P is past_sequence_length
N is number of attention heads, H is head size, and W=N*H, h=Sqrt(H), G is group size.
GQA inputs: query, key value, past_key, past_value, seqlens_k, total_sequence_length
Notes: If the datatype of the inputs (qkv and past kv) is float16, we cast them to float32 to ensure data precision.
query key value
| | |
Reshape Reshape Reshape (B,S,H,N) seqlens_k
| | | / |
| | past_value | (scatter_indices*) |
q_Transpose | \ | / |
(0,2,1,3) | past_key ScatterND-----------------------|------> present_value
\ | / | |
present_key<--\----ScatterND Expand(G) (attention_bias, one/finfo_min mask*)
\ | | /
| Expand(G) | /
| | | /
| k_Transpose | /
| (0,1,3,2) | /
| | | /
+---------------------------------------+
| ScaledDotProductAttention |
+---------------------------------------+
|
output
```
The ScaledDotProductAttention logic is:
```
ScaledDotProductAttention Subgraph: The basis for MultiHeadAttention and GroupQueryAttention
inputs: query, key, value, scale, attention mask, and reshape_output_shape (for reshape)
Abbreviatios: B is batch_size, S is query sequence_length, kv_S is key/value sequence length,
N is number of attention heads, H is head size, W is hidden_size
query key
| |
+---matmul---+ scale
| |
+-----div-----+ attn_mask
| |
+-----add-----+ value
| |
+------matmul-----+
|
(0,2,1,3) transpose B,H,S,N -> B,S,H,N
|
Reshape B,S,H,N -> B,S,W
|
output
```
scatter_indices's calculation:
```
if_prefill (0/1 constant)
|
scatter_indices_left_constant scatter_indices_right_constant 0 ---> Where <--- Cast <---seqlens_k
| | |
| Add <--------------------------- scatter_pos*
| |
+--------------------+---------------------+
|
scatter_indices
```
attention_bias's calculation:
```
ones_array (shape=B,N,S,P) range_of_qkv_sequence_length_constant (0,1,2,...) (shape=S)
| |
CumSum (axis=3, exclusive=true, reversed=false) Add <--- scatter_pos
| |
| Expand (shape=P,S)
| |
+-------------------------------> Lesser <------------------------------Transpose (1,0)
|
1 ---> Where <--- finfo_min (minimum value of FP32)
|
attention_bias
```
*Notes: Now we only support `past_sequence_length ==
total_sequence_length` for GQA.*
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->