fix(fx): make all `make_fx` invocations isolated (opaque to higher `make_fx` invocations) by default (#93290)
Fixes https://github.com/pytorch/pytorch/issues/88996#issuecomment-1409174554
Example code:
```python
import torch
from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx
@torch.fx.wrap
def func(a, b):
return b.expand([1, a.shape[0], b.shape[-1]])
a = torch.randn(3, 4)
b = torch.randn(4)
class TestMode(torch.overrides.TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs={}):
if torch.overrides.resolve_name(func) in ["torch.Tensor.expand"]:
print(f"TestMode: {func} {args} {kwargs}")
wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs)
gm = make_fx(wrapped, tracing_mode="real")(all_args)
return func(*args, **kwargs)
with TestMode():
gm = make_fx(func, tracing_mode="symbolic")(a, b)
gm.graph.print_tabular()
```
Before:
```
opcode name target args kwargs
------------- ---------- ------------------- -------------------------------- --------
placeholder a_1 a_1 () {}
placeholder b_1 b_1 () {}
call_function detach aten.detach.default (b_1,) {}
call_function detach_1 aten.detach.default (detach,) {}
call_function sym_size aten.sym_size (a_1, 0) {}
call_function sym_size_1 aten.sym_size (b_1, 0) {}
call_function expand aten.expand.default (b_1, [1, sym_size, sym_size_1]) {}
call_function detach_2 aten.detach.default (expand,) {}
call_function expand_1 aten.expand.default (b_1, [1, sym_size, sym_size_1]) {}
output output output (expand_1,) {}
```
After:
```
opcode name target args kwargs
------------- ---------- ------------------- -------------------------------- --------
placeholder a_1 a_1 () {}
placeholder b_1 b_1 () {}
call_function sym_size aten.sym_size (a_1, 0) {}
call_function sym_size_1 aten.sym_size (b_1, 0) {}
call_function expand aten.expand.default (b_1, [1, sym_size, sym_size_1]) {}
output output output (expand_1,) {}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93290
Approved by: https://github.com/ezyang