Memory optimization refactor and refinement (#17481)
### Memory optimization refactor and refinement
Currently memory optimizer runs graph transformations and print
recompute opportunities in INFO level, while ORT backend has many many
INFO level logs making users hard to find those information. So we are
looking for a Python binding API to retrieve the memory optimization
opportunities instead of depending on the MemoryOptimizer's default
logging.
Then we can print ORTModule feature statistics using this information.
Also, with such an API, we can create an ORT session created, where
allocation plan is done, the analysis will consider buffer reuse as
well. This can void giving some recomputation subgraphs that are reusing
other subgraphs' output buffers.
Check
https://github.com/microsoft/onnxruntime/blob/pengwa/add_devinfo_level/docs/Memory_Optimizer.md
for the new flow using `MemoryOptimizer`.
This pull requests made following refactoring:
1. Print the log in ORTModule Python script, along with ORTModule
feature enabling stats. This is implemented by exposing an API
`get_serialized_ortmodule_memory_stat` to retrieve the memory
optimization opportunities.
2. We are analyzing memory optimization opportunities considering ORT
memory planning. This is done by firstly creating the execution graph
without enabling MemoryOptimizer, then we call
`execution_agent.get_serialized_ortmodule_memory_stat` which internally
will consider the session memory allocation planner when analyzing
memory optimization opportunity. As a direct result, the memory
optimization opportunities can show those stashed activations that are
reusing other buffers.
3. Move recompute analysis logic from memory_optimizer.h/cc to
recompute_analysis.h/cc.
4. Abstract optimization strategies for their own implementation. This
will make introducing new strategies (for example compression and
decompression ) easier.
New logging matrix (INFO Level), in WARNING level, the details will NOT
show.
```
2023-09-13 13:25:09,249 orttraining.rank-0 [WARNING] -
***** ONNX Runtime Training (ORTModule) is accelerating your model *****
ORTModule is enabled with following features ON/OFF for [training] mode:
ATen Executor : ON : Dispatch ATen operators to ORT's ATen executor
Cast Propagation : ON : Level 1 enabled
Custom Function : ON : Support custom torch.autograd.Function export and execution
Memory Optimizer : ON : RecomputeConfig: Reshape+Where+BiasSoftmax+:1:-1,Cast+:1:-1, ProbeLevel: 1, available configs:
Config Freq Saving(B) Saving Symbolic(Bytes)
- Plan 1 : ON : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : ON : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0)
- Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 5 : OFF : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
Compute Optimizer : ON : Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0
- FLOPReduction : ON : Reduce FLOPs by upstreaming shrinking-sized ops
Auto Fallback : ON : Fallback to PyTorch when encountering unsupported ops
TritonOp Enabled : OFF : ORT will switch to Triton for executing some ops to further accelerate training.
ZeRO Stage3 Support : OFF : Enable/Disable with env ORTMODULE_ENABLE_ZERO_STAGE3=1/0
Total ORT initialization overhead is 10.73s where export takes 8.39s.
Other overhead details: graph builder init takes 0.06s, runtime detection takes 0.01s, graph building takes 0.31s, session creation takes 1.96s
Versions: ONNX Runtime - 1.16.0+cu118, ONNX - 1.11.0
Note 1: use comma to enable multiple plans at the same time.
export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...
Note 2: saving is calculated based on the 1st batch symbolic dim values:
inputs_input_ids_dim0=1,
inputs_input_ids_dim1=1024,
inputs_attention_mask_dim0=1,
inputs_attention_mask_dim1=1024,
inputs_labels_dim0=1,
inputs_labels_dim1=1024,
************************************************************************
```
If DEVINFO level is enabled, then more details about the memory
optimizations are printed.
```
MemoryInsight Summary - User config: BiasGelu+:1:-1,Cast+:2:-1
==========================================================================================================================================
|Freq | Memory Optimization Opportunities (Clustered by node-level activation patterns) |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|3 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+Add+Reshape+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+Reshape+:1:-1 |
| | Stashed Activations: |
| | - ReuseFreq : Output 0(3), |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 32 x 240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+:1:-1 |
| | Stashed Activations: |
| | - ReuseFreq : Output 0(2), |
| | - Output 0 : [ x 2560 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Cast+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+Where+BiasSoftmax+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+BiasSoftmax+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph BiasGelu+ |
| | Status : Enabled, requested count=-1, actual applied count=2 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+Add+FusedMatMul+Add+Add+Add+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 2560 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+Where+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Cast+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
| | |
| |>>Option 2 : RecomputeWithCompromise subgraph Cast+ |
| | Status : Enabled, requested count=-1, actual applied count=1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 50% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph BiasSoftmax+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=BiasSoftmax+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph BiasGelu+ |
| | Status : Enabled, requested count=-1, actual applied count=1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Add+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Add+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 2560 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
==========================================================================================================================================
Note: use comma as a separator for enabling more than one subgraphs.
************************************************************************
```
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->