Check native_function's outputs' TensorImpl and StorageImpl (#60286)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/25927
Does some checks described in https://github.com/pytorch/pytorch/issues/25927#issuecomment-589354373:
If function does not modify its inputs (non-inplace and has no out arg):
- Check TensorImpl has use_count of 1. (This should make us aware of functions that return self.
- If function is a view function check that StorageImpl is same as that of the aliased input, otherwise, StorageImpl's use_count is 1.
Detected a couple functions that failed the check that returned TensorImpl should have use_count of 1: 'native_batch_norm', 'native_batch_norm_backward', '_embedding_bag'. (Filing issues).
Examples of generated code:
We did not update checks for in-place ops (this includes in-place views).
Example of a view:
- Check that outputs StorageImpl of `result` is the same as that of `self`.
- Check TensorImpl has use_count of 1
```cpp
at::Tensor as_strided(c10::DispatchKeySet ks, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<int64_t> storage_offset) {
auto& self_ = unpack(self, "self", 0);
auto _any_requires_grad = compute_requires_grad( self );
(void)_any_requires_grad;
std::shared_ptr<AsStridedBackward> grad_fn;
if (_any_requires_grad) {
grad_fn = std::shared_ptr<AsStridedBackward>(new AsStridedBackward(), deleteNode);
grad_fn->set_next_edges(collect_next_edges( self ));
grad_fn->self_geometry = TensorGeometry(self);
grad_fn->size = size.vec();
grad_fn->stride = stride.vec();
grad_fn->storage_offset = storage_offset;
}
#ifndef NDEBUG
c10::optional<Storage> self__storage_saved =
self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;
c10::intrusive_ptr<TensorImpl> self__impl_saved;
if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();
#endif
auto _tmp = ([&]() {
at::AutoDispatchBelowAutograd guard;
return at::redispatch::as_strided(ks & c10::after_autograd_keyset, self_, size, stride, storage_offset);
})();
auto result = std::move(_tmp);
#ifndef NDEBUG
if (self__storage_saved.has_value())
AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));
if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());
if (self__storage_saved.has_value())
AT_ASSERT(self__storage_saved.value().is_alias_of(result.storage())); <<<<<<<<<<<<<<<<<<<<<<<<
AT_ASSERT(result.use_count() == 1); <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
#endif
if (grad_fn) {
set_history(flatten_tensor_args( result ), grad_fn);
}
TORCH_CHECK_NOT_IMPLEMENTED(!(isFwGradDefined(self)), "Trying to use forward AD with as_strided that does not support it.");
return result;
}
```
Example of non-view:
- Check that output's StorageImpl has use_count of 1.
- Check that output's TensorImpl has use_count of 1.
```cpp
at::Tensor asin(c10::DispatchKeySet ks, const at::Tensor & self) {
auto& self_ = unpack(self, "self", 0);
auto _any_requires_grad = compute_requires_grad( self );
(void)_any_requires_grad;
std::shared_ptr<AsinBackward> grad_fn;
if (_any_requires_grad) {
grad_fn = std::shared_ptr<AsinBackward>(new AsinBackward(), deleteNode);
grad_fn->set_next_edges(collect_next_edges( self ));
grad_fn->self_ = SavedVariable(self, false);
}
#ifndef NDEBUG
c10::optional<Storage> self__storage_saved =
self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;
c10::intrusive_ptr<TensorImpl> self__impl_saved;
if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();
#endif
auto _tmp = ([&]() {
at::AutoDispatchBelowADInplaceOrView guard;
return at::redispatch::asin(ks & c10::after_autograd_keyset, self_);
})();
auto result = std::move(_tmp);
#ifndef NDEBUG
if (self__storage_saved.has_value())
AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));
if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());
if (result.has_storage()) AT_ASSERT(result.storage().use_count() == 1); <<<<<<<<<<<<<<<<<<<<<<<<<<
AT_ASSERT(result.use_count() == 1); <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
#endif
if (grad_fn) {
set_history(flatten_tensor_args( result ), grad_fn);
}
if (isFwGradDefined(self)) {
auto self_t_raw = toNonOptFwGrad(self);
auto self_t = self_t_raw.defined() ? self_t_raw : at::zeros_like(toNonOptTensor(self));
auto self_p = toNonOptPrimal(self);
auto result_new_fw_grad = (self_t.conj() * (-self_p * self_p + 1).rsqrt().conj()).conj();
if (result_new_fw_grad.defined()) {
// The hardcoded 0 here will need to be updated once we support multiple levels.
result._set_fw_grad(result_new_fw_grad, /* level */ 0, /* is_inplace_op */ false);
}
}
return result;
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60286
Reviewed By: jbschlosser
Differential Revision: D29402253
Pulled By: soulitzer
fbshipit-source-id: b90f34c455b8767f95a52c329db351dbbb495397