pytorch
263a15c5 - [tensorexpr] Add PYTORCH_TENSOREXPR_DONT_FUSE env variable to disable fusion on specified operators - fixed #50757 (#55650)

Commit
3 years ago
[tensorexpr] Add PYTORCH_TENSOREXPR_DONT_FUSE env variable to disable fusion on specified operators - fixed #50757 (#55650) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55650 Test Plan: Imported from OSS $ python local/fusion.py ``` graph(%a.1 : Tensor, %b.1 : Tensor, %c.1 : Tensor): %33 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %34 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %35 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %36 : bool = prim::TypeCheck[types=[Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)]](%c.1, %a.1, %b.1) %37 : Tensor = prim::If(%36) block0(): %18 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = prim::TensorExprGroup_0(%33, %34, %35) -> (%18) block1(): %44 : Function = prim::Constant[name="fallback_function", fallback=1]() %45 : (Tensor) = prim::CallFunction(%44, %c.1, %a.1, %b.1) %46 : Tensor = prim::TupleUnpack(%45) -> (%46) return (%37) with prim::TensorExprGroup_0 = graph(%c.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %b.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)): %10 : int = prim::Constant[value=1]() %11 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %10) # local/fusion.py:13:15 %9 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::mul(%a.1, %b.1) # local/fusion.py:13:19 %6 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::mul(%9, %c.1) # local/fusion.py:13:19 %3 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%11, %6, %10) # local/fusion.py:13:15 return (%3) ``` $ PYTORCH_TENSOREXPR_DONT_FUSE="add" python local/fusion.py ``` graph(%a.1 : Tensor, %b.1 : Tensor, %c.1 : Tensor): %3 : int = prim::Constant[value=1]() %6 : Tensor = aten::add(%a.1, %b.1, %3) # local/fusion.py:13:15 %27 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %28 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %29 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %30 : bool = prim::TypeCheck[types=[Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)]](%c.1, %a.1, %b.1) %31 : Tensor = prim::If(%30) block0(): %18 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = prim::TensorExprGroup_0(%27, %28, %29) -> (%18) block1(): %35 : Function = prim::Constant[name="fallback_function", fallback=1]() %36 : (Tensor) = prim::CallFunction(%35, %c.1, %a.1, %b.1) %37 : Tensor = prim::TupleUnpack(%36) -> (%37) %15 : Tensor = aten::add(%6, %31, %3) # local/fusion.py:13:15 return (%15) with prim::TensorExprGroup_0 = graph(%c.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), %b.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)): %5 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::mul(%a.1, %b.1) # local/fusion.py:13:19 %2 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::mul(%5, %c.1) # local/fusion.py:13:19 return (%2) ``` Reviewed By: navahgar Differential Revision: D27667232 Pulled By: huiguoo fbshipit-source-id: 002ddbb49760b42d52e0605ca3967f4fa36f4e3f
Author
Parents
Loading