Fix sum batching rule, add simple clone batching rule (#47189)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47189
PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
and instead returns a new copy of the original scalar_tensor. If we
end up vmapping over per-example scalar tensors, e.g.,
```
>>> x = torch.randn(B0) # the per-examples are all scalars
>>> vmap(partial(torch.sum, dim=0), x)
```
then we should replicate the behavior of sum(scalar_tensor, dim=0) by
returning a clone of the input tensor.
This PR also adds a batching rule for clone(Tensor, MemoryFormat). The
batching rule:
- unwraps the BatchedTensor, calls clone(), and rewraps the
BatchedTensor if MemoryFormat is torch.preserve_format (which is the
default).
- errors out with an NYI for all other memory formats, including
torch.contiguous_format. There are some weird semantics for memory
layouts with vmap that I need to go and figure out. Those are noted in
the comments for `clone_batching_rule`
Test Plan: - new tests
Reviewed By: ejguan
Differential Revision: D24741689
Pulled By: zou3519
fbshipit-source-id: e640344b4e4aa8c0d2dbacc5c49901f4c33c6613