pytorch
f88a3fff - Set requires_gradient to help autodiff to prune unneeded gradients (#54374)

Commit
3 years ago
Set requires_gradient to help autodiff to prune unneeded gradients (#54374) Summary: Fixes https://github.com/pytorch/pytorch/issues/54040 `prim::RequiresGradCheck` guarantees that requires_grad properties of input tensors will match the profiled, otherwise a fallback path will be triggered. This allow us to prune off gradients in backward graph for inputs that don't need gradients. We transfer requires_grad properties from inputs to the `prim::DifferentiableGraph` onto inputs to the differentiable graph. Autodiff will inspect these properties and prune off gradients that aren't required Pull Request resolved: https://github.com/pytorch/pytorch/pull/54374 Reviewed By: H-Huang Differential Revision: D27369251 Pulled By: Krovatkin fbshipit-source-id: 2bce7a2d7f2ec091db9bf4c4b91d8b29edd5be11
Author
Parents
Loading