onnxruntime
Add fusions for re-designed Phi-3 vision and Phi-3.5 vision ONNX models
#22026
Merged

Add fusions for re-designed Phi-3 vision and Phi-3.5 vision ONNX models #22026

kunal-vaishnavi
kunal-vaishnavi283 days ago

Description

This PR adds the optimizer logic to fuse the newly designed exported ONNX models for Phi-3 vision and Phi-3.5 vision.

Motivation and Context

After the re-designed export of Phi-3 vision and Phi-3.5 vision, the ONNX models for the vision component and embedding component contain If and Loop ops to handle multi-image support.

kunal-vaishnavi Support updating graph when subgraphs exist
91bfb84c
kunal-vaishnavi Add changes suggested by linter
31104aee
yufenglee yufenglee requested a review from tianleiwu tianleiwu 282 days ago
kunal-vaishnavi Increase fused Gelu op count
7bebab25
tianleiwu
tianleiwu commented on 2024-09-09
Conversation is marked as resolved
Show resolved
onnxruntime/python/tools/transformers/fusion_skiplayernorm.py
5656 # Root Mean Square Layer Normalization
5757 simplified = node.op_type == "SimplifiedLayerNormalization"
5858
59
if self.shape_infer_helper is not None:
tianleiwu282 days ago (edited 282 days ago)

The shape check is still required. I think currently we support limited broadcast in cuda kernel, so this logic shall check whether the broadcast is allowed:

  • input tensor shape: (batch_size, sequence_length, hidden_size)
  • skip tensor with shape: (batch_size, sequence_length, hidden_size) or (1, sequence_length, hidden_size) or (sequence_length, hidden_size)

Currently SkipLayerNorm only allows 3D inputs, while LayerNorm allows 4D inputs etc.
We might also need add a check of dimensions like

if self.shape_infer_helper.get_edge_shape(add.input[0]) and len(self.shape_infer_helper.get_edge_shape(add.input[0])) != 3:
    return
kunal-vaishnavi282 days ago

This PR says broadcasting support for the above shapes has been added for CPU and CUDA. I've added a workaround to keep the shape check.

tianleiwu
tianleiwu commented on 2024-09-09
Conversation is marked as resolved
Show resolved
onnxruntime/python/tools/transformers/onnx_model.py
934949
950 if len(self.graphs()) > 1:
951 # Get input names for all nodes in all subgraphs
952
subgraph_nodes = list(filter(lambda node: node.op_type in {"Loop", "Scan", "If"}, self.model.graph.node))
tianleiwu282 days ago

other ops like BeamSearch also has subgraph. May add some TODO comments if we do not handle them right now.

kunal-vaishnavi282 days ago

Added a TODO comment and kept the return early logic

tianleiwu
tianleiwu commented on 2024-09-09
Conversation is marked as resolved
Show resolved
onnxruntime/python/tools/transformers/onnx_model.py
7474 input_name_to_nodes[input_name].append(node)
7575 return input_name_to_nodes
7676
77
def input_name_to_nodes_for_main_graph(self):
tianleiwu282 days ago

May use interface like input_name_to_nodes(self, exclude_subgraphs=False)

kunal-vaishnavi282 days ago

Fixed

tianleiwu
tianleiwu commented on 2024-09-09
Conversation is marked as resolved
Show resolved
onnxruntime/python/tools/transformers/onnx_model.py
954 for parent_node in subgraph_nodes:
955 for attr in parent_node.attribute:
956 if attr.type == AttributeProto.GRAPH:
957
child_nodes = attr.g.node
tianleiwu282 days ago (edited 282 days ago)

May add a help function of subgraphs() or subgraph_nodes() to simplify the code.

kunal-vaishnavi282 days ago

Added

kunal-vaishnavi Add changes from PR feedback
85fed782
github-advanced-security
github-advanced-security commented on 2024-09-10
onnxruntime/python/tools/transformers/onnx_model.py
906908 if len(unused_nodes) > 0:
907909 logger.debug(f"Removed unused constant nodes: {len(unused_nodes)}")
908910
911
def get_subgraph_nodes_and_inputs(self, ops_with_graph_attrs={"Loop", "Scan", "If"}):
github-advanced-security282 days ago

RUFF/B006

Do not use mutable data structures for argument defaults.
See https://docs.astral.sh/ruff/rules/mutable-argument-default

Show more details

kunal-vaishnavi Add changes suggested by linter
b016f011
kunal-vaishnavi Fix CI test
6a375c35
tianleiwu
tianleiwu commented on 2024-09-10
Conversation is marked as resolved
Show resolved
onnxruntime/python/tools/transformers/onnx_model.py
10001034 remaining_input_names = []
10011035 for node in graph.node:
10021036 if node.op_type in ["Loop", "Scan", "If"]:
1003 # TODO: handle inner graph
1004 logger.debug(f"Skip update_graph since graph has operator: {node.op_type}")
1005 return
1037
# Add input names of nodes in subgraphs
tianleiwu282 days ago

nit: use self.get_subgraph_nodes_and_inputs here?

kunal-vaishnavi281 days ago

Added

kunal-vaishnavi Refactor some logic
e4952cb6
tianleiwu
tianleiwu commented on 2024-09-10
Conversation is marked as resolved
Show resolved
onnxruntime/python/tools/transformers/onnx_model.py
916 for attr in node.attribute:
917 if attr.type == AttributeProto.GRAPH:
918 child_nodes = attr.g.node
919
for child_node in child_nodes:
tianleiwu281 days ago

nit: Add a comment that we do handle subgraph of child nodes. This function only handle one level subgraphs.

tianleiwu
tianleiwu dismissed these changes on 2024-09-10
kunal-vaishnavi Add note on function capabilities
2cada0bc
kunal-vaishnavi kunal-vaishnavi dismissed their stale review via 2cada0bc 281 days ago
tianleiwu
tianleiwu dismissed these changes on 2024-09-10
tianleiwu
tianleiwu commented on 2024-09-10
Conversation is marked as resolved
Show resolved
onnxruntime/python/tools/transformers/onnx_model.py
972
973 # Check if node output is an input of a subgraph node and not an input to a node in the main graph
974 for output in node.output:
975
if output in subgraph_nodes_inputs and output not in input_name_to_nodes_for_main_graph:
tianleiwu281 days ago (edited 281 days ago)

The logic of keep_outputs has assumption that Loop/Scan/If nodes will not be pruned. It might not be True.
If we do want to handle it right now, please add some comments about the assumption or TODO.

tianleiwu
tianleiwu commented on 2024-09-10
Conversation is marked as resolved
Show resolved
onnxruntime/python/tools/transformers/onnx_model.py
906908 if len(unused_nodes) > 0:
907909 logger.debug(f"Removed unused constant nodes: {len(unused_nodes)}")
908910
911
def get_subgraph_inputs_of_node(self, node):
tianleiwu281 days ago (edited 281 days ago)

Can we change this to private function (add _ prefix)? I think it is not ready as public API.
For example, attr.graphs also contains subgraphs but it is not handled here.

tianleiwu
tianleiwu commented on 2024-09-10
Conversation is marked as resolved
Show resolved
onnxruntime/python/tools/transformers/onnx_model.py
921 subgraph_nodes_inputs.update(child_node.input)
922 return subgraph_nodes_inputs
923
924
def get_subgraph_nodes_and_inputs(self, ops_with_graph_attrs):
tianleiwu281 days ago

Please change this to private function as well.

tianleiwu tianleiwu dismissed their stale review 281 days ago
new comments
kunal-vaishnavi Make functions private and add TODO comment
68704df5
kunal-vaishnavi Fix TODO comment
97a6d443
tianleiwu
tianleiwu approved these changes on 2024-09-10
kunal-vaishnavi kunal-vaishnavi merged c5418f35 into main 281 days ago

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone