Fixes forward AD codegen for multiple formulas (#68535)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/67367
- Adds check to make sure forward grad itself does not have forward grad at the same level
- Verify with `python test/test_ops.py -k test_forward_mode_AD_linalg_eigh_cpu_float64` that it fails the check before, but passes after the codegen update
Before:
```
if (_any_has_forward_grad_eigenvalues) {
auto self_t_raw = toNonOptFwGrad(self);
auto self_t = self_t_raw.defined() ? self_t_raw : at::zeros_like(toNonOptTensor(self));
auto eigenvalues_new_fw_grad = eigh_jvp_eigenvalues(self_t, eigenvalues, eigenvectors);
if (eigenvalues_new_fw_grad.defined()) {
// The hardcoded 0 here will need to be updated once we support multiple levels.
eigenvalues._set_fw_grad(eigenvalues_new_fw_grad, /* level */ 0, /* is_inplace_op */ false);
}
}
if (_any_has_forward_grad_eigenvectors) {
auto self_t_raw = toNonOptFwGrad(self);
auto self_t = self_t_raw.defined() ? self_t_raw : at::zeros_like(toNonOptTensor(self));
auto eigenvectors_new_fw_grad = eigh_jvp_eigenvectors(self_t, eigenvalues, eigenvectors);
if (eigenvectors_new_fw_grad.defined()) {
// The hardcoded 0 here will need to be updated once we support multiple levels.
eigenvectors._set_fw_grad(eigenvectors_new_fw_grad, /* level */ 0, /* is_inplace_op */ false);
}
}
```
After:
```
c10::optional<at::Tensor> eigenvalues_new_fw_grad_opt = c10::nullopt;
if (_any_has_forward_grad_eigenvalues) {
auto self_t_raw = toNonOptFwGrad(self);
auto self_t = self_t_raw.defined() ? self_t_raw : at::zeros_like(toNonOptTensor(self));
eigenvalues_new_fw_grad_opt = eigh_jvp_eigenvalues(self_t, eigenvalues, eigenvectors);
}
c10::optional<at::Tensor> eigenvectors_new_fw_grad_opt = c10::nullopt;
if (_any_has_forward_grad_eigenvectors) {
auto self_t_raw = toNonOptFwGrad(self);
auto self_t = self_t_raw.defined() ? self_t_raw : at::zeros_like(toNonOptTensor(self));
eigenvectors_new_fw_grad_opt = eigh_jvp_eigenvectors(self_t, eigenvalues, eigenvectors);
}
if (eigenvalues_new_fw_grad_opt.has_value() && eigenvalues_new_fw_grad_opt.value().defined()) {
// The hardcoded 0 here will need to be updated once we support multiple levels.
eigenvalues._set_fw_grad(eigenvalues_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ false);
}
if (eigenvectors_new_fw_grad_opt.has_value() && eigenvectors_new_fw_grad_opt.value().defined()) {
// The hardcoded 0 here will need to be updated once we support multiple levels.
eigenvectors._set_fw_grad(eigenvectors_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ false);
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68535
Reviewed By: ngimel
Differential Revision: D32536089
Pulled By: soulitzer
fbshipit-source-id: a3f288540e2d78a4a9ec4bd66d2c0f0e65dd72cd