Update T5 Onnx Export and Optimization (#23949)
Previously, the encoder onnx model adds extra initialization for decoder
to generate kv cache from prompt. It is not necessary. Here we redesign
onnx export for T5 model to output two separate models for encode and
decoder.
Move Linear that generates cross features based on encoder_hidden_states
to encoder onnx model. In this way, the encoder does not need output
encoder_hidden_states, and only need output the features for cross
attention used in decoder.
Major changes:
-[x] update t5 onnx export script
-[x] update convert_generation script
-[x] update beam search to support changes of inputs and outputs (detail
can be found below).
-[x] add a tiny t5 model, and enable the generation test for T5 in Linux
CI pipelines.
Example change in inputs and outputs for one layer model:
**Encoder Inputs**:
- encoder_input_ids: int32 (B, encode_sequence_length)
- encoder_attention_mask: int32 (B, encode_sequence_length)
- ~~decoder_input_ids: int32 (B, 1)~~
**Encoder Outputs**:
- ~~logits: (B, 1, vocab_size)~~
- ~~encoder_hidden_states: (B, encode_sequence_length,
encoder_hidden_size)~~
- ~~present_key_self_0: (B, num_heads, 1, head_size)~~
- ~~present_value_self_0: (B, num_heads, 1, head_size)~~
- present_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
- present_value_cross_0: (B, num_heads, encode_sequence_length,
head_size)
**Decoder Inputs**:
- input_ids: int32 (B, 1)
- ~~encoder_input_ids: int32 (B, encode_sequence_length) (optional for
old format; removed in new format)~~
- encoder_attention_mask: int32 (B, encode_sequence_length)
- ~~encoder_hidden_states: (B, encode_sequence_length,
encoder_hidden_size) (optional for old format; removed in new format)~~
- past_key_self_0: (B, num_heads, past_decode_sequence_length,
head_size)
- past_value_self_0: (B, num_heads, past_decode_sequence_length,
head_size)
- past_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
- past_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
**Decoder Outputs**:
- logits: (B, 1, vocab_size)
- present_key_self_0: (B, num_heads, past_decode_sequence_length + 1,
head_size)
- present_value_self_0: (B, num_heads, past_decode_sequence_length + 1,
head_size)
Known issues:
- Some postprocessing (like converting to use decoder masked MHA, past
and present buffer sharing) is not done. Could be a future work item to
integrate with onnxruntime-genai.
### Motivation and Context
Make the encoder onnx model simpler and more efficient in inference (no
need to output encoder_hidden_states).