pytorch
e5664c65 - [ONNX] Support aten::scaled_dot_product_attention in torchscript exporter (#99658)

Commit
1 year ago
[ONNX] Support aten::scaled_dot_product_attention in torchscript exporter (#99658) Fixes #97262 <!-- copilot:all --> ### <samp>🤖 Generated by Copilot at d06d195</samp> ### Summary 🆕🚀📝 <!-- 1. 🆕 for adding tests and annotations for a new operator. 2. 🚀 for adding support for exporting a new operator to ONNX. 3. 📝 for fixing a minor formatting issue. --> This pull request adds ONNX opset 14 support for the `nn.functional.scaled_dot_product_attention` operator, which is used for self-attention in transformer models. It does so by adding tests and annotations in `test/onnx/test_op_consistency.py`, and by adding a symbolic function in `torch/onnx/symbolic_opset14.py` that reuses an existing implementation. > _To export `scaled_dot_product_attention`_ > _To ONNX opset 14, we need some extension_ > _We import some modules and types_ > _And add a symbolic that pipes_ > _The existing code with some annotation_ ### Walkthrough * Implement the `nn.functional.scaled_dot_product_attention` operator for ONNX opset 14 ([link](https://github.com/pytorch/pytorch/pull/99658/files?diff=unified&w=0#diff-244955d820ec138d5ddffb20ee6f517cc4c5d281f19ccb53d8db47043b5ac46fR122-R292)) * Add imports for modules and types needed for the operator implementation ([link](https://github.com/pytorch/pytorch/pull/99658/files?diff=unified&w=0#diff-244955d820ec138d5ddffb20ee6f517cc4c5d281f19ccb53d8db47043b5ac46fL17-R23)) * Add a command to run the pytest module for testing the operator consistency ([link](https://github.com/pytorch/pytorch/pull/99658/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753R13)) * Add the operator to the list of operators tested for consistency ([link](https://github.com/pytorch/pytorch/pull/99658/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753R311)) * Add annotations to indicate the operator's limitations and issues ([link](https://github.com/pytorch/pytorch/pull/99658/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L333-R339), [link](https://github.com/pytorch/pytorch/pull/99658/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753R354-R358)) * Remove an empty line at the end of `test/onnx/test_op_consistency.py` ([link](https://github.com/pytorch/pytorch/pull/99658/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L441)) Pull Request resolved: https://github.com/pytorch/pytorch/pull/99658 Approved by: https://github.com/justinchuby
Author
Committer
Parents
Loading