onnxruntime
7942fa7a - Whisper Redesigned Solution (#23549)

Commit
283 days ago
Whisper Redesigned Solution (#23549) ### Description This PR re-designs how Whisper is created and supported in ONNX Runtime. The new solution leverages [previous optimization work](https://github.com/microsoft/onnxruntime/pull/15473), and it is designed to be used in conjunction with [this work](https://github.com/microsoft/onnxruntime-genai/pull/1229) in ONNX Runtime GenAI. Some of the added changes include: - Re-designed export that creates new ONNX models without needing a `WhisperBeamSearch` op - Creates one encoder model that also pre-computes the cross-attention KV caches (since they only need to be calculated once) - Creates one decoder model that can be used during pre-fill and token generation - Creates one jump-times model that can be used for word-level timestamps - Removes need for a `WhisperBeamSearch` op to chain the encoder and decoder subgraphs - Removes need to duplicate decoder's weights in memory - Previous solution with the `WhisperBeamSearch` op created an encoder-decoder-init model and decoder-with-past model. The decoder was duplicated twice, one in each. - Removes need for separate logic to export the PyTorch model coming from OpenAI vs. the PyTorch model coming from Hugging Face - Re-factors common parameters and logic used in CPU and CUDA attention kernels - Adds `DUMP_STRING` to enable easy logging of intermediate information when running in debug mode to debug a problem. This info is not printed in release mode so it will not impact performance. - Integrates `DecoderMaskedMultiHeadAttention` into `MultiHeadAttention` - Enables past-present buffer sharing in the `MultiHeadAttention` op for improved performance - Adds `cache_indirection` and `past_sequence_length` as new optional inputs to `MultiHeadAttention` - Adds `output_qk` as new optional output to `MultiHeadAttention` - Enables calculating `output_qk` tensor with FP16 or FP32 precision, regardless of the model's precision - CI tests that run end-to-end across various flag combinations that are used by many customers internally and externally The existing solutions are still available if desired. ### Known Issues - The FP32 CPU model with the `WhisperBeamSearch` op and output QK is currently disabled. This is because ONNX Runtime doesn't currently support output QK kernels on CPU, only on CUDA. - The `DecoderMaskedMultiHeadAttention` CPU kernel has a parity mismatch with the `DecoderMaskedMultiHeadAttention` CUDA kernel. - Using `DecoderMaskedMultiHeadAttention` for the FP32 CPU model is not enabled. Currently, it uses `MultiHeadAttention` to avoid the parity mismatch issue. ### Motivation and Context Using the beam search op has made it more difficult to debug and fix errors that are encountered. This new approach is more flexible and more customizable for users (e.g. by running with ONNX Runtime GenAI). It also helps [this issue](https://github.com/microsoft/onnxruntime/issues/18216). --------- Co-authored-by: mindest <linminuser@gmail.com>
Parents
Loading