[Vulkan] Optimize GRU operator with pre-packing (#73599)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73599
Optimized GRU operator by using pre-packing for weights and biases in the Vulkan GPU backend:
* The weights and biases are always on the CPU side by design.
* To reduce the overhead by retrieving the weight and bias tensors every time, it is the best way to store them by pre-packing.
* A custom op context `GruOpContext` (derived from `torch::jit::CustomClassHolder`) is created to hold both packed and unpacked data. It corresponds to the unpacked_ struct which represents the data needed to construct the op context. This data will be pre-packed and be stored in the packed_ struct. The constructor of the `GruOpContext` loads the data into the unpacked_ and packed_ structs.
* `at::native::vulkan::ops::gru_prepack` and `at::native::vulkan::ops::gru_run` methods use the op context. The `gru_prepack` takes in whatever data is needed to construct the op context and returns a pointer to a created context. The `gru_run` takes input tensors and a pointer to the op context that uses the data stored in the context to process the inputs.
* Lastly, we need to register the op context class and ops in [Register.cpp](https://github.com/pytorch/pytorch/blob/11dc1581298c5bb2b322897c7b3999d1a3971720/aten/src/ATen/native/vulkan/ops/Register.cpp). And rewrite the subgraph function of GRU op in [vulkan_rewrite.cpp](https://github.com/pytorch/pytorch/blob/11dc1581298c5bb2b322897c7b3999d1a3971720/torch/csrc/jit/passes/vulkan_rewrite.cpp) so that `gru_prepack` and `gru_run` ops can be executed instead in the Vulkan GPU backend.
* To avoid `"Undefined symbols for architecture x86_64"` compiler error on the x86_64 platform, `c10::Dispatcher::callBoxed()` API is used to call `vulkan_prepack::gru_prepack` and `vulkan_prepack::gru_run` by name. Otherwise, the test methods can't resolve the symbols.
* Added new tests for the GRU pre-packing and run operations: `gru_prepack_success` and gru_prepack_invalidinputs_exceptions`
* To build your PyTorch OSS on your local machine:
```
python setup.py clean
git submodule update --init --recursive
USE_VULKAN=1 USE_VULKAN_FP16_INFERENCE=1 python3 setup.py install --cmake
python setup.py develop && python -c "import torch"
```
* To run and dump a model containing GRU operators in Python:
```
import torch
from torch.utils import mobile_optimizer
model = torch.jit.load("Mclaren_traced.pt")
vk_model = mobile_optimizer.optimize_for_mobile(model, backend="vulkan")
print(vk_model.graph)
```
* The following torch scripts are the updated version by GRU pre-packing:
```
%15 : Tensor[] = prim::ListConstruct(%weight_ih_l0.1, %weight_hh_l0.1, %bias_ih_l0.1, %bias_hh_l0.1, %weight_ih_l1.1, %weight_hh_l1.1, %bias_ih_l1.1, %bias_hh_l1.1)
%19 : __torch__.torch.classes.vulkan.GruOpContext = vulkan_prepack::gru_prepack(%15, %4, %5, %6, %3, %3, %4)
%20 : Tensor, %21 : Tensor = vulkan_prepack::gru_run(%input.1, %hx.1, %19)
%18 : (Tensor, Tensor) = prim::TupleConstruct(%21, %20)
return (%18)
```
* This implementation has some limitations:
* Tensor dim should be 3 for input sequence and hidden state.
* has_biases=True
* train=False
* bidirectional=False
* batch_first=True
* dropout=0.0
* D=1 since bidirectional=False
* N=1 (batch size)
* L=1 (sequence length)
Test Plan:
Build & test on Android:
```
cd ~/fbsource
buck build -c ndk.custom_libcxx=false -c pt.enable_qpl=0 //xplat/caffe2:pt_vulkan_api_test_binAndroid\#android-arm64 --show-output
adb push buck-out/gen/xplat/caffe2/pt_vulkan_api_test_binAndroid\#android-arm64 /data/local/tmp/vulkan_api_test
adb shell "/data/local/tmp/vulkan_api_test"
```
Build & test on MacOS (x86_64):
```
cd ~/fbsource
buck build //xplat/caffe2:pt_vulkan_api_test_binAppleMac
./buck-out/gen/xplat/caffe2/pt_vulkan_api_test_binAppleMac\#macosx-x86_64
```
Test result on Android (Google Pixel 5):
```
Running main() from gtest_main.cc
[==========] Running 4 tests from 1 test case.
[----------] Global test environment set-up.
[----------] 4 tests from VulkanAPITest
[ RUN ] VulkanAPITest.gru_mclareninputs_success
[ OK ] VulkanAPITest.gru_mclareninputs_success (1037 ms)
[ RUN ] VulkanAPITest.gru_invalidinputs_exceptions
[ OK ] VulkanAPITest.gru_invalidinputs_exceptions (16 ms)
[ RUN ] VulkanAPITest.gru_prepack_success
[ OK ] VulkanAPITest.gru_prepack_success (45 ms)
[ RUN ] VulkanAPITest.gru_prepack_invalidinputs_exceptions
[ OK ] VulkanAPITest.gru_prepack_invalidinputs_exceptions (16 ms)
[----------] 4 tests from VulkanAPITest (1114 ms total)
[----------] Global test environment tear-down
[==========] 4 tests from 1 test case ran. (1114 ms total)
[ PASSED ] 4 tests.
```
Test result on MacOS (x86_64):
```
Running main() from gtest_main.cc
[==========] Running 4 tests from 1 test case.
[----------] Global test environment set-up.
[----------] 4 tests from VulkanAPITest
[ RUN ] VulkanAPITest.gru_mclareninputs_success
[ OK ] VulkanAPITest.gru_mclareninputs_success (1012 ms)
[ RUN ] VulkanAPITest.gru_invalidinputs_exceptions
[ OK ] VulkanAPITest.gru_invalidinputs_exceptions (40 ms)
[ RUN ] VulkanAPITest.gru_prepack_success
[ OK ] VulkanAPITest.gru_prepack_success (99 ms)
[ RUN ] VulkanAPITest.gru_prepack_invalidinputs_exceptions
[ OK ] VulkanAPITest.gru_prepack_invalidinputs_exceptions (39 ms)
[----------] 4 tests from VulkanAPITest (1190 ms total)
[----------] Global test environment tear-down
[==========] 4 tests from 1 test case ran. (1190 ms total)
[ PASSED ] 4 tests.
```
Reviewed By: SS-JIA
Differential Revision: D34556940
fbshipit-source-id: dce918de238fb8a4a0ea5e966e05ca99ed910c28
(cherry picked from commit cd1d95ff8d0fa7810cf18a54ba64539e46daa26a)