pytorch
25b42aef - [Inductor] Using PythonPrinter for SymInt arguments codegen for FallbackKernal (#100606)

Commit
2 years ago
[Inductor] Using PythonPrinter for SymInt arguments codegen for FallbackKernal (#100606) Fixes Meta internal user case. Repro: ``` import torch import torch._inductor torch._inductor.config.disable_cpp_codegen = True @torch.compile(backend="inductor", dynamic=True) def func(input: torch.Tensor) -> torch.Tensor: n = input.size(-1) output = input + int(n * 0.2) + 1 return output, input + 1 print(func(torch.rand(5, device="cpu"))) print(func(torch.rand(10, device="cpu"))) ``` Error: ``` Traceback (most recent call last): File "/scratch/ybliang/work/repos/debug/debug7.py", line 20, in <module> print(func(torch.rand(10, device="cpu"))) File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/eval_frame.py", line 280, in _fn return fn(*args, **kwargs) File "/scratch/ybliang/work/repos/debug/debug7.py", line 12, in func @torch.compile(backend="inductor", dynamic=True) File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/eval_frame.py", line 280, in _fn return fn(*args, **kwargs) File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/external_utils.py", line 17, in inner return fn(*args, **kwargs) File "/scratch/ybliang/work/repos/pytorch/torch/_functorch/aot_autograd.py", line 3346, in forward return compiled_fn(full_args) File "/scratch/ybliang/work/repos/pytorch/torch/_functorch/aot_autograd.py", line 1260, in g return f(*args) File "/scratch/ybliang/work/repos/pytorch/torch/_functorch/aot_autograd.py", line 2210, in runtime_wrapper all_outs = call_func_with_args( File "/scratch/ybliang/work/repos/pytorch/torch/_functorch/aot_autograd.py", line 1285, in call_func_with_args out = normalize_as_list(f(args)) File "/scratch/ybliang/work/repos/pytorch/torch/_functorch/aot_autograd.py", line 1372, in rng_functionalization_wrapper return compiled_fw(args) File "/tmp/torchinductor_ybliang/od/codk4bo4oqmjiec35zlz2rsildcix33lsxpdcy7pi6p4nvdrofpu.py", line 27, in call buf0 = torch.ops.aten.add.Tensor(arg1_1, floor(0.2*s0)) NameError: name 'floor' is not defined ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/100606 Approved by: https://github.com/xw285cornell
Author
Committer
Parents
Loading