Moves clamp from autodiff cpp to symbolic script (#23927)
Summary:
This PR:
- Moves clamp from autodiff cpp to symbolic script
- Adds an additional tuple lowering pass to the graph executor
- Updates clamp backwards to be maximally gradient preserving
Moving clamp to symbolic script presented two challenges:
- When the backward graph is defined the branch taken in the conditional is known, but communicating this information to the Jit is a little tricky. It turns out the Jit has a quirk where variables that can be None at the time of graph instantiation are treated as constants, so testing min and max against None lets the Jit instantiate only one path branch. It might be more natural to select different backward functions for these cases, but that is not yet supported.
- Moving clamp to symbolic script introduced an extra tuple construction and immediate unpacking which prevented fusion. This was dealt with by adding an additional tuple removal pass. This issue could appear whenever a symbolic script's return value was defined in an if statement, which made the Jit see the unpacked tuple as being constructed from an if, not a TupleConstruct. The graph is later optimized but tuple lowering was not performed again after these optimizations.
Moving clamp to symbolic script also adds some explicit conversions to float in graphs which it appears, but these seem harmless.
If clamp were simply moved to symbolic script then its backward graphs would look like this:
`graph(%0 : Float(*, *),
%1 : AutogradZeroTensor,
%2 : Float(*, *),
%3 : int[]?,
%4 : Scalar?,
%5 : int):
%6 : None = prim::Constant() # <string>:5:31
%7 : float = aten::Float(%5) # <string>:12:37
%8 : Float(*, *) = prim::FusionGroup_0(%0, %2, %7)
%9 : (Float(*, *), None, None) = prim::TupleConstruct(%8, %6, %6)
%10 : Float(*, *), %11 : None, %12 : None = prim::TupleUnpack(%9)
return (%10)
with prim::FusionGroup_0 = graph(%0 : Float(*, *),
%1 : Float(*, *),
%2 : float):
%3 : Bool(*, *) = aten::le(%1, %2) # <string>:12:29
%mask.5 : Float(*, *) = aten::type_as(%3, %1) # <string>:12:29
%5 : Float(*, *) = aten::mul(%0, %mask.5) # <string>:13:28
return (%5)`
And adding the additional pass to remove tuples eliminates the prim::TupleConstruct and prim::TupleUnpack. Keeping these included previously would cause test_fuser_iou to fail because multiple fusion groups would be created. Since https://github.com/pytorch/pytorch/issues/23372 this test is disabled, however. When enabled the relevant portion of its graph is now:
`%59 : float = aten::Float(%26) # <string>:314:38
%60 : float = aten::Float(%27) # <string>:314:61
%61 : int[] = aten::size(%14) # <string>:41:99
%62 : int[] = aten::size(%11) # <string>:42:100
%63 : int[] = aten::size(%15) # <string>:41:99
%64 : int[] = aten::size(%12) # <string>:42:100
%65 : Tensor, %66 : Tensor, %67 : Tensor, %68 : Tensor, %69 : Tensor, %70 : Tensor, %71 : Tensor, %72 : Tensor, %73 : Double(*, *) = prim::FusionGroup_0(%w.1, %13, %16, %23, %h.1, %54, %inter.1, %0, %12, %15, %18, %17, %29, %11, %14, %60, %59)
%74 : Tensor = aten::_grad_sum_to_size(%73, %53)
%75 : Tensor = aten::_grad_sum_to_size(%73, %52)
%grad_self.10 : Tensor = aten::_grad_sum_to_size(%65, %61) # <string>:41:30
%grad_other.10 : Tensor = aten::_grad_sum_to_size(%66, %62) # <string>:42:31
%78 : Tensor = prim::FusionGroup_1(%grad_self.10, %74, %36)
%79 : Tensor = prim::FusionGroup_2(%grad_other.10, %75, %44)
%grad_self.14 : Tensor = aten::_grad_sum_to_size(%67, %21) # <string>:33:30
%grad_other.14 : Tensor = aten::_grad_sum_to_size(%68, %22) # <string>:34:31
%grad_self.12 : Tensor = aten::_grad_sum_to_size(%69, %63) # <string>:41:30
%grad_other.12 : Tensor = aten::_grad_sum_to_size(%70, %64) # <string>:42:31
%grad_self.16 : Tensor = aten::_grad_sum_to_size(%71, %19) # <string>:33:30
%grad_other.16 : Tensor = aten::_grad_sum_to_size(%72, %20) # <string>:34:31
%86 : Tensor, %87 : Tensor = prim::FusionGroup_3(%grad_self.12, %grad_self.16, %74, %39)
%88 : Tensor, %89 : Tensor = prim::FusionGroup_4(%grad_other.12, %grad_other.16, %75, %47)
return (%79, %88, %89, %78, %86, %87, %grad_self.14, %grad_other.14)`
Which I think is expected/desired.
Finally, this implementation of clamp backwards is "maximally gradient preserving," which simply means that elements on the boundary now receive gradients. For example, if an element of a tensor is 5 and the clamp is to [2, 5], then that element will now receive a gradient. The prior implementation would zero these gradients. See https://github.com/pytorch/pytorch/issues/7002 for a discussion on preserving gradients.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23927
Test Plan: Existing tests provided sufficient coverage.
Differential Revision: D16739740
Pulled By: mruberry
fbshipit-source-id: c94291d20e1f3f25197afc7b74dc61aeb204b074