pytorch
6226a3cf - [Vulkan] Implement permute operator (#68274)

Commit
4 years ago
[Vulkan] Implement permute operator (#68274) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68274 Implemented `permute` operator on the Vulkan backend: * Supports only <= 4D tensors. * Builds up shader operations from the output texture point of view to avoid the nondeterministic order of GPU shader operations between texels. See [incoherent memory access](https://www.khronos.org/opengl/wiki/Memory_Model#Incoherent_memory_access) * Generalized input tensors to 4D ones to simplify input/output texture handling. For example, {2, 3} is treated as {1,1,2,3} internally. * 1D to 4D inputs with all possible permutations are used for test cases. * Reference on CPU implementation of `permute` operator: [TensorShape.cpp](https://github.com/pytorch/pytorch/blob/cbf596bf8ee48f4ed895964fe4faf75a851b49c4/aten/src/ATen/native/TensorShape.cpp#L936) * When shuffling dims, a new depth size of output texture needs to be determined by `ceil(batch*channel)/4`. This logic needs to be handled in a separate change. * The depth of texture cannot exceed a certain number, depending on the device's capability. It is typically 2048 on most of android devices but less than or equal to 16,384 (see [Value distribution for maxImageDimension3D on Android](https://vulkan.gpuinfo.org/displaydevicelimit.php?name=maxImageDimension3D&platform=android)). i.e., 2048 on MacOS and Google Pixel 5. * Due to this limitation, `permute` op needs to throw an exception if the depth of output texture is greater than or equal to `VkImageFormatProperties.maxExtent.depth`. * Otherwise, the following error will occur: `-[MTLTextureDescriptorInternal validateWithDevice:]:1325: failed assertion "Texture Descriptor Validation MTLTextureDescriptor has depth (10664) greater than the maximum allowed size of 2048."` * Vulkan `permute` operator tensor conversion: {F679505029} {F679505223} * Vulkan `permute` operator shader equation: {F679504799} * Error/edge cases: ``` X = torch.randint(0, 23, (2, 3, 2, 2)) O = torch.permute(X, (2, 2, 1, 0)) # RuntimeError: repeated dim in permute O = torch.permute(X, (2, 1, 0)) # RuntimeError: number of dims don't match in permute O = torch.permute(X, (4, 3, 2, 1, 0)) # RuntimeError: number of dims don't match in permute O = torch.permute(X, (3, 2, -1, 0)) # RuntimeError: repeated dim in permute data2 = [0,1,2] X2 = torch.tensor(data2) O2 = torch.permute(X2, (0)) # permute(): argument 'dims' (position 2) must be tuple of ints, not int # TypeError: permute(): argument 'dims' (position 2) must be tuple of ints, not int O = torch.permute(X, (0, 1, 2, 3)) # do nothing since the dims doesn't change? ``` * Shader debug traces with a 4D tensor size {2,3,2,2} with permute by {3,2,1,0}: ``` output tensor: (1,1,.,.) = 0.4395 0.5652 0.1309 0.9768 0.0490 0.1127 (2,1,.,.) = 0.7058 0.2238 0.6542 0.4064 0.4813 0.0500 (1,2,.,.) = 0.1716 0.4951 0.2225 0.3255 0.0758 0.7150 (2,2,.,.) = 0.3762 0.0228 0.6367 0.4411 0.7682 0.7599 [ CPUFloatType{2,2,3,2} ] shader debug traces: src_index:0, b c h w: 0 0 0 0, posIn: (0 0 0) i:0 -> b c h w: 0 0 0 0, dst_index: 0, posOut: (0 0 0) j:0 -> inval[0.439453] outval[0.439453] -> inval[0.439453 0.130859 0.049011 0.564941] outval[0.439453 0.000000 0.000000 0.000000] src_index:3, b c h w: 1 0 0 0, posIn: (0 0 0) i:3 -> b c h w: 0 0 0 1, dst_index: 0, posOut: (1 0 0) j:0 -> inval[0.564941] outval[0.564941] -> inval[0.439453 0.130859 0.049011 0.564941] outval[0.564941 0.000000 0.000000 0.000000] src_index:1, b c h w: 0 1 0 0, posIn: (0 0 0) i:1 -> b c h w: 0 0 1 0, dst_index: 0, posOut: (0 1 0) j:0 -> inval[0.130859] outval[0.130859] -> inval[0.439453 0.130859 0.049011 0.564941] outval[0.130859 0.000000 0.000000 0.000000] src_index:4, b c h w: 1 1 0 0, posIn: (0 0 1) i:0 -> b c h w: 0 0 1 1, dst_index: 0, posOut: (1 1 0) j:0 -> inval[0.976562] outval[0.976562] -> inval[0.976562 0.112671 -65504.000000 -65504.000000] outval[0.976562 0.000000 0.000000 0.000000] src_index:2, b c h w: 0 2 0 0, posIn: (0 0 0) i:2 -> b c h w: 0 0 2 0, dst_index: 0, posOut: (0 2 0) j:0 -> inval[0.049011] outval[0.049011] -> inval[0.439453 0.130859 0.049011 0.564941] outval[0.049011 0.000000 0.000000 0.000000] src_index:5, b c h w: 1 2 0 0, posIn: (0 0 1) i:1 -> b c h w: 0 0 2 1, dst_index: 0, posOut: (1 2 0) j:0 -> inval[0.112671] outval[0.112671] -> inval[0.976562 0.112671 -65504.000000 -65504.000000] outval[0.112671 0.000000 0.000000 0.000000] src_index:0, b c h w: 0 0 1 0, posIn: (0 1 0) i:0 -> b c h w: 0 1 0 0, dst_index: 1, posOut: (0 0 0) j:1 -> inval[0.171509] outval[0.171509] -> inval[0.171509 0.222412 0.075745 0.494873] outval[0.439453 0.171509 0.000000 0.000000] src_index:3, b c h w: 1 0 1 0, posIn: (0 1 0) i:3 -> b c h w: 0 1 0 1, dst_index: 1, posOut: (1 0 0) j:1 -> inval[0.494873] outval[0.494873] -> inval[0.171509 0.222412 0.075745 0.494873] outval[0.564941 0.494873 0.000000 0.000000] src_index:1, b c h w: 0 1 1 0, posIn: (0 1 0) i:1 -> b c h w: 0 1 1 0, dst_index: 1, posOut: (0 1 0) j:1 -> inval[0.222412] outval[0.222412] -> inval[0.171509 0.222412 0.075745 0.494873] outval[0.130859 0.222412 0.000000 0.000000] src_index:4, b c h w: 1 1 1 0, posIn: (0 1 1) i:0 -> b c h w: 0 1 1 1, dst_index: 1, posOut: (1 1 0) j:1 -> inval[0.325439] outval[0.325439] -> inval[0.325439 0.714844 -65504.000000 -65504.000000] outval[0.976562 0.325439 0.000000 0.000000] src_index:2, b c h w: 0 2 1 0, posIn: (0 1 0) i:2 -> b c h w: 0 1 2 0, dst_index: 1, posOut: (0 2 0) j:1 -> inval[0.075745] outval[0.075745] -> inval[0.171509 0.222412 0.075745 0.494873] outval[0.049011 0.075745 0.000000 0.000000] src_index:5, b c h w: 1 2 1 0, posIn: (0 1 1) i:1 -> b c h w: 0 1 2 1, dst_index: 1, posOut: (1 2 0) j:1 -> inval[0.714844] outval[0.714844] -> inval[0.325439 0.714844 -65504.000000 -65504.000000] outval[0.112671 0.714844 0.000000 0.000000] src_index:0, b c h w: 0 0 0 1, posIn: (1 0 0) i:0 -> b c h w: 1 0 0 0, dst_index: 2, posOut: (0 0 0) j:2 -> inval[0.705566] outval[0.705566] -> inval[0.705566 0.653809 0.481201 0.223755] outval[0.439453 0.171509 0.705566 0.000000] src_index:3, b c h w: 1 0 0 1, posIn: (1 0 0) i:3 -> b c h w: 1 0 0 1, dst_index: 2, posOut: (1 0 0) j:2 -> inval[0.223755] outval[0.223755] -> inval[0.705566 0.653809 0.481201 0.223755] outval[0.564941 0.494873 0.223755 0.000000] src_index:1, b c h w: 0 1 0 1, posIn: (1 0 0) i:1 -> b c h w: 1 0 1 0, dst_index: 2, posOut: (0 1 0) j:2 -> inval[0.653809] outval[0.653809] -> inval[0.705566 0.653809 0.481201 0.223755] outval[0.130859 0.222412 0.653809 0.000000] src_index:4, b c h w: 1 1 0 1, posIn: (1 0 1) i:0 -> b c h w: 1 0 1 1, dst_index: 2, posOut: (1 1 0) j:2 -> inval[0.406250] outval[0.406250] -> inval[0.406250 0.049957 -65504.000000 -65504.000000] outval[0.976562 0.325439 0.406250 0.000000] src_index:2, b c h w: 0 2 0 1, posIn: (1 0 0) i:2 -> b c h w: 1 0 2 0, dst_index: 2, posOut: (0 2 0) j:2 -> inval[0.481201] outval[0.481201] -> inval[0.705566 0.653809 0.481201 0.223755] outval[0.049011 0.075745 0.481201 0.000000] src_index:5, b c h w: 1 2 0 1, posIn: (1 0 1) i:1 -> b c h w: 1 0 2 1, dst_index: 2, posOut: (1 2 0) j:2 -> inval[0.049957] outval[0.049957] -> inval[0.406250 0.049957 -65504.000000 -65504.000000] outval[0.112671 0.714844 0.049957 0.000000] src_index:0, b c h w: 0 0 1 1, posIn: (1 1 0) i:0 -> b c h w: 1 1 0 0, dst_index: 3, posOut: (0 0 0) j:3 -> inval[0.376221] outval[0.376221] -> inval[0.376221 0.636719 0.768066 0.022751] outval[0.439453 0.171509 0.705566 0.376221] outval_after[0.439453 0.171509 0.705566 0.376221] src_index:3, b c h w: 1 0 1 1, posIn: (1 1 0) i:3 -> b c h w: 1 1 0 1, dst_index: 3, posOut: (1 0 0) j:3 -> inval[0.022751] outval[0.022751] -> inval[0.376221 0.636719 0.768066 0.022751] outval[0.564941 0.494873 0.223755 0.022751] outval_after[0.564941 0.494873 0.223755 0.022751] src_index:1, b c h w: 0 1 1 1, posIn: (1 1 0) i:1 -> b c h w: 1 1 1 0, dst_index: 3, posOut: (0 1 0) j:3 -> inval[0.636719] outval[0.636719] -> inval[0.376221 0.636719 0.768066 0.022751] outval[0.130859 0.222412 0.653809 0.636719] outval_after[0.130859 0.222412 0.653809 0.636719] src_index:4, b c h w: 1 1 1 1, posIn: (1 1 1) i:0 -> b c h w: 1 1 1 1, dst_index: 3, posOut: (1 1 0) j:3 -> inval[0.440918] outval[0.440918] -> inval[0.440918 0.759766 -65504.000000 -65504.000000] outval[0.976562 0.325439 0.406250 0.440918] outval_after[0.976562 0.325439 0.406250 0.440918] src_index:2, b c h w: 0 2 1 1, posIn: (1 1 0) i:2 -> b c h w: 1 1 2 0, dst_index: 3, posOut: (0 2 0) j:3 -> inval[0.768066] outval[0.768066] -> inval[0.376221 0.636719 0.768066 0.022751] outval[0.049011 0.075745 0.481201 0.768066] outval_after[0.049011 0.075745 0.481201 0.768066] src_index:5, b c h w: 1 2 1 1, posIn: (1 1 1) i:1 -> b c h w: 1 1 2 1, dst_index: 3, posOut: (1 2 0) j:3 -> inval[0.759766] outval[0.759766] -> inval[0.440918 0.759766 -65504.000000 -65504.000000] outval[0.112671 0.714844 0.049957 0.759766] outval_after[0.112671 0.714844 0.049957 0.759766] ``` Test Plan: Build & test on Android: ``` cd ~/fbsource buck build -c ndk.custom_libcxx=false -c pt.enable_qpl=0 //xplat/caffe2:pt_vulkan_api_test_binAndroid\#android-arm64 --show-output adb push buck-out/gen/xplat/caffe2/pt_vulkan_api_test_binAndroid\#android-arm64 /data/local/tmp/vulkan_api_test adb shell "/data/local/tmp/vulkan_api_test" ``` Build & test on MacOS: ``` cd ~/fbsource buck build //xplat/caffe2:pt_vulkan_api_test_binAppleMac ./buck-out/gen/xplat/caffe2/pt_vulkan_api_test_binAppleMac\#macosx-x86_64 ``` Test result on Android (Google Pixel 5): ``` [ RUN ] VulkanAPITest.permute_2d_success [ OK ] VulkanAPITest.permute_2d_success (26 ms) [ RUN ] VulkanAPITest.permute_3d_success [ OK ] VulkanAPITest.permute_3d_success (6 ms) [ RUN ] VulkanAPITest.permute_4d_success [ OK ] VulkanAPITest.permute_4d_success (10 ms) [ RUN ] VulkanAPITest.permute_4dmclaren_success [ OK ] VulkanAPITest.permute_4dmclaren_success (1 ms) [ RUN ] VulkanAPITest.permute_4dbig_success [ OK ] VulkanAPITest.permute_4dbig_success (234 ms) [ RUN ] VulkanAPITest.permute_negativedims_success [ OK ] VulkanAPITest.permute_negativedims_success (0 ms) [ RUN ] VulkanAPITest.permute_1d_nochange [ OK ] VulkanAPITest.permute_1d_nochange (0 ms) [ RUN ] VulkanAPITest.permute_sameDims_nochange [ OK ] VulkanAPITest.permute_sameDims_nochange (1 ms) [ RUN ] VulkanAPITest.permute_invalidinputs_exceptions [ OK ] VulkanAPITest.permute_invalidinputs_exceptions (1 ms) ``` Test result on MacOS: ``` [ RUN ] VulkanAPITest.permute_2d_success [ OK ] VulkanAPITest.permute_2d_success (154 ms) [ RUN ] VulkanAPITest.permute_3d_success [ OK ] VulkanAPITest.permute_3d_success (13 ms) [ RUN ] VulkanAPITest.permute_4d_success [ OK ] VulkanAPITest.permute_4d_success (33 ms) [ RUN ] VulkanAPITest.permute_4dmclaren_success [ OK ] VulkanAPITest.permute_4dmclaren_success (2 ms) [ RUN ] VulkanAPITest.permute_4dbig_success [ OK ] VulkanAPITest.permute_4dbig_success (251 ms) [ RUN ] VulkanAPITest.permute_negativedims_success [ OK ] VulkanAPITest.permute_negativedims_success (2 ms) [ RUN ] VulkanAPITest.permute_1d_nochange [ OK ] VulkanAPITest.permute_1d_nochange (1 ms) [ RUN ] VulkanAPITest.permute_sameDims_nochange [ OK ] VulkanAPITest.permute_sameDims_nochange (0 ms) [ RUN ] VulkanAPITest.permute_invalidinputs_exceptions [ OK ] VulkanAPITest.permute_invalidinputs_exceptions (2 ms) ``` Reviewed By: SS-JIA Differential Revision: D32292554 fbshipit-source-id: dbeaee6ff98633022cf34d6da90662d81eac6b0e
Author
Parents
Loading