pytorch
0bb2b015 - Add forward mode AD to in-place foreach functions (#100695)

Commit
2 years ago
Add forward mode AD to in-place foreach functions (#100695) Awkwardly implement fwd AD by - adding a few `CodeTemplate`s - allowing for the cases where a variable is initialized with i-th element of TensorList <!-- ### TODOs: - [x] ~~remove the first `_any_has_forward_grad_self`~~ make it a vector of bool - [ ] clean up mapping of names from reference impl to foreach impl - [x] add tests --> ### Rel: - #58833 - #96405 --- `_foreach_addcmul_.ScalarList` from `VariableType` ```c++ void _foreach_addcmul__ScalarList(c10::DispatchKeySet ks, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars) { auto self_ = unpack(self, "self", 0); auto tensor1_ = unpack(tensor1, "tensor1", 1); auto tensor2_ = unpack(tensor2, "tensor2", 2); [[maybe_unused]] auto _any_requires_grad = compute_requires_grad( self, tensor1, tensor2 ); std::vector<bool> _any_has_forward_grad_self(self.size()); for (const auto& i : c10::irange(self.size())) { _any_has_forward_grad_self[i] = isFwGradDefined(self[i]) || isFwGradDefined(tensor1[i]) || isFwGradDefined(tensor2[i]); } std::vector<c10::optional<at::Tensor>> original_selfs(self.size()); std::vector<std::shared_ptr<AddcmulBackward0>> grad_fns; if (_any_requires_grad) { for (const auto& i : c10::irange( self.size() )) { const auto ith_requires_grad = compute_requires_grad(self[i], tensor1[i], tensor2[i]); check_inplace(self[i], ith_requires_grad); grad_fns.push_back([&]() -> std::shared_ptr<AddcmulBackward0> { if (!ith_requires_grad) { return nullptr; } else { auto grad_fn = std::shared_ptr<AddcmulBackward0>(new AddcmulBackward0(), deleteNode); grad_fn->set_next_edges(collect_next_edges( self[i], tensor1[i], tensor2[i] )); return grad_fn; } }()); } if (!grad_fns.empty()) { for (const auto& i : c10::irange(grad_fns.size())) { auto grad_fn = grad_fns[i]; if (grad_fn != nullptr) { grad_fn->self_scalar_type = self[i].scalar_type(); grad_fn->tensor1_scalar_type = tensor1[i].scalar_type(); if (grad_fn->should_compute_output(1)) { grad_fn->tensor2_ = SavedVariable(tensor2[i], false); } grad_fn->value = scalars[i]; if (grad_fn->should_compute_output(2)) { grad_fn->tensor1_ = SavedVariable(tensor1[i], false); } grad_fn->tensor2_scalar_type = tensor2[i].scalar_type(); } } } } #ifndef NDEBUG std::vector<c10::optional<Storage>> self__storage_saved(self_.size()); for (const Tensor& tensor : self_) self__storage_saved.push_back( tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt); std::vector<c10::intrusive_ptr<TensorImpl>> self__impl_saved(self_.size()); for (size_t i=0; i<self_.size(); i++) if (self_[i].defined()) self__impl_saved[i] = self_[i].getIntrusivePtr(); std::vector<c10::optional<Storage>> tensor1__storage_saved(tensor1_.size()); for (const Tensor& tensor : tensor1_) tensor1__storage_saved.push_back( tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt); std::vector<c10::intrusive_ptr<TensorImpl>> tensor1__impl_saved(tensor1_.size()); for (size_t i=0; i<tensor1_.size(); i++) if (tensor1_[i].defined()) tensor1__impl_saved[i] = tensor1_[i].getIntrusivePtr(); std::vector<c10::optional<Storage>> tensor2__storage_saved(tensor2_.size()); for (const Tensor& tensor : tensor2_) tensor2__storage_saved.push_back( tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt); std::vector<c10::intrusive_ptr<TensorImpl>> tensor2__impl_saved(tensor2_.size()); for (size_t i=0; i<tensor2_.size(); i++) if (tensor2_[i].defined()) tensor2__impl_saved[i] = tensor2_[i].getIntrusivePtr(); #endif { at::AutoDispatchBelowAutograd guard; at::redispatch::_foreach_addcmul_(ks & c10::after_autograd_keyset, self_, tensor1_, tensor2_, scalars); } #ifndef NDEBUG for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (self__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(self_)) TORCH_INTERNAL_ASSERT(self__storage_saved[i].value().is_alias_of(self_[i].storage())); } for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (self__impl_saved[i] && !at::impl::tensorlist_has_dispatch(self_)) TORCH_INTERNAL_ASSERT(self__impl_saved[i] == self_[i].getIntrusivePtr()); } for (size_t i=0; i<tensor1_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (tensor1__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(tensor1_)) TORCH_INTERNAL_ASSERT(tensor1__storage_saved[i].value().is_alias_of(tensor1_[i].storage())); } for (size_t i=0; i<tensor1_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (tensor1__impl_saved[i] && !at::impl::tensorlist_has_dispatch(tensor1_)) TORCH_INTERNAL_ASSERT(tensor1__impl_saved[i] == tensor1_[i].getIntrusivePtr()); } for (size_t i=0; i<tensor2_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (tensor2__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(tensor2_)) TORCH_INTERNAL_ASSERT(tensor2__storage_saved[i].value().is_alias_of(tensor2_[i].storage())); } for (size_t i=0; i<tensor2_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (tensor2__impl_saved[i] && !at::impl::tensorlist_has_dispatch(tensor2_)) TORCH_INTERNAL_ASSERT(tensor2__impl_saved[i] == tensor2_[i].getIntrusivePtr()); } #endif if (!grad_fns.empty()) { auto differentiable_outputs = flatten_tensor_args( self ); TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size()); for (const auto& i : c10::irange(grad_fns.size())) { auto grad_fn = grad_fns[i]; if (grad_fn != nullptr) { rebase_history(differentiable_outputs[i], grad_fns[i]); } } } std::vector<c10::optional<at::Tensor>> self_new_fw_grad_opts(self.size(), c10::nullopt); for (const auto& i : c10::irange(self_new_fw_grad_opts.size())) { if (_any_has_forward_grad_self[i]) { auto self_t_raw = toNonOptFwGrad(self[i]); auto self_tensor = toNonOptTensor(self[i]); auto self_t = (self_t_raw.defined() || !self_tensor.defined()) ? self_t_raw : at::zeros(self_tensor.sizes(), self_tensor.options()); auto tensor1_t_raw = toNonOptFwGrad(tensor1[i]); auto tensor1_tensor = toNonOptTensor(tensor1[i]); auto tensor1_t = (tensor1_t_raw.defined() || !tensor1_tensor.defined()) ? tensor1_t_raw : at::_efficientzerotensor(tensor1_tensor.sizes(), tensor1_tensor.options()); auto tensor1_p = toNonOptPrimal(tensor1[i]); auto tensor2_t_raw = toNonOptFwGrad(tensor2[i]); auto tensor2_tensor = toNonOptTensor(tensor2[i]); auto tensor2_t = (tensor2_t_raw.defined() || !tensor2_tensor.defined()) ? tensor2_t_raw : at::_efficientzerotensor(tensor2_tensor.sizes(), tensor2_tensor.options()); auto tensor2_p = toNonOptPrimal(tensor2[i]); self_t = GradMode::is_enabled() ? self_t.clone() : self_t; self_new_fw_grad_opts[i] = self_t_raw.defined() ? self_t_raw.copy_(self_t + maybe_multiply(tensor1_t * tensor2_p, scalars[i]) + maybe_multiply(tensor2_t * tensor1_p, scalars[i])) : self_t + maybe_multiply(tensor1_t * tensor2_p, scalars[i]) + maybe_multiply(tensor2_t * tensor1_p, scalars[i]); } } for (const auto& i : c10::irange(self_new_fw_grad_opts.size())) { auto& self_new_fw_grad_opt = self_new_fw_grad_opts[i]; if (self_new_fw_grad_opt.has_value() && self_new_fw_grad_opt.value().defined() && self[i].defined()) { // The hardcoded 0 here will need to be updated once we support multiple levels. self[i]._set_fw_grad(self_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ true); } } } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/100695 Approved by: https://github.com/soulitzer
Author
Committer
Parents
Loading