Add batching rule for torch.clone(tensor, torch.contiguous_format) (#47365)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47365
I wanted to avoid defining vmap behavior over contiguous_format for as
long as possible. This is potentially ambiguous, consider the following:
```
>>> x = torch.randn(3, B0, 5)
>>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1,
out_dims=1)(x)
>>> y[:,0].is_contiguous() # ??
```
There are two possible ways to interpret this operation (if we choose to
allow it to succeed):
1. Each per-sample becomes contiguous, so y[:,0] is contiguous.
2. The output of vmap is contiguous (so y is contiguous, but y[:,0] is
not)
(1) makes more sense because vmap operates on a per-sample level.
This makes sense when combined with the vmap fallback:
- there are places in the codebase where we perform .contiguous() and
then pass the result to an operator `op` that only accepts contiguous
inputs.
- If we vmap over such code and don't have a batching rule implemented for
`op`, then we want the per-samples to be contiguous so that
when `op` goes through the vmap fallback, it receives contiguous
per-samples.
(1) is the approach we've selected for this PR.
Motivation
----------
To vmap over CopySlices, we have to vmap over a clone(contiguous_format)
call:
https://github.com/pytorch/pytorch/blob/e4bc785dd57b15ae091eb8e8ca71a604da9b3fb2/torch/csrc/autograd/functions/tensor.cpp#L93
Alternatives
------------
- Implementing (2) is difficult in the current design because vmap is
allowed to move batch dimensions to the front of the tensor. We would
need some global information about the in_dims and out_dims passed to
vmap.
- We could also error out if someone calls clone(contiguous_format) and
the batch dims are not at the front. This would resolve the ambiguity at
the cost of limiting what vmap can do.
Future Work
-----------
- Add to a "vmap gotchas" page the behavior of contiguous_format.
- Implement is_contiguous, Tensor.contiguous() with the same semantics.
Those currently error out.
Test Plan
---------
- new tests
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Differential Revision: D24741683
Pulled By: zou3519
fbshipit-source-id: 3ef5ded1b646855f41d39dcefe81129176de8a70