Add torch.empty_permuted (#95069)
torch.empty_permuted is a generalized version of torch.empty(memory_format=...), where you can pass an arbitrary physical layout as a tuple of dims to allow you to setup dense, non-overlapping tensors with non-standard memory format. Check the docblock for a full description of semantics.
The initial motivation for this PR is with guard-less unbacked SymInts. Traditionally, the way we allocate dense tensors with arbitrary layout is with `empty_strided`. However, `empty_strided` does not know that the given strides are actually contiguous, and must test this manually to find out if it is the case. With `empty_permuted`, this is known statically to be the case and helps us skip some 0/1 guards.
However, I also think torch.empty_permuted is a useful API in its own right. It is technically possible to simulate this with an empty and a permute; however, there are some downsides:
* The manual incant is tricky to work out. To allocate an NHWC tensor, the invocation is `torch.empty(N, H, W, C).permute(0, 3, 1, 2)`; the permute call has to take NHWC to NCHW, and is the *inverse* of the permutation people are typically thinking of when they talk about NHWC (0, 2, 3, 1). Instead, torch.empty_permuted lets you say `torch.empty_permuted((N, C, H, W), (0, 2, 3, 1))`, letting you provide the intuitive permutation. It can be literally be read off as NHWC if you assign N=0, C=1, H=2, W=3.
* An empty(requires_grad=True).permute() is no longer a leaf tensor. You can force it to be a leaf with a detach(), but it is more straightforward and less error prone to allow directly allocating a tensor with the correct permutation.
It is also technically possible to simulate this with empty_strided. However, this requires the user to manually compute the contiguous output strides and is bad from a reduction of guards perspective. For what it's worth, this is one of the more common uses of as_strided in the wild, and it would be nice to get rid of it.
A nice enhancement of this feature would be to accept `physical_layout` anywhere `memory_format` is accepted. However, this would be a pretty involved change, so I'm doing the easy thing instead.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95069
Approved by: https://github.com/malfet, https://github.com/ngimel, https://github.com/albanD, https://github.com/dagitses