pytorch
9ad0de3c - Rework requires_grad on DifferentiableGraphOp (#57575)

Commit
3 years ago
Rework requires_grad on DifferentiableGraphOp (#57575) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57575 This PR does two things: 1. reverts "Manual revert of D27369251 (https://github.com/pytorch/pytorch/commit/f88a3fff65b35cb6d4968fc54a9a0a1314a9a3b7) (#56080)" in commit 92a09fb87a567100122b872613344d3a422abc9f. 2. fixing DifferentiableGraph output with wrong requires_grad flag Fixing requires_grad on outputs from DifferentiableGraph, the proper flag is retrieved from profiling information. We previously only retrieves the profiling information on the first profile node in all its uses. However, in case where control flows are present, we need to iteratively search for profile node with profiling information available, in case the first use is in an inactive code path. e.g. ``` graph(%0 : Tensor, %1 : Bool): ..., %2 : Tensor = prim::DifferentiableGraph_0(%0) %3 : Tensor = prim::If(%1) block0(): %4 : Tensor = prim::DifferentiableGraph_1(%2) -> (%4) block1(): %5 : Tensor = prim::DifferentiableGraph_2(%2) -> (%5) -> (%3) with prim::DifferentiableGraph_0 = graph(%0 : Tensor): ... %out : Tensor = aten::operation(...) ... return (..., %out) with prim::DifferentiableGraph_1 = graph(%0 : Tensor): %temp : Tensor = prim::profile[profiled_type=Tensor](%0) ... with prim::DifferentiableGraph_2 = graph(%0 : Tensor): %temp : Tensor = prim::profile[profiled_type=Float(...)](%0) ... ``` Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D29038773 Pulled By: Krovatkin fbshipit-source-id: 6c0a851119f6b8f2f1afae5c74532407aae238fe
Author
Parents
Loading