pytorch
6e30d1c5 - [Vulkan] Optimize GRU operator with pre-packing (#73599)

Commit
2 years ago
[Vulkan] Optimize GRU operator with pre-packing (#73599) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73599 Optimized GRU operator by using pre-packing for weights and biases in the Vulkan GPU backend: * The weights and biases are always on the CPU side by design. * To reduce the overhead by retrieving the weight and bias tensors every time, it is the best way to store them by pre-packing. * A custom op context `GruOpContext` (derived from `torch::jit::CustomClassHolder`) is created to hold both packed and unpacked data. It corresponds to the unpacked_ struct which represents the data needed to construct the op context. This data will be pre-packed and be stored in the packed_ struct. The constructor of the `GruOpContext` loads the data into the unpacked_ and packed_ structs. * `at::native::vulkan::ops::gru_prepack` and `at::native::vulkan::ops::gru_run` methods use the op context. The `gru_prepack` takes in whatever data is needed to construct the op context and returns a pointer to a created context. The `gru_run` takes input tensors and a pointer to the op context that uses the data stored in the context to process the inputs. * Lastly, we need to register the op context class and ops in [Register.cpp](https://github.com/pytorch/pytorch/blob/11dc1581298c5bb2b322897c7b3999d1a3971720/aten/src/ATen/native/vulkan/ops/Register.cpp). And rewrite the subgraph function of GRU op in [vulkan_rewrite.cpp](https://github.com/pytorch/pytorch/blob/11dc1581298c5bb2b322897c7b3999d1a3971720/torch/csrc/jit/passes/vulkan_rewrite.cpp) so that `gru_prepack` and `gru_run` ops can be executed instead in the Vulkan GPU backend. * To avoid `"Undefined symbols for architecture x86_64"` compiler error on the x86_64 platform, `c10::Dispatcher::callBoxed()` API is used to call `vulkan_prepack::gru_prepack` and `vulkan_prepack::gru_run` by name. Otherwise, the test methods can't resolve the symbols. * Added new tests for the GRU pre-packing and run operations: `gru_prepack_success` and gru_prepack_invalidinputs_exceptions` * To build your PyTorch OSS on your local machine: ``` python setup.py clean git submodule update --init --recursive USE_VULKAN=1 USE_VULKAN_FP16_INFERENCE=1 python3 setup.py install --cmake python setup.py develop && python -c "import torch" ``` * To run and dump a model containing GRU operators in Python: ``` import torch from torch.utils import mobile_optimizer model = torch.jit.load("Mclaren_traced.pt") vk_model = mobile_optimizer.optimize_for_mobile(model, backend="vulkan") print(vk_model.graph) ``` * The following torch scripts are the updated version by GRU pre-packing: ``` %15 : Tensor[] = prim::ListConstruct(%weight_ih_l0.1, %weight_hh_l0.1, %bias_ih_l0.1, %bias_hh_l0.1, %weight_ih_l1.1, %weight_hh_l1.1, %bias_ih_l1.1, %bias_hh_l1.1) %19 : __torch__.torch.classes.vulkan.GruOpContext = vulkan_prepack::gru_prepack(%15, %4, %5, %6, %3, %3, %4) %20 : Tensor, %21 : Tensor = vulkan_prepack::gru_run(%input.1, %hx.1, %19) %18 : (Tensor, Tensor) = prim::TupleConstruct(%21, %20) return (%18) ``` * This implementation has some limitations: * Tensor dim should be 3 for input sequence and hidden state. * has_biases=True * train=False * bidirectional=False * batch_first=True * dropout=0.0 * D=1 since bidirectional=False * N=1 (batch size) * L=1 (sequence length) Test Plan: Build & test on Android: ``` cd ~/fbsource buck build -c ndk.custom_libcxx=false -c pt.enable_qpl=0 //xplat/caffe2:pt_vulkan_api_test_binAndroid\#android-arm64 --show-output adb push buck-out/gen/xplat/caffe2/pt_vulkan_api_test_binAndroid\#android-arm64 /data/local/tmp/vulkan_api_test adb shell "/data/local/tmp/vulkan_api_test" ``` Build & test on MacOS (x86_64): ``` cd ~/fbsource buck build //xplat/caffe2:pt_vulkan_api_test_binAppleMac ./buck-out/gen/xplat/caffe2/pt_vulkan_api_test_binAppleMac\#macosx-x86_64 ``` Test result on Android (Google Pixel 5): ``` Running main() from gtest_main.cc [==========] Running 4 tests from 1 test case. [----------] Global test environment set-up. [----------] 4 tests from VulkanAPITest [ RUN ] VulkanAPITest.gru_mclareninputs_success [ OK ] VulkanAPITest.gru_mclareninputs_success (1037 ms) [ RUN ] VulkanAPITest.gru_invalidinputs_exceptions [ OK ] VulkanAPITest.gru_invalidinputs_exceptions (16 ms) [ RUN ] VulkanAPITest.gru_prepack_success [ OK ] VulkanAPITest.gru_prepack_success (45 ms) [ RUN ] VulkanAPITest.gru_prepack_invalidinputs_exceptions [ OK ] VulkanAPITest.gru_prepack_invalidinputs_exceptions (16 ms) [----------] 4 tests from VulkanAPITest (1114 ms total) [----------] Global test environment tear-down [==========] 4 tests from 1 test case ran. (1114 ms total) [ PASSED ] 4 tests. ``` Test result on MacOS (x86_64): ``` Running main() from gtest_main.cc [==========] Running 4 tests from 1 test case. [----------] Global test environment set-up. [----------] 4 tests from VulkanAPITest [ RUN ] VulkanAPITest.gru_mclareninputs_success [ OK ] VulkanAPITest.gru_mclareninputs_success (1012 ms) [ RUN ] VulkanAPITest.gru_invalidinputs_exceptions [ OK ] VulkanAPITest.gru_invalidinputs_exceptions (40 ms) [ RUN ] VulkanAPITest.gru_prepack_success [ OK ] VulkanAPITest.gru_prepack_success (99 ms) [ RUN ] VulkanAPITest.gru_prepack_invalidinputs_exceptions [ OK ] VulkanAPITest.gru_prepack_invalidinputs_exceptions (39 ms) [----------] 4 tests from VulkanAPITest (1190 ms total) [----------] Global test environment tear-down [==========] 4 tests from 1 test case ran. (1190 ms total) [ PASSED ] 4 tests. ``` Reviewed By: SS-JIA Differential Revision: D34556940 fbshipit-source-id: dce918de238fb8a4a0ea5e966e05ca99ed910c28 (cherry picked from commit cd1d95ff8d0fa7810cf18a54ba64539e46daa26a)
Author
Committer
Parents
Loading