pytorch
86b7aa26 - Fix FakeTensorProp on Module with Parameters or Buffers (#88700)

Commit
2 years ago
Fix FakeTensorProp on Module with Parameters or Buffers (#88700) In `FakeTensorMode.__torch_dispatch__`, the output is now always computed by meta kernels in ```python try: with in_kernel_invocation_manager(self): r = func(*args, **kwargs) # <----- "r" can be a real tensor. except NotImplementedError as not_implemented_error: # no meta kernel registered, fallback to kernel for the device if not self.allow_fallback_kernels: raise not_implemented_error return run_fallback_kernel(self, func, args, kwargs, not_implemented_error) return self.wrap_meta_outputs_with_default_device_logic(r, func, args, kwargs) ``` For example, I observed a CPU tensor is generated when executing `aten.addmm` when running `FakeTensorProp`. Therefore, I'd like to allow `FakeTensorMode` to wrap real tensor as `FakeTensor` during the computation. Does this PR look a good direction to fix this problem? If yes, I can go ahead and add some tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88700 Approved by: https://github.com/eellison, https://github.com/ezyang
Author
Committer
Parents
Loading