Rework requires_grad on DifferentiableGraphOp (#57575)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57575
This PR does two things:
1. reverts "Manual revert of D27369251 (https://github.com/pytorch/pytorch/commit/f88a3fff65b35cb6d4968fc54a9a0a1314a9a3b7) (#56080)" in commit
92a09fb87a567100122b872613344d3a422abc9f.
2. fixing DifferentiableGraph output with wrong requires_grad flag
Fixing requires_grad on outputs from DifferentiableGraph, the proper flag is
retrieved from profiling information. We previously only retrieves the profiling
information on the first profile node in all its uses. However, in case where
control flows are present, we need to iteratively search for profile node with
profiling information available, in case the first use is in an inactive code
path.
e.g.
```
graph(%0 : Tensor,
%1 : Bool):
..., %2 : Tensor = prim::DifferentiableGraph_0(%0)
%3 : Tensor = prim::If(%1)
block0():
%4 : Tensor = prim::DifferentiableGraph_1(%2)
-> (%4)
block1():
%5 : Tensor = prim::DifferentiableGraph_2(%2)
-> (%5)
-> (%3)
with prim::DifferentiableGraph_0 = graph(%0 : Tensor):
...
%out : Tensor = aten::operation(...)
...
return (..., %out)
with prim::DifferentiableGraph_1 = graph(%0 : Tensor):
%temp : Tensor = prim::profile[profiled_type=Tensor](%0)
...
with prim::DifferentiableGraph_2 = graph(%0 : Tensor):
%temp : Tensor = prim::profile[profiled_type=Float(...)](%0)
...
```
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Differential Revision: D29038773
Pulled By: Krovatkin
fbshipit-source-id: 6c0a851119f6b8f2f1afae5c74532407aae238fe