per-Tensor `grad_fn` for in-place foreach functions (#96405)
Generate a `grad_fn` for each (tuple of) `Tensor`(s) of the same index for `_foreach_foo_` and each `grad_fn` is `FooBackward`.
The current status of foreach functions' backward support for the record:
- out-place: Implemented, but no optimized implementations like their forward path
- in-place: not implemented. I think this check https://github.com/pytorch/pytorch/blob/7eaaefafb3ce0e4a0a9e1eb647e340711973ec12/torchgen/api/autograd.py#L309-L311 is partly responsible but the difference of signature between out-place and in-place (see https://github.com/pytorch/pytorch/pull/96405#discussion_r1154690940) would prevent in-place from using out-place versions (the logic is around https://github.com/pytorch/pytorch/blob/7eaaefafb3ce0e4a0a9e1eb647e340711973ec12/torchgen/api/autograd.py#L495-L500)
```c++
void _foreach_abs_(c10::DispatchKeySet ks, at::TensorList self) {
auto self_ = unpack(self, "self", 0);
#ifndef NDEBUG
std::vector<c10::optional<Storage>> self__storage_saved(self_.size());
for (const Tensor& tensor : self_)
self__storage_saved.push_back(
tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
std::vector<c10::intrusive_ptr<TensorImpl>> self__impl_saved(self_.size());
for (size_t i=0; i<self_.size(); i++)
if (self_[i].defined()) self__impl_saved[i] = self_[i].getIntrusivePtr();
#endif
{
at::AutoDispatchBelowAutograd guard;
at::redispatch::_foreach_abs_(ks & c10::after_autograd_keyset, self_);
}
#ifndef NDEBUG
for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (self__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(self_))
AT_ASSERT(self__storage_saved[i].value().is_alias_of(self_[i].storage()));
}
for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (self__impl_saved[i] && !at::impl::tensorlist_has_dispatch(self_))
AT_ASSERT(self__impl_saved[i] == self_[i].getIntrusivePtr());
}
#endif
}
```
Related:
- #95431
- #95765 for multiple `grad_fn`s logic
---
Examples: outputs of `_foreach_add_.List`, `_foreach_addcmul_.ScalarList`, and `_foreach_exp`
```c++
void _foreach_addcmul__ScalarList(c10::DispatchKeySet ks, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars) {
auto self_ = unpack(self, "self", 0);
auto tensor1_ = unpack(tensor1, "tensor1", 1);
auto tensor2_ = unpack(tensor2, "tensor2", 2);
auto _any_requires_grad = compute_requires_grad( self, tensor1, tensor2 );
(void)_any_requires_grad;
std::vector<c10::optional<at::Tensor>> original_selfs(self.size());
std::vector<std::shared_ptr<AddcmulBackward0>> grad_fns;
if (_any_requires_grad) {
for (const auto& i : c10::irange( self.size() )) {
const auto ith_requires_grad = compute_requires_grad(self[i], tensor1[i], tensor2[i]);
check_inplace(self[i], ith_requires_grad);
grad_fns.push_back([&]() -> std::shared_ptr<AddcmulBackward0> {
if (!ith_requires_grad) {
return nullptr;
} else {
auto grad_fn = std::shared_ptr<AddcmulBackward0>(new AddcmulBackward0(), deleteNode);
grad_fn->set_next_edges(collect_next_edges( self[i], tensor1[i], tensor2[i] ));
return grad_fn;
}
}());
}
if (!grad_fns.empty()) {
for (const auto& i : c10::irange(grad_fns.size())) {
auto grad_fn = grad_fns[i];
if (grad_fn != nullptr) {
grad_fn->self_scalar_type = self[i].scalar_type();
grad_fn->tensor1_scalar_type = tensor1[i].scalar_type();
if (grad_fn->should_compute_output(1)) {
grad_fn->tensor2_ = SavedVariable(tensor2[i], false);
}
grad_fn->value = scalars[i];
if (grad_fn->should_compute_output(2)) {
grad_fn->tensor1_ = SavedVariable(tensor1[i], false);
}
grad_fn->tensor2_scalar_type = tensor2[i].scalar_type();
}
}
}
}
#ifndef NDEBUG
std::vector<c10::optional<Storage>> self__storage_saved(self_.size());
for (const Tensor& tensor : self_)
self__storage_saved.push_back(
tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
std::vector<c10::intrusive_ptr<TensorImpl>> self__impl_saved(self_.size());
for (size_t i=0; i<self_.size(); i++)
if (self_[i].defined()) self__impl_saved[i] = self_[i].getIntrusivePtr();
std::vector<c10::optional<Storage>> tensor1__storage_saved(tensor1_.size());
for (const Tensor& tensor : tensor1_)
tensor1__storage_saved.push_back(
tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
std::vector<c10::intrusive_ptr<TensorImpl>> tensor1__impl_saved(tensor1_.size());
for (size_t i=0; i<tensor1_.size(); i++)
if (tensor1_[i].defined()) tensor1__impl_saved[i] = tensor1_[i].getIntrusivePtr();
std::vector<c10::optional<Storage>> tensor2__storage_saved(tensor2_.size());
for (const Tensor& tensor : tensor2_)
tensor2__storage_saved.push_back(
tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
std::vector<c10::intrusive_ptr<TensorImpl>> tensor2__impl_saved(tensor2_.size());
for (size_t i=0; i<tensor2_.size(); i++)
if (tensor2_[i].defined()) tensor2__impl_saved[i] = tensor2_[i].getIntrusivePtr();
#endif
{
at::AutoDispatchBelowAutograd guard;
at::redispatch::_foreach_addcmul_(ks & c10::after_autograd_keyset, self_, tensor1_, tensor2_, scalars);
}
#ifndef NDEBUG
for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (self__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(self_))
TORCH_INTERNAL_ASSERT(self__storage_saved[i].value().is_alias_of(self_[i].storage()));
}
for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (self__impl_saved[i] && !at::impl::tensorlist_has_dispatch(self_))
TORCH_INTERNAL_ASSERT(self__impl_saved[i] == self_[i].getIntrusivePtr());
}
for (size_t i=0; i<tensor1_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (tensor1__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(tensor1_))
TORCH_INTERNAL_ASSERT(tensor1__storage_saved[i].value().is_alias_of(tensor1_[i].storage()));
}
for (size_t i=0; i<tensor1_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (tensor1__impl_saved[i] && !at::impl::tensorlist_has_dispatch(tensor1_))
TORCH_INTERNAL_ASSERT(tensor1__impl_saved[i] == tensor1_[i].getIntrusivePtr());
}
for (size_t i=0; i<tensor2_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (tensor2__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(tensor2_))
TORCH_INTERNAL_ASSERT(tensor2__storage_saved[i].value().is_alias_of(tensor2_[i].storage()));
}
for (size_t i=0; i<tensor2_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (tensor2__impl_saved[i] && !at::impl::tensorlist_has_dispatch(tensor2_))
TORCH_INTERNAL_ASSERT(tensor2__impl_saved[i] == tensor2_[i].getIntrusivePtr());
}
#endif
if (!grad_fns.empty()) {
auto differentiable_outputs = flatten_tensor_args( self );
TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size());
for (const auto& i : c10::irange(grad_fns.size())) {
auto grad_fn = grad_fns[i];
if (grad_fn != nullptr) {
rebase_history(differentiable_outputs[i], grad_fns[i]);
}
}
}
}
```
```c++
void _foreach_add__List(c10::DispatchKeySet ks, at::TensorList self, at::TensorList other, const at::Scalar & alpha) {
auto self_ = unpack(self, "self", 0);
auto other_ = unpack(other, "other", 1);
auto _any_requires_grad = compute_requires_grad( self, other );
(void)_any_requires_grad;
std::vector<c10::optional<at::Tensor>> original_selfs(self.size());
std::vector<std::shared_ptr<AddBackward0>> grad_fns;
if (_any_requires_grad) {
for (const auto& i : c10::irange( self.size() )) {
const auto ith_requires_grad = compute_requires_grad(self[i], other[i]);
check_inplace(self[i], ith_requires_grad);
grad_fns.push_back([&]() -> std::shared_ptr<AddBackward0> {
if (!ith_requires_grad) {
return nullptr;
} else {
auto grad_fn = std::shared_ptr<AddBackward0>(new AddBackward0(), deleteNode);
grad_fn->set_next_edges(collect_next_edges( self[i], other[i] ));
return grad_fn;
}
}());
}
if (!grad_fns.empty()) {
for (const auto& i : c10::irange(grad_fns.size())) {
auto grad_fn = grad_fns[i];
if (grad_fn != nullptr) {
grad_fn->other_scalar_type = other[i].scalar_type();
grad_fn->alpha = alpha;
grad_fn->self_scalar_type = self[i].scalar_type();
}
}
}
}
#ifndef NDEBUG
std::vector<c10::optional<Storage>> self__storage_saved(self_.size());
for (const Tensor& tensor : self_)
self__storage_saved.push_back(
tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
std::vector<c10::intrusive_ptr<TensorImpl>> self__impl_saved(self_.size());
for (size_t i=0; i<self_.size(); i++)
if (self_[i].defined()) self__impl_saved[i] = self_[i].getIntrusivePtr();
std::vector<c10::optional<Storage>> other__storage_saved(other_.size());
for (const Tensor& tensor : other_)
other__storage_saved.push_back(
tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
std::vector<c10::intrusive_ptr<TensorImpl>> other__impl_saved(other_.size());
for (size_t i=0; i<other_.size(); i++)
if (other_[i].defined()) other__impl_saved[i] = other_[i].getIntrusivePtr();
#endif
{
at::AutoDispatchBelowAutograd guard;
at::redispatch::_foreach_add_(ks & c10::after_autograd_keyset, self_, other_, alpha);
}
#ifndef NDEBUG
for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (self__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(self_))
TORCH_INTERNAL_ASSERT(self__storage_saved[i].value().is_alias_of(self_[i].storage()));
}
for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (self__impl_saved[i] && !at::impl::tensorlist_has_dispatch(self_))
TORCH_INTERNAL_ASSERT(self__impl_saved[i] == self_[i].getIntrusivePtr());
}
for (size_t i=0; i<other_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (other__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(other_))
TORCH_INTERNAL_ASSERT(other__storage_saved[i].value().is_alias_of(other_[i].storage()));
}
for (size_t i=0; i<other_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (other__impl_saved[i] && !at::impl::tensorlist_has_dispatch(other_))
TORCH_INTERNAL_ASSERT(other__impl_saved[i] == other_[i].getIntrusivePtr());
}
#endif
if (!grad_fns.empty()) {
auto differentiable_outputs = flatten_tensor_args( self );
TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size());
for (const auto& i : c10::irange(grad_fns.size())) {
auto grad_fn = grad_fns[i];
if (grad_fn != nullptr) {
rebase_history(differentiable_outputs[i], grad_fns[i]);
}
}
}
}
...
void _foreach_exp_(c10::DispatchKeySet ks, at::TensorList self) {
auto self_ = unpack(self, "self", 0);
auto _any_requires_grad = compute_requires_grad( self );
(void)_any_requires_grad;
std::vector<c10::optional<at::Tensor>> original_selfs(self.size());
std::vector<std::shared_ptr<ExpBackward0>> grad_fns;
if (_any_requires_grad) {
for (const auto& i : c10::irange( self.size() )) {
const auto ith_requires_grad = compute_requires_grad(self[i]);
check_inplace(self[i], ith_requires_grad);
grad_fns.push_back([&]() -> std::shared_ptr<ExpBackward0> {
if (!ith_requires_grad) {
return nullptr;
} else {
auto grad_fn = std::shared_ptr<ExpBackward0>(new ExpBackward0(), deleteNode);
grad_fn->set_next_edges(collect_next_edges( self[i] ));
return grad_fn;
}
}());
}
}
#ifndef NDEBUG
std::vector<c10::optional<Storage>> self__storage_saved(self_.size());
for (const Tensor& tensor : self_)
self__storage_saved.push_back(
tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
std::vector<c10::intrusive_ptr<TensorImpl>> self__impl_saved(self_.size());
for (size_t i=0; i<self_.size(); i++)
if (self_[i].defined()) self__impl_saved[i] = self_[i].getIntrusivePtr();
#endif
{
at::AutoDispatchBelowAutograd guard;
at::redispatch::_foreach_exp_(ks & c10::after_autograd_keyset, self_);
}
#ifndef NDEBUG
for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (self__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(self_))
TORCH_INTERNAL_ASSERT(self__storage_saved[i].value().is_alias_of(self_[i].storage()));
}
for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (self__impl_saved[i] && !at::impl::tensorlist_has_dispatch(self_))
TORCH_INTERNAL_ASSERT(self__impl_saved[i] == self_[i].getIntrusivePtr());
}
#endif
if (!grad_fns.empty()) {
auto differentiable_outputs = flatten_tensor_args( self );
TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size());
for (const auto& i : c10::irange(grad_fns.size())) {
auto grad_fn = grad_fns[i];
if (grad_fn != nullptr) {
rebase_history(differentiable_outputs[i], grad_fns[i]);
}
}
}
if (!grad_fns.empty()) {
for (const auto& i : c10::irange(grad_fns.size())) {
auto grad_fn = grad_fns[i];
if (grad_fn != nullptr) {
grad_fn->result_ = SavedVariable(self[i], true, self[i].is_view());
}
}
}
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96405
Approved by: https://github.com/soulitzer