Support torch.onnx.dynamo_export within FakeTensorMode (#105477)
Currently, exporting a model to ONNX with fake tensor mode requires the
user to load data and model within `torch.onnx.enable_fake_mode` context,
but the actual call to `torch.onnx.dynamo_export` is done outside such
context.
With this PR, we enable `torch.onnx.dynamo_export` to be called either
within `torch.onnx.enable_fake_mode` or outside of it. This feature
required changes to the core PyTorch Dynamo, which were greatly
supported by @ezyang
In future steps we will determine which scenario we are going to
support, but for now we can use either to explore different options and
scenarios and asses their pros and cons.
This PR also creates a separate suite of tests for fake mode specific
scenarios (`TestFxToOnnxFakeTensorWithOnnxRuntime`).
It was done separately to decrease the test time, but we
could merge it with the default `TestFxToOnnxWithOnnxRuntime`. The
additional parameters are `load_checkpoint_during_init` and
`export_within_fake_mode`
With the newly added supported of nested export within fake mode, the
following scenarios are now supported:
```python
import torch
with torch.onnx.enable_fake_mode() as fake_context:
fake_args = create_args()
fake_kwargs = create_kwargs()
fake_model = create_model()
fake_model.load_state_dict(torch.load(tmp_checkpoint_file.name))
export_options = torch.onnx.ExportOptions(fake_context=fake_context)
# `torch.onnx.dynamo_export` called WITHIN `torch.onnx.enable_fake_mode`
export_output = torch.onnx.dynamo_export(
fake_model,
*fake_args,
**fake_kwargs,
export_options=export_options,
)
export_output.save("/path/to/model.onnx", model_state_dict=create_model())
```
If we decide to only support scenarios in which `torch._dynamo.export` is called within `FakeTensorMode`, then we can remove `fake_mode` argument from `torch._dynamo.export` as a follow-up task
ps: This PR is mostly Edward's https://github.com/pytorch/pytorch/pull/105468 + unit tests after an offline discussion
ps: https://github.com/pytorch/pytorch/issues/105464 tracks pending tasks/limitations from this PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105477
Approved by: https://github.com/ezyang, https://github.com/BowenBao