Batching rule for Tensor.new_empty_strided (#47226)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47226
The batching rule is a little weird because it's not immediately obvious
what the strides of the result should be. If
tensor.new_empty_strided(size, stride) is called inside vmap and
`tensor` is being vmapped over, the result is a physical tensor with:
- size `[batch_shape] + size`
- strides `[S0, S1, ..., Sn] + stride` such that the
S0...Sn are part of a contiguous subspace and Sn is equal to the size of
the storage of `torch.empty_strided(size, stride)`.
I refactored some of the logic that computes the storage size for
`torch.empty_strided(size, stride)` into a helper function
`native::storage_size_for` and use it in the batching rule.
Test Plan: - New tests in test/test_vmap.py
Reviewed By: ejguan
Differential Revision: D24741690
Pulled By: zou3519
fbshipit-source-id: f09b5578e923470d456d50348d86687a03b598d2