onnxruntime
340b188c - Fusing Initializers with Graph Transforms (#24726)

Commit
208 days ago
Fusing Initializers with Graph Transforms (#24726) ### Description Added a graph transform for mixed precision graphs when FP16 compute is unavailable. At session creation, this graph transform converts FP16 initializers (_which were changed to FP16 to FP32 cast nodes_) to FP32 initializers and fuses them with their next FP32 nodes. - Behavior before this change: "fp16 initializers -> cast_from_fp16_to_fp32 -> fp32 node/s" - Behavior after this change: "fp16 initializers converted to fp32 initializers then fused with fp32 node/s" ### Motivation and Context This change aims to run the FP16 models without the repetitive casting of FP16 initializers to FP32 initializers, by fusing FP32 initializers with their next nodes, when FP16 compute is not available. > For naming purposes, the newly added Graph Transforms in long form is called "Fused Initializers Graph Transforms", and in short form is called "FIGT". ### Working Currently, the Fuse Initializers Graph Transform fuses cast nodes that casts from FP16 to FP32, back to their next/output nodes. Below is an explanation of how this transforms works. It depends on ```InsertCastTransforms``` to produce the intermediate representation from which it fuses the initializers (which are the cast node with zero input, one initializer, and one output) back to the next/output node. After fusion, the link/edge between such a cast node to the next/output node will then be removed. Cast nodes will be removed as well. ``` "Input Graph" "Intermediate Representation" "FIGT Transforms" -------- -------- -------- -------- -------- | X_Fp16 | | X_Fp16 | | W_Fp16 | | B_Fp16 | | X_Fp16 | -------- -------- -------- -------- -------- | | | | | | | | | | | V V V V | | Cast | | Cast | | Cast | | Cast | | | Fp16 | | Fp16 | | Fp16 | | Fp16 | | | To | | To | | To | | To | | | Fp32 | | Fp32 | | Fp32 | | Fp32 | | | | | | | | | | | V V V V V ---------------------------- ----------------------------------------- ---------------------------- | Conv_Fp16 | | | | Conv_Fp32 | | --W_Fp16-- | ==> | Conv_Fp32 | ==> | --W_Fp32-- | | --B_Fp16-- | | | | --B_Fp32-- | ---------------------------- ----------------------------------------- ---------------------------- | | | | | | | V V | | Cast | | Cast | | | Fp32 | | Fp32 | | | To | | To | | | Fp16 | | Fp16 | | | | | | | V V V -------- -------- -------- | Y_Fp16 | | Y_Fp16 | | Y_Fp16 | -------- -------- -------- ``` The newly added Graph Transforms perform the following actions. * Detect Cast node/s with single FP16 initializer converting to FP32. * Convert all such FP16 initializer/s to FP32 initializer/s. * Fuse newly created FP32 initializer/s to relative FP32 node/s. * Remove FP16 to FP32 Cast node/s. This is run in a loop as follows. It excludes Level 1 and Partitioning optimizations. ``` Level 2 --> Level 3 --> InsertCastTransforms --> FIGT ^ | | "LOOP" | | | ------------------------------------------------- ``` ### Adding FIGT as a Level-4 Graph Transform. This will have the following benefits. 1. Ability to turn off (any/all) the Level 4 Optimizations. We can use the `disable optimizers` functionality to turn off one of such optimizations during testing, or use the `-o` switch to turn off all Level 4 optimizations while executing a model using the command line or Python scripts (or any other scripts). 2. Ability to rerun Level 2 and Level 3 optimizations remains intact after Level 4 Optimizations are applied. Adding Level 4 takes care that FIGT (or any similar optimizations) will always run after InsertCastNodes. 3. It keeps the current graph manipulations untouched and gives us more flexibility to add future optimizations like adding `Int8 to Int32` upconvert or `FP8 to FP16` upconvert under Level 4. Level 4 can, as of now, work as a placeholder for any other such upcoming Graph optimizations. ``` Level 2 --> Level 3 --> InsertCastTransforms --> Level 4 ^ | | "LOOP" | | | -------------------------------------------------- ``` > Added a placeholder for Level-4 for graph transforms utils under orttraining. This helps resolve any exceptions that may be encountered during training sessions. #### Re-running Level 2+ optimizations after Level 4 / FIGT The idea behind re-running Level2+ graph transforms is that, after the fusion of initializers with their respective nodes, the nodes are now in a format that might be supported by other graph transforms that were previously skipped. Hence, some of the transformations previously unable to be applied are now valid and can be applied to create a more optimal graph for execution. ### Added a new session option "kOrtSessionOptionsGraphOptimizationsLoopLevel" to handle the graph optimization loop. * When set to 2 or above it will loop until no more optimizations are applied at any level starting Level 2 and above. ``` Level 2 --> Level 3 --> InsertCastTransforms --> Level 4 ^ | | "Loop" | | | --------------------------------------------------- ``` * When set to 1 (default) it will loop until no more optimizations are applied at Level 4 only. ``` Level 2 --> Level 3 --> InsertCastTransforms --> Level 4 ^ | | "Loop only depending on Level 4" | | | --------------------------------------------------- ``` * When set to 0 it disables the loop. ``` Level 2 --> Level 3 --> InsertCastTransforms --> Level 4 ^ | | "No Loop" | | | X xxxxxxxxxxx X ``` ### Documentation We have not added any details related to Level 4 in the [Graph Optimizations in ONNX Runtime](https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html) documentation. ### OLD PR This PR is created following a thorough discussion on the [OLD PR](https://github.com/microsoft/onnxruntime/pull/24175). Signed-off-by: Sunny Shukla <sunny.shukla@intel.com>
Parents
Loading