pytorch
b80da898 - Batching rule for Tensor.new_empty_strided (#47226)

Commit
4 years ago
Batching rule for Tensor.new_empty_strided (#47226) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47226 The batching rule is a little weird because it's not immediately obvious what the strides of the result should be. If tensor.new_empty_strided(size, stride) is called inside vmap and `tensor` is being vmapped over, the result is a physical tensor with: - size `[batch_shape] + size` - strides `[S0, S1, ..., Sn] + stride` such that the S0...Sn are part of a contiguous subspace and Sn is equal to the size of the storage of `torch.empty_strided(size, stride)`. I refactored some of the logic that computes the storage size for `torch.empty_strided(size, stride)` into a helper function `native::storage_size_for` and use it in the batching rule. Test Plan: - New tests in test/test_vmap.py Reviewed By: ejguan Differential Revision: D24741690 Pulled By: zou3519 fbshipit-source-id: f09b5578e923470d456d50348d86687a03b598d2
Author
Parents
Loading