[Vulkan] Implement permute operator (#68274)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68274
Implemented `permute` operator on the Vulkan backend:
* Supports only <= 4D tensors.
* Builds up shader operations from the output texture point of view to avoid the nondeterministic order of GPU shader operations between texels. See [incoherent memory access](https://www.khronos.org/opengl/wiki/Memory_Model#Incoherent_memory_access)
* Generalized input tensors to 4D ones to simplify input/output texture handling. For example, {2, 3} is treated as {1,1,2,3} internally.
* 1D to 4D inputs with all possible permutations are used for test cases.
* Reference on CPU implementation of `permute` operator: [TensorShape.cpp](https://github.com/pytorch/pytorch/blob/cbf596bf8ee48f4ed895964fe4faf75a851b49c4/aten/src/ATen/native/TensorShape.cpp#L936)
* When shuffling dims, a new depth size of output texture needs to be determined by `ceil(batch*channel)/4`. This logic needs to be handled in a separate change.
* The depth of texture cannot exceed a certain number, depending on the device's capability. It is typically 2048 on most of android devices but less than or equal to 16,384 (see [Value distribution for maxImageDimension3D on Android](https://vulkan.gpuinfo.org/displaydevicelimit.php?name=maxImageDimension3D&platform=android)). i.e., 2048 on MacOS and Google Pixel 5.
* Due to this limitation, `permute` op needs to throw an exception if the depth of output texture is greater than or equal to `VkImageFormatProperties.maxExtent.depth`.
* Otherwise, the following error will occur: `-[MTLTextureDescriptorInternal validateWithDevice:]:1325: failed assertion "Texture Descriptor Validation MTLTextureDescriptor has depth (10664) greater than the maximum allowed size of 2048."`
* Vulkan `permute` operator tensor conversion:
{F679505029}
{F679505223}
* Vulkan `permute` operator shader equation:
{F679504799}
* Error/edge cases:
```
X = torch.randint(0, 23, (2, 3, 2, 2))
O = torch.permute(X, (2, 2, 1, 0))
# RuntimeError: repeated dim in permute
O = torch.permute(X, (2, 1, 0))
# RuntimeError: number of dims don't match in permute
O = torch.permute(X, (4, 3, 2, 1, 0))
# RuntimeError: number of dims don't match in permute
O = torch.permute(X, (3, 2, -1, 0))
# RuntimeError: repeated dim in permute
data2 = [0,1,2]
X2 = torch.tensor(data2)
O2 = torch.permute(X2, (0))
# permute(): argument 'dims' (position 2) must be tuple of ints, not int
# TypeError: permute(): argument 'dims' (position 2) must be tuple of ints, not int
O = torch.permute(X, (0, 1, 2, 3))
# do nothing since the dims doesn't change?
```
* Shader debug traces with a 4D tensor size {2,3,2,2} with permute by {3,2,1,0}:
```
output tensor:
(1,1,.,.) =
0.4395 0.5652
0.1309 0.9768
0.0490 0.1127
(2,1,.,.) =
0.7058 0.2238
0.6542 0.4064
0.4813 0.0500
(1,2,.,.) =
0.1716 0.4951
0.2225 0.3255
0.0758 0.7150
(2,2,.,.) =
0.3762 0.0228
0.6367 0.4411
0.7682 0.7599
[ CPUFloatType{2,2,3,2} ]
shader debug traces:
src_index:0, b c h w: 0 0 0 0, posIn: (0 0 0) i:0 -> b c h w: 0 0 0 0, dst_index: 0, posOut: (0 0 0) j:0 -> inval[0.439453] outval[0.439453] -> inval[0.439453 0.130859 0.049011 0.564941] outval[0.439453 0.000000 0.000000 0.000000]
src_index:3, b c h w: 1 0 0 0, posIn: (0 0 0) i:3 -> b c h w: 0 0 0 1, dst_index: 0, posOut: (1 0 0) j:0 -> inval[0.564941] outval[0.564941] -> inval[0.439453 0.130859 0.049011 0.564941] outval[0.564941 0.000000 0.000000 0.000000]
src_index:1, b c h w: 0 1 0 0, posIn: (0 0 0) i:1 -> b c h w: 0 0 1 0, dst_index: 0, posOut: (0 1 0) j:0 -> inval[0.130859] outval[0.130859] -> inval[0.439453 0.130859 0.049011 0.564941] outval[0.130859 0.000000 0.000000 0.000000]
src_index:4, b c h w: 1 1 0 0, posIn: (0 0 1) i:0 -> b c h w: 0 0 1 1, dst_index: 0, posOut: (1 1 0) j:0 -> inval[0.976562] outval[0.976562] -> inval[0.976562 0.112671 -65504.000000 -65504.000000] outval[0.976562 0.000000 0.000000 0.000000]
src_index:2, b c h w: 0 2 0 0, posIn: (0 0 0) i:2 -> b c h w: 0 0 2 0, dst_index: 0, posOut: (0 2 0) j:0 -> inval[0.049011] outval[0.049011] -> inval[0.439453 0.130859 0.049011 0.564941] outval[0.049011 0.000000 0.000000 0.000000]
src_index:5, b c h w: 1 2 0 0, posIn: (0 0 1) i:1 -> b c h w: 0 0 2 1, dst_index: 0, posOut: (1 2 0) j:0 -> inval[0.112671] outval[0.112671] -> inval[0.976562 0.112671 -65504.000000 -65504.000000] outval[0.112671 0.000000 0.000000 0.000000]
src_index:0, b c h w: 0 0 1 0, posIn: (0 1 0) i:0 -> b c h w: 0 1 0 0, dst_index: 1, posOut: (0 0 0) j:1 -> inval[0.171509] outval[0.171509] -> inval[0.171509 0.222412 0.075745 0.494873] outval[0.439453 0.171509 0.000000 0.000000]
src_index:3, b c h w: 1 0 1 0, posIn: (0 1 0) i:3 -> b c h w: 0 1 0 1, dst_index: 1, posOut: (1 0 0) j:1 -> inval[0.494873] outval[0.494873] -> inval[0.171509 0.222412 0.075745 0.494873] outval[0.564941 0.494873 0.000000 0.000000]
src_index:1, b c h w: 0 1 1 0, posIn: (0 1 0) i:1 -> b c h w: 0 1 1 0, dst_index: 1, posOut: (0 1 0) j:1 -> inval[0.222412] outval[0.222412] -> inval[0.171509 0.222412 0.075745 0.494873] outval[0.130859 0.222412 0.000000 0.000000]
src_index:4, b c h w: 1 1 1 0, posIn: (0 1 1) i:0 -> b c h w: 0 1 1 1, dst_index: 1, posOut: (1 1 0) j:1 -> inval[0.325439] outval[0.325439] -> inval[0.325439 0.714844 -65504.000000 -65504.000000] outval[0.976562 0.325439 0.000000 0.000000]
src_index:2, b c h w: 0 2 1 0, posIn: (0 1 0) i:2 -> b c h w: 0 1 2 0, dst_index: 1, posOut: (0 2 0) j:1 -> inval[0.075745] outval[0.075745] -> inval[0.171509 0.222412 0.075745 0.494873] outval[0.049011 0.075745 0.000000 0.000000]
src_index:5, b c h w: 1 2 1 0, posIn: (0 1 1) i:1 -> b c h w: 0 1 2 1, dst_index: 1, posOut: (1 2 0) j:1 -> inval[0.714844] outval[0.714844] -> inval[0.325439 0.714844 -65504.000000 -65504.000000] outval[0.112671 0.714844 0.000000 0.000000]
src_index:0, b c h w: 0 0 0 1, posIn: (1 0 0) i:0 -> b c h w: 1 0 0 0, dst_index: 2, posOut: (0 0 0) j:2 -> inval[0.705566] outval[0.705566] -> inval[0.705566 0.653809 0.481201 0.223755] outval[0.439453 0.171509 0.705566 0.000000]
src_index:3, b c h w: 1 0 0 1, posIn: (1 0 0) i:3 -> b c h w: 1 0 0 1, dst_index: 2, posOut: (1 0 0) j:2 -> inval[0.223755] outval[0.223755] -> inval[0.705566 0.653809 0.481201 0.223755] outval[0.564941 0.494873 0.223755 0.000000]
src_index:1, b c h w: 0 1 0 1, posIn: (1 0 0) i:1 -> b c h w: 1 0 1 0, dst_index: 2, posOut: (0 1 0) j:2 -> inval[0.653809] outval[0.653809] -> inval[0.705566 0.653809 0.481201 0.223755] outval[0.130859 0.222412 0.653809 0.000000]
src_index:4, b c h w: 1 1 0 1, posIn: (1 0 1) i:0 -> b c h w: 1 0 1 1, dst_index: 2, posOut: (1 1 0) j:2 -> inval[0.406250] outval[0.406250] -> inval[0.406250 0.049957 -65504.000000 -65504.000000] outval[0.976562 0.325439 0.406250 0.000000]
src_index:2, b c h w: 0 2 0 1, posIn: (1 0 0) i:2 -> b c h w: 1 0 2 0, dst_index: 2, posOut: (0 2 0) j:2 -> inval[0.481201] outval[0.481201] -> inval[0.705566 0.653809 0.481201 0.223755] outval[0.049011 0.075745 0.481201 0.000000]
src_index:5, b c h w: 1 2 0 1, posIn: (1 0 1) i:1 -> b c h w: 1 0 2 1, dst_index: 2, posOut: (1 2 0) j:2 -> inval[0.049957] outval[0.049957] -> inval[0.406250 0.049957 -65504.000000 -65504.000000] outval[0.112671 0.714844 0.049957 0.000000]
src_index:0, b c h w: 0 0 1 1, posIn: (1 1 0) i:0 -> b c h w: 1 1 0 0, dst_index: 3, posOut: (0 0 0) j:3 -> inval[0.376221] outval[0.376221] -> inval[0.376221 0.636719 0.768066 0.022751] outval[0.439453 0.171509 0.705566 0.376221] outval_after[0.439453 0.171509 0.705566 0.376221]
src_index:3, b c h w: 1 0 1 1, posIn: (1 1 0) i:3 -> b c h w: 1 1 0 1, dst_index: 3, posOut: (1 0 0) j:3 -> inval[0.022751] outval[0.022751] -> inval[0.376221 0.636719 0.768066 0.022751] outval[0.564941 0.494873 0.223755 0.022751] outval_after[0.564941 0.494873 0.223755 0.022751]
src_index:1, b c h w: 0 1 1 1, posIn: (1 1 0) i:1 -> b c h w: 1 1 1 0, dst_index: 3, posOut: (0 1 0) j:3 -> inval[0.636719] outval[0.636719] -> inval[0.376221 0.636719 0.768066 0.022751] outval[0.130859 0.222412 0.653809 0.636719] outval_after[0.130859 0.222412 0.653809 0.636719]
src_index:4, b c h w: 1 1 1 1, posIn: (1 1 1) i:0 -> b c h w: 1 1 1 1, dst_index: 3, posOut: (1 1 0) j:3 -> inval[0.440918] outval[0.440918] -> inval[0.440918 0.759766 -65504.000000 -65504.000000] outval[0.976562 0.325439 0.406250 0.440918] outval_after[0.976562 0.325439 0.406250 0.440918]
src_index:2, b c h w: 0 2 1 1, posIn: (1 1 0) i:2 -> b c h w: 1 1 2 0, dst_index: 3, posOut: (0 2 0) j:3 -> inval[0.768066] outval[0.768066] -> inval[0.376221 0.636719 0.768066 0.022751] outval[0.049011 0.075745 0.481201 0.768066] outval_after[0.049011 0.075745 0.481201 0.768066]
src_index:5, b c h w: 1 2 1 1, posIn: (1 1 1) i:1 -> b c h w: 1 1 2 1, dst_index: 3, posOut: (1 2 0) j:3 -> inval[0.759766] outval[0.759766] -> inval[0.440918 0.759766 -65504.000000 -65504.000000] outval[0.112671 0.714844 0.049957 0.759766] outval_after[0.112671 0.714844 0.049957 0.759766]
```
Test Plan:
Build & test on Android:
```
cd ~/fbsource
buck build -c ndk.custom_libcxx=false -c pt.enable_qpl=0 //xplat/caffe2:pt_vulkan_api_test_binAndroid\#android-arm64 --show-output
adb push buck-out/gen/xplat/caffe2/pt_vulkan_api_test_binAndroid\#android-arm64 /data/local/tmp/vulkan_api_test
adb shell "/data/local/tmp/vulkan_api_test"
```
Build & test on MacOS:
```
cd ~/fbsource
buck build //xplat/caffe2:pt_vulkan_api_test_binAppleMac
./buck-out/gen/xplat/caffe2/pt_vulkan_api_test_binAppleMac\#macosx-x86_64
```
Test result on Android (Google Pixel 5):
```
[ RUN ] VulkanAPITest.permute_2d_success
[ OK ] VulkanAPITest.permute_2d_success (26 ms)
[ RUN ] VulkanAPITest.permute_3d_success
[ OK ] VulkanAPITest.permute_3d_success (6 ms)
[ RUN ] VulkanAPITest.permute_4d_success
[ OK ] VulkanAPITest.permute_4d_success (10 ms)
[ RUN ] VulkanAPITest.permute_4dmclaren_success
[ OK ] VulkanAPITest.permute_4dmclaren_success (1 ms)
[ RUN ] VulkanAPITest.permute_4dbig_success
[ OK ] VulkanAPITest.permute_4dbig_success (234 ms)
[ RUN ] VulkanAPITest.permute_negativedims_success
[ OK ] VulkanAPITest.permute_negativedims_success (0 ms)
[ RUN ] VulkanAPITest.permute_1d_nochange
[ OK ] VulkanAPITest.permute_1d_nochange (0 ms)
[ RUN ] VulkanAPITest.permute_sameDims_nochange
[ OK ] VulkanAPITest.permute_sameDims_nochange (1 ms)
[ RUN ] VulkanAPITest.permute_invalidinputs_exceptions
[ OK ] VulkanAPITest.permute_invalidinputs_exceptions (1 ms)
```
Test result on MacOS:
```
[ RUN ] VulkanAPITest.permute_2d_success
[ OK ] VulkanAPITest.permute_2d_success (154 ms)
[ RUN ] VulkanAPITest.permute_3d_success
[ OK ] VulkanAPITest.permute_3d_success (13 ms)
[ RUN ] VulkanAPITest.permute_4d_success
[ OK ] VulkanAPITest.permute_4d_success (33 ms)
[ RUN ] VulkanAPITest.permute_4dmclaren_success
[ OK ] VulkanAPITest.permute_4dmclaren_success (2 ms)
[ RUN ] VulkanAPITest.permute_4dbig_success
[ OK ] VulkanAPITest.permute_4dbig_success (251 ms)
[ RUN ] VulkanAPITest.permute_negativedims_success
[ OK ] VulkanAPITest.permute_negativedims_success (2 ms)
[ RUN ] VulkanAPITest.permute_1d_nochange
[ OK ] VulkanAPITest.permute_1d_nochange (1 ms)
[ RUN ] VulkanAPITest.permute_sameDims_nochange
[ OK ] VulkanAPITest.permute_sameDims_nochange (0 ms)
[ RUN ] VulkanAPITest.permute_invalidinputs_exceptions
[ OK ] VulkanAPITest.permute_invalidinputs_exceptions (2 ms)
```
Reviewed By: SS-JIA
Differential Revision: D32292554
fbshipit-source-id: dbeaee6ff98633022cf34d6da90662d81eac6b0e