[tensorexpr] Add PYTORCH_TENSOREXPR_DONT_FUSE env variable to disable fusion on specified operators - fixed #50757 (#55650)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55650
Test Plan:
Imported from OSS
$ python local/fusion.py
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%c.1 : Tensor):
%33 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %34 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %35 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %36 : bool = prim::TypeCheck[types=[Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)]](%c.1, %a.1, %b.1)
%37 : Tensor = prim::If(%36)
block0():
%18 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = prim::TensorExprGroup_0(%33, %34, %35)
-> (%18)
block1():
%44 : Function = prim::Constant[name="fallback_function", fallback=1]()
%45 : (Tensor) = prim::CallFunction(%44, %c.1, %a.1, %b.1)
%46 : Tensor = prim::TupleUnpack(%45)
-> (%46)
return (%37)
with prim::TensorExprGroup_0 = graph(%c.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
%b.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)):
%10 : int = prim::Constant[value=1]()
%11 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %10) # local/fusion.py:13:15
%9 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::mul(%a.1, %b.1) # local/fusion.py:13:19
%6 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::mul(%9, %c.1) # local/fusion.py:13:19
%3 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%11, %6, %10) # local/fusion.py:13:15
return (%3)
```
$ PYTORCH_TENSOREXPR_DONT_FUSE="add" python local/fusion.py
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%c.1 : Tensor):
%3 : int = prim::Constant[value=1]()
%6 : Tensor = aten::add(%a.1, %b.1, %3) # local/fusion.py:13:15
%27 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %28 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %29 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %30 : bool = prim::TypeCheck[types=[Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)]](%c.1, %a.1, %b.1)
%31 : Tensor = prim::If(%30)
block0():
%18 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = prim::TensorExprGroup_0(%27, %28, %29)
-> (%18)
block1():
%35 : Function = prim::Constant[name="fallback_function", fallback=1]()
%36 : (Tensor) = prim::CallFunction(%35, %c.1, %a.1, %b.1)
%37 : Tensor = prim::TupleUnpack(%36)
-> (%37)
%15 : Tensor = aten::add(%6, %31, %3) # local/fusion.py:13:15
return (%15)
with prim::TensorExprGroup_0 = graph(%c.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
%b.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)):
%5 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::mul(%a.1, %b.1) # local/fusion.py:13:19
%2 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::mul(%5, %c.1) # local/fusion.py:13:19
return (%2)
```
Reviewed By: navahgar
Differential Revision: D27667232
Pulled By: huiguoo
fbshipit-source-id: 002ddbb49760b42d52e0605ca3967f4fa36f4e3f