pytorch
67401890 - per-Tensor `grad_fn` for in-place foreach functions (#96405)

Commit
1 year ago
per-Tensor `grad_fn` for in-place foreach functions (#96405) Generate a `grad_fn` for each (tuple of) `Tensor`(s) of the same index for `_foreach_foo_` and each `grad_fn` is `FooBackward`. The current status of foreach functions' backward support for the record: - out-place: Implemented, but no optimized implementations like their forward path - in-place: not implemented. I think this check https://github.com/pytorch/pytorch/blob/7eaaefafb3ce0e4a0a9e1eb647e340711973ec12/torchgen/api/autograd.py#L309-L311 is partly responsible but the difference of signature between out-place and in-place (see https://github.com/pytorch/pytorch/pull/96405#discussion_r1154690940) would prevent in-place from using out-place versions (the logic is around https://github.com/pytorch/pytorch/blob/7eaaefafb3ce0e4a0a9e1eb647e340711973ec12/torchgen/api/autograd.py#L495-L500) ```c++ void _foreach_abs_(c10::DispatchKeySet ks, at::TensorList self) { auto self_ = unpack(self, "self", 0); #ifndef NDEBUG std::vector<c10::optional<Storage>> self__storage_saved(self_.size()); for (const Tensor& tensor : self_) self__storage_saved.push_back( tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt); std::vector<c10::intrusive_ptr<TensorImpl>> self__impl_saved(self_.size()); for (size_t i=0; i<self_.size(); i++) if (self_[i].defined()) self__impl_saved[i] = self_[i].getIntrusivePtr(); #endif { at::AutoDispatchBelowAutograd guard; at::redispatch::_foreach_abs_(ks & c10::after_autograd_keyset, self_); } #ifndef NDEBUG for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (self__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(self_)) AT_ASSERT(self__storage_saved[i].value().is_alias_of(self_[i].storage())); } for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (self__impl_saved[i] && !at::impl::tensorlist_has_dispatch(self_)) AT_ASSERT(self__impl_saved[i] == self_[i].getIntrusivePtr()); } #endif } ``` Related: - #95431 - #95765 for multiple `grad_fn`s logic --- Examples: outputs of `_foreach_add_.List`, `_foreach_addcmul_.ScalarList`, and `_foreach_exp` ```c++ void _foreach_addcmul__ScalarList(c10::DispatchKeySet ks, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars) { auto self_ = unpack(self, "self", 0); auto tensor1_ = unpack(tensor1, "tensor1", 1); auto tensor2_ = unpack(tensor2, "tensor2", 2); auto _any_requires_grad = compute_requires_grad( self, tensor1, tensor2 ); (void)_any_requires_grad; std::vector<c10::optional<at::Tensor>> original_selfs(self.size()); std::vector<std::shared_ptr<AddcmulBackward0>> grad_fns; if (_any_requires_grad) { for (const auto& i : c10::irange( self.size() )) { const auto ith_requires_grad = compute_requires_grad(self[i], tensor1[i], tensor2[i]); check_inplace(self[i], ith_requires_grad); grad_fns.push_back([&]() -> std::shared_ptr<AddcmulBackward0> { if (!ith_requires_grad) { return nullptr; } else { auto grad_fn = std::shared_ptr<AddcmulBackward0>(new AddcmulBackward0(), deleteNode); grad_fn->set_next_edges(collect_next_edges( self[i], tensor1[i], tensor2[i] )); return grad_fn; } }()); } if (!grad_fns.empty()) { for (const auto& i : c10::irange(grad_fns.size())) { auto grad_fn = grad_fns[i]; if (grad_fn != nullptr) { grad_fn->self_scalar_type = self[i].scalar_type(); grad_fn->tensor1_scalar_type = tensor1[i].scalar_type(); if (grad_fn->should_compute_output(1)) { grad_fn->tensor2_ = SavedVariable(tensor2[i], false); } grad_fn->value = scalars[i]; if (grad_fn->should_compute_output(2)) { grad_fn->tensor1_ = SavedVariable(tensor1[i], false); } grad_fn->tensor2_scalar_type = tensor2[i].scalar_type(); } } } } #ifndef NDEBUG std::vector<c10::optional<Storage>> self__storage_saved(self_.size()); for (const Tensor& tensor : self_) self__storage_saved.push_back( tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt); std::vector<c10::intrusive_ptr<TensorImpl>> self__impl_saved(self_.size()); for (size_t i=0; i<self_.size(); i++) if (self_[i].defined()) self__impl_saved[i] = self_[i].getIntrusivePtr(); std::vector<c10::optional<Storage>> tensor1__storage_saved(tensor1_.size()); for (const Tensor& tensor : tensor1_) tensor1__storage_saved.push_back( tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt); std::vector<c10::intrusive_ptr<TensorImpl>> tensor1__impl_saved(tensor1_.size()); for (size_t i=0; i<tensor1_.size(); i++) if (tensor1_[i].defined()) tensor1__impl_saved[i] = tensor1_[i].getIntrusivePtr(); std::vector<c10::optional<Storage>> tensor2__storage_saved(tensor2_.size()); for (const Tensor& tensor : tensor2_) tensor2__storage_saved.push_back( tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt); std::vector<c10::intrusive_ptr<TensorImpl>> tensor2__impl_saved(tensor2_.size()); for (size_t i=0; i<tensor2_.size(); i++) if (tensor2_[i].defined()) tensor2__impl_saved[i] = tensor2_[i].getIntrusivePtr(); #endif { at::AutoDispatchBelowAutograd guard; at::redispatch::_foreach_addcmul_(ks & c10::after_autograd_keyset, self_, tensor1_, tensor2_, scalars); } #ifndef NDEBUG for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (self__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(self_)) TORCH_INTERNAL_ASSERT(self__storage_saved[i].value().is_alias_of(self_[i].storage())); } for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (self__impl_saved[i] && !at::impl::tensorlist_has_dispatch(self_)) TORCH_INTERNAL_ASSERT(self__impl_saved[i] == self_[i].getIntrusivePtr()); } for (size_t i=0; i<tensor1_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (tensor1__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(tensor1_)) TORCH_INTERNAL_ASSERT(tensor1__storage_saved[i].value().is_alias_of(tensor1_[i].storage())); } for (size_t i=0; i<tensor1_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (tensor1__impl_saved[i] && !at::impl::tensorlist_has_dispatch(tensor1_)) TORCH_INTERNAL_ASSERT(tensor1__impl_saved[i] == tensor1_[i].getIntrusivePtr()); } for (size_t i=0; i<tensor2_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (tensor2__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(tensor2_)) TORCH_INTERNAL_ASSERT(tensor2__storage_saved[i].value().is_alias_of(tensor2_[i].storage())); } for (size_t i=0; i<tensor2_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (tensor2__impl_saved[i] && !at::impl::tensorlist_has_dispatch(tensor2_)) TORCH_INTERNAL_ASSERT(tensor2__impl_saved[i] == tensor2_[i].getIntrusivePtr()); } #endif if (!grad_fns.empty()) { auto differentiable_outputs = flatten_tensor_args( self ); TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size()); for (const auto& i : c10::irange(grad_fns.size())) { auto grad_fn = grad_fns[i]; if (grad_fn != nullptr) { rebase_history(differentiable_outputs[i], grad_fns[i]); } } } } ``` ```c++ void _foreach_add__List(c10::DispatchKeySet ks, at::TensorList self, at::TensorList other, const at::Scalar & alpha) { auto self_ = unpack(self, "self", 0); auto other_ = unpack(other, "other", 1); auto _any_requires_grad = compute_requires_grad( self, other ); (void)_any_requires_grad; std::vector<c10::optional<at::Tensor>> original_selfs(self.size()); std::vector<std::shared_ptr<AddBackward0>> grad_fns; if (_any_requires_grad) { for (const auto& i : c10::irange( self.size() )) { const auto ith_requires_grad = compute_requires_grad(self[i], other[i]); check_inplace(self[i], ith_requires_grad); grad_fns.push_back([&]() -> std::shared_ptr<AddBackward0> { if (!ith_requires_grad) { return nullptr; } else { auto grad_fn = std::shared_ptr<AddBackward0>(new AddBackward0(), deleteNode); grad_fn->set_next_edges(collect_next_edges( self[i], other[i] )); return grad_fn; } }()); } if (!grad_fns.empty()) { for (const auto& i : c10::irange(grad_fns.size())) { auto grad_fn = grad_fns[i]; if (grad_fn != nullptr) { grad_fn->other_scalar_type = other[i].scalar_type(); grad_fn->alpha = alpha; grad_fn->self_scalar_type = self[i].scalar_type(); } } } } #ifndef NDEBUG std::vector<c10::optional<Storage>> self__storage_saved(self_.size()); for (const Tensor& tensor : self_) self__storage_saved.push_back( tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt); std::vector<c10::intrusive_ptr<TensorImpl>> self__impl_saved(self_.size()); for (size_t i=0; i<self_.size(); i++) if (self_[i].defined()) self__impl_saved[i] = self_[i].getIntrusivePtr(); std::vector<c10::optional<Storage>> other__storage_saved(other_.size()); for (const Tensor& tensor : other_) other__storage_saved.push_back( tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt); std::vector<c10::intrusive_ptr<TensorImpl>> other__impl_saved(other_.size()); for (size_t i=0; i<other_.size(); i++) if (other_[i].defined()) other__impl_saved[i] = other_[i].getIntrusivePtr(); #endif { at::AutoDispatchBelowAutograd guard; at::redispatch::_foreach_add_(ks & c10::after_autograd_keyset, self_, other_, alpha); } #ifndef NDEBUG for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (self__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(self_)) TORCH_INTERNAL_ASSERT(self__storage_saved[i].value().is_alias_of(self_[i].storage())); } for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (self__impl_saved[i] && !at::impl::tensorlist_has_dispatch(self_)) TORCH_INTERNAL_ASSERT(self__impl_saved[i] == self_[i].getIntrusivePtr()); } for (size_t i=0; i<other_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (other__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(other_)) TORCH_INTERNAL_ASSERT(other__storage_saved[i].value().is_alias_of(other_[i].storage())); } for (size_t i=0; i<other_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (other__impl_saved[i] && !at::impl::tensorlist_has_dispatch(other_)) TORCH_INTERNAL_ASSERT(other__impl_saved[i] == other_[i].getIntrusivePtr()); } #endif if (!grad_fns.empty()) { auto differentiable_outputs = flatten_tensor_args( self ); TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size()); for (const auto& i : c10::irange(grad_fns.size())) { auto grad_fn = grad_fns[i]; if (grad_fn != nullptr) { rebase_history(differentiable_outputs[i], grad_fns[i]); } } } } ... void _foreach_exp_(c10::DispatchKeySet ks, at::TensorList self) { auto self_ = unpack(self, "self", 0); auto _any_requires_grad = compute_requires_grad( self ); (void)_any_requires_grad; std::vector<c10::optional<at::Tensor>> original_selfs(self.size()); std::vector<std::shared_ptr<ExpBackward0>> grad_fns; if (_any_requires_grad) { for (const auto& i : c10::irange( self.size() )) { const auto ith_requires_grad = compute_requires_grad(self[i]); check_inplace(self[i], ith_requires_grad); grad_fns.push_back([&]() -> std::shared_ptr<ExpBackward0> { if (!ith_requires_grad) { return nullptr; } else { auto grad_fn = std::shared_ptr<ExpBackward0>(new ExpBackward0(), deleteNode); grad_fn->set_next_edges(collect_next_edges( self[i] )); return grad_fn; } }()); } } #ifndef NDEBUG std::vector<c10::optional<Storage>> self__storage_saved(self_.size()); for (const Tensor& tensor : self_) self__storage_saved.push_back( tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt); std::vector<c10::intrusive_ptr<TensorImpl>> self__impl_saved(self_.size()); for (size_t i=0; i<self_.size(); i++) if (self_[i].defined()) self__impl_saved[i] = self_[i].getIntrusivePtr(); #endif { at::AutoDispatchBelowAutograd guard; at::redispatch::_foreach_exp_(ks & c10::after_autograd_keyset, self_); } #ifndef NDEBUG for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (self__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(self_)) TORCH_INTERNAL_ASSERT(self__storage_saved[i].value().is_alias_of(self_[i].storage())); } for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) { if (self__impl_saved[i] && !at::impl::tensorlist_has_dispatch(self_)) TORCH_INTERNAL_ASSERT(self__impl_saved[i] == self_[i].getIntrusivePtr()); } #endif if (!grad_fns.empty()) { auto differentiable_outputs = flatten_tensor_args( self ); TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size()); for (const auto& i : c10::irange(grad_fns.size())) { auto grad_fn = grad_fns[i]; if (grad_fn != nullptr) { rebase_history(differentiable_outputs[i], grad_fns[i]); } } } if (!grad_fns.empty()) { for (const auto& i : c10::irange(grad_fns.size())) { auto grad_fn = grad_fns[i]; if (grad_fn != nullptr) { grad_fn->result_ = SavedVariable(self[i], true, self[i].is_view()); } } } } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/96405 Approved by: https://github.com/soulitzer
Author
Committer
Parents
Loading