onnxruntime
654c335c - Fix 3d attention mask broadcasting in MHA (#27464)

Commit
93 days ago
Fix 3d attention mask broadcasting in MHA (#27464) (1) Fix 3d attention mask broadcasting in MHA (2) Refactor attention python tests of LLM (add MHA) --- This pull request includes a minor fix to the attention mask broadcasting logic in the CUDA attention kernel, as well as the addition of a missing license header in a test file. Improvements to attention mask broadcasting logic: * Updated the logic in `attention.cc` to clarify and correct how broadcasting is determined for 3D attention masks, ensuring the batch dimension always broadcasts and the heads dimension broadcasts only if its size is 1. This improves correctness and clarity for different mask shapes. Documentation and compliance: * Added the standard Microsoft MIT license header to the `test_onnx_attention/__init__.py` file to ensure proper licensing information is included. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Author
Parents
Loading