Improve fake mode support by adding fake_context to ExportOutput (#105247)
Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the initial model state dict (including non persistent buffers) would not be reused by `ExportOutput.save` during ONNX proto creation.
That is not necessarily a bug because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want. However, it can be a hassle because if the user doesn't provide a full state, including non-persistent buffers, the resulting ONNX graph would require the missing buffers to be specified as input during execution.
With this PR, the `enable_fake_mode` is improved to capture the initial model state including any non-persistent buffer. This reference (not actual data) is persisted within `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode`. The result is an ONNX graph with all model state without user having to specify the non-persistent buffers.
This helps addressing https://github.com/pytorch/pytorch/issues/105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured.
ps: https://github.com/pytorch/pytorch/issues/105464 tracks pending tasks/limitations from this PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105247
Approved by: https://github.com/BowenBao