ensure version_counter gets incremented for non-differentiable outputs (#20612)
Summary:
issue:
https://github.com/pytorch/pytorch/issues/14571
To reproduce I:
1) added these lines to derivatives.yaml:
```
- name: add_(Tensor self, Scalar other, Scalar alpha)
output_differentiability: [False, False, False]
- name: add_(Tensor self, Tensor other, Scalar alpha)
output_differentiability: [False, False, False]
```
2) Ran this code:
```
import torch
scalar = torch.tensor(5)
var1 = torch.randn(4,2,requires_grad=True)
var2 = var1.detach().requires_grad_()
output1 = var1 * scalar
output2 = var2 * scalar
output1.sum().backward()
scalar.add_(5, 1)
output2.sum().backward()
print(var1.grad)
print(var2.grad)
```
Observed modified var2.grad in the output:
```
tensor([[5., 5.],
[5., 5.],
[5., 5.],
[5., 5.]])
tensor([[10., 10.],
[10., 10.],
[10., 10.],
[10., 10.]])
```
After making this change, re-running the above code produces the expected error:
```
Traceback (most recent call last):
File "test.py", line 18, in <module>
output2.sum().backward()
File "/home/bvaughan/anaconda3/lib/python3.7/site-packages/torch/tensor.py", line 107, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/bvaughan/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 93, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.LongTensor []] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20612
Differential Revision: D15661958
Pulled By: nairbv
fbshipit-source-id: af3373135a1a589a635b49e0ff62622a210258e6