pytorch
ece0002c - [ONNX] Disable autocast cache in exporter (#84219)

Commit
3 years ago
[ONNX] Disable autocast cache in exporter (#84219) This PR provides a temporary fix on #84092 in exporter to avoid more cases falling into this bug. A long-term fix will be provided later. A simple repro with torch.onnx.export is still under investigation, as torch.jit.trace() is not the API we call inside torch.onnx.export, and it may introduce the difference. Therefore, a test case is provided here only. A specific test one can use, ```python import torch import onnxruntime from onnxruntime.training.ortmodule import DebugOptions, LogLevel from onnxruntime.training.ortmodule import ORTModule class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.cv1 = torch.nn.Conv2d(3, 3, 5, 2, 1) def forward(self, x): x = self.cv1(x) return x x = torch.randn(10, 3, 20, 20) * 2 m = MyModule().eval() x = x.cuda() m = m.cuda() debug_options = DebugOptions(log_level=LogLevel.VERBOSE, save_onnx=True, onnx_prefix="ViT-B") m = ORTModule(m, debug_options=debug_options) with torch.cuda.amp.autocast(dtype=torch.float16, cache_enabled=True): loss = m(x) ``` AND make assertion fail in ORTModule https://github.com/microsoft/onnxruntime/blob/17ccd6fa02877a1c8d3201344137b1ca105b681d/orttraining/orttraining/python/training/ortmodule/_io.py#L578-L581 Without the fix, the user will see the weight/bias of Conv node becomes constant. Pull Request resolved: https://github.com/pytorch/pytorch/pull/84219 Approved by: https://github.com/BowenBao, https://github.com/thiagocrepaldi
Author
Committer
Parents
Loading