as_strided batching rule (#47364)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47364
This PR adds a batching rule for as_strided. `as_strided` is a really weird
operation and I hope that users don't use it very much.
Motivation
----------
The motivation for adding a batching rule for as_strided is for
batched gradient computation.
AsStridedBackward appears in PyTorch when handling view+in-place
operations and calls `as_strided`. AsStridedBackward calls as_strided on
a fresh tensor with storage_offset equal to 0. We would like to be able
to vmap through the backward graph of view+in-place operations to
for batched gradient computation, especially because internally we have
a number of functions that are implemented as a view+in-place.
Alternatives
------------
If we think that as_strided is too crazy to have a batching rule, we
could either:
- have a flag that controls the autograd view+in-place
behavior
- require that the input tensor's storage offset must be equal to 0
to make it easier to reason about.
I think the batching rule makes sense, so I didn't pursue the
alternatives.
The batching rule
-----------------
```
y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
```
The result of the above should be "equivalent" to:
- Assume that each x has storage offset equal to xs.storage_offset()
(call that S).
- Calling as_strided with (sizes, sizes, offset + x[i].storage_offset() - S) on each x.
More concretely,
this returns a view on `xs`, such that each y[i] has:
- sizes: `sizes`
- strides: `strides`
- storage_offset: offset + i * x.stride(batch_dim)
Why the behavior can be weird
-----------------------------
The behavior of the batching rule may be different from actually running
as_strided in a for-loop because `as_strided` takes in `offset` as a
"absolute offset". As an example, consider
```
>>> x = torch.tensor([0., 1., 2., 3., 4.])
>>> z = [x[i].as_strided([1], [1], 0) for i in range(5)]
```
Each z[i] is actually the same view on x (z[i] == torch.tensor([0.]))!
However, we consider the above for-loop comprehension to be a user error:
a user should have written the following if they wanted to use as_strided
in a per-sample way:
```
>>> z = [x[i].as_strided([1], [1], 0 + x[i].storage_offset()) for i in range(5)]
```
Test Plan
---------
- Added some tests that compare vmap+as_strided to vmap+(the equivalent operator)
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Differential Revision: D24741685
Pulled By: zou3519
fbshipit-source-id: c1429caff43bfa33661a80bffc0daf2c0eea5564