[onnxrt, dynamo] Enable custom ONNX model transforms in `onnxrt` dynamo backend (#120854)
A global transorm list is created. All backend instances call the transform functions in that list sequentially to modify the exported ONNX model before sending model to ORT session. For example, `record_onnx_model_transform` below is a no-op transform and only records the ONNX graphs sent to ONNXRuntime.
```python
recorded_models = []
def record_onnx_model_transform(onnx_model):
# Record the ONNX model seen by the transform.
recorded_models.append(onnx_model)
from torch.onnx import (
register_backend_graph_transform,
unregister_backend_graph_transform,
)
# Register so that `onnxrt` backend calls it to modify ONNX model.
register_backend_graph_transform(record_onnx_model_transform)
def example_model(x: torch.Tensor):
y = torch.sigmoid(x)
z = x + y
return z
# During the compilation, the exported ONNX model will be
# modified by calling `record_onnx_model_transform` before
# sending the model to `onnxruntime.InferenceSession`.
compiled_model = torch.compile(
example_model,
backend="onnxrt",
dynamic=True,
)
# Now, `recorded_models` should contain one `onnx.ModelProto` representing
# `example_model(x: torch.Tensor)`.
# Remove the pass when not needed. If `record_onnx_model_transform` is not
# removed, it will be applied to all models compiled by `backend="onnxrt"`.
unregister_backend_graph_transform(record_onnx_model_transform)
```
In the future, we plan to use this mechanism to register all graph transforms such ash graph fusion and general ONNX optimization for `onnxrt`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120854
Approved by: https://github.com/BowenBao, https://github.com/thiagocrepaldi