Permutation extended (#1614)
Extended permutation support in integration (See more details on #1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time.
The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)`
1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation;
2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output;
3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous.
By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`.
Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5):
1. different rank & same permutation
```
t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c)
t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1]) # stride (1, wc, c)
out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0
```
2. different rank & compatible permutation
```
t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c)
t1 = torch.randn(c, h, w).cuda() # stride (hw, w, 1)
out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0
```
3. different rank & compatible permutation with broadcasting
```
t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c)
t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1) # stride (1, 1, 1)
out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0
```
4. different rank & in-compatible permutation
```
t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c)
t1 = torch.randn(h, w).cuda() # stride (w, 1)
jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, wc, c, 1) # nvfuser outputs contiguous tensor
eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # TI preserves memory format of LHS operand
```
5. different rank & in-compatible permutation
```
t0 = torch.randn(c, h, w).cuda() # stride (hw, w, 1)
t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c)
jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # nvfuser preserves memory format of highest rank tensors
eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, hw, w, 1) # TensorIterator preserves memory format of LHS operand
```