pytorch
913ac271 - Fixes forward AD codegen for multiple formulas (#68535)

Commit
4 years ago
Fixes forward AD codegen for multiple formulas (#68535) Summary: Fixes https://github.com/pytorch/pytorch/issues/67367 - Adds check to make sure forward grad itself does not have forward grad at the same level - Verify with `python test/test_ops.py -k test_forward_mode_AD_linalg_eigh_cpu_float64` that it fails the check before, but passes after the codegen update Before: ``` if (_any_has_forward_grad_eigenvalues) { auto self_t_raw = toNonOptFwGrad(self); auto self_t = self_t_raw.defined() ? self_t_raw : at::zeros_like(toNonOptTensor(self)); auto eigenvalues_new_fw_grad = eigh_jvp_eigenvalues(self_t, eigenvalues, eigenvectors); if (eigenvalues_new_fw_grad.defined()) { // The hardcoded 0 here will need to be updated once we support multiple levels. eigenvalues._set_fw_grad(eigenvalues_new_fw_grad, /* level */ 0, /* is_inplace_op */ false); } } if (_any_has_forward_grad_eigenvectors) { auto self_t_raw = toNonOptFwGrad(self); auto self_t = self_t_raw.defined() ? self_t_raw : at::zeros_like(toNonOptTensor(self)); auto eigenvectors_new_fw_grad = eigh_jvp_eigenvectors(self_t, eigenvalues, eigenvectors); if (eigenvectors_new_fw_grad.defined()) { // The hardcoded 0 here will need to be updated once we support multiple levels. eigenvectors._set_fw_grad(eigenvectors_new_fw_grad, /* level */ 0, /* is_inplace_op */ false); } } ``` After: ``` c10::optional<at::Tensor> eigenvalues_new_fw_grad_opt = c10::nullopt; if (_any_has_forward_grad_eigenvalues) { auto self_t_raw = toNonOptFwGrad(self); auto self_t = self_t_raw.defined() ? self_t_raw : at::zeros_like(toNonOptTensor(self)); eigenvalues_new_fw_grad_opt = eigh_jvp_eigenvalues(self_t, eigenvalues, eigenvectors); } c10::optional<at::Tensor> eigenvectors_new_fw_grad_opt = c10::nullopt; if (_any_has_forward_grad_eigenvectors) { auto self_t_raw = toNonOptFwGrad(self); auto self_t = self_t_raw.defined() ? self_t_raw : at::zeros_like(toNonOptTensor(self)); eigenvectors_new_fw_grad_opt = eigh_jvp_eigenvectors(self_t, eigenvalues, eigenvectors); } if (eigenvalues_new_fw_grad_opt.has_value() && eigenvalues_new_fw_grad_opt.value().defined()) { // The hardcoded 0 here will need to be updated once we support multiple levels. eigenvalues._set_fw_grad(eigenvalues_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ false); } if (eigenvectors_new_fw_grad_opt.has_value() && eigenvectors_new_fw_grad_opt.value().defined()) { // The hardcoded 0 here will need to be updated once we support multiple levels. eigenvectors._set_fw_grad(eigenvectors_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ false); } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/68535 Reviewed By: ngimel Differential Revision: D32536089 Pulled By: soulitzer fbshipit-source-id: a3f288540e2d78a4a9ec4bd66d2c0f0e65dd72cd
Author
Parents
Loading