[Vulkan] Implement GRU operator (#72692)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72692
Implemented GRU operator in the Vulkan GPU backend:
* This is an initial implementation to support an internal model.
* Internal name for GRU is `aten::gru.input`
* There should be 2 weights and 2 biases per layer. See [GRU >> Variables](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html) section
* For num_layers=1 the weights should contain [weight_ih, weight_hh, bias_ih, bias_hh] (4 elements)
* Need to reshape input and hidden state to 2D since Vulkan `mm` and `addmm` ops accept only 2D dim
* By design, all weights and biases should be on the CPU where input sequence and hidden state should be on the Vulkan GPU.
* Input arguments and return values:
* `input_vk`: input tensor of shape (L, N, H_in) when batch_first=False or (N, L, H_in) when batch_first=True containing the features of the input sequence
* `hx_vk`: initial hidden state for each element in the batch. tensor of shape (D * num_layers, N, H_out)
* `output`: tensor of shape (N, L, D * H_out)) when batch_first=True
* `h_n`: tensor of shape (D * num_layers, N, H_out)
* where
* L = sequence length
* N = batch size
* D = 2 if bidirectional=True otherwise 1
* H_in = input_size (# of expected features in the input x)
* H_out = hidden_size (# of features in the hidden state h)
* This initial implementation has some limitations:
* Tensor dim should be 3 for input sequence and hidden state.
* has_biases=True
* train=False
* bidirectional=False
* batch_first=True
* dropout=0.0
* D=1 since bidirectional=False
* N=1 (batch size)
* L=1 (sequence length)
* GRU high-level python code:
```
import torch
from torch import nn
import numpy as np
import math
H_in = 10
H_out = 10
num_layers = 2
D = 1
gru = nn.GRU(H_in, H_out, num_layers)
input = torch.randn(1, 1, H_in)
h0 = torch.randn(D * num_layers, 1, H_out)
output, h_n = gru(input, h0)
print(output)
print(h_n)
print(gru._all_weights)
# the same result can be calculated directly
x = input
output = x
h_n = []
for i in range(num_layers):
h = h0[i]
W_ih, W_hh, b_ih, b_hh = gru._flat_weights[i * 4 : (i + 1) * 4]
W_ir, W_iz, W_in = W_ih.split(H_in)
W_hr, W_hz, W_hn = W_hh.split(H_in)
b_ir, b_iz, b_in = b_ih.split(H_in)
b_hr, b_hz, b_hn = b_hh.split(H_in)
r = torch.sigmoid(x @ W_ir.T + b_ir + h @ W_hr.T + b_hr)
z = torch.sigmoid(x @ W_iz.T + b_iz + h @ W_hz.T + b_hz)
n = torch.tanh(x @ W_in.T + b_in + r * (h @ W_hn.T + b_hn))
h = (1 - z) * n + z * h
x = h
output = x
h_n.append(h[0])
print(output)
print(h_n)
```
* References
* PyTorch Docs > torch.nn > [GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html)
* Dive into Deep Learning > [9.1. Gated Recurrent Units (GRU)](https://d2l.ai/chapter_recurrent-modern/gru.html)
* [Gated Recurrent Unit (GRU) With PyTorch](https://blog.floydhub.com/gru-with-pytorch/)
* [From GRU to Transformer](https://ogunlao.github.io/blog/2020/06/12/from_gru_to_transformer.html)
Test Plan:
Build & test on Android:
```
cd ~/fbsource
buck build -c ndk.custom_libcxx=false -c pt.enable_qpl=0 //xplat/caffe2:pt_vulkan_api_test_binAndroid\#android-arm64 --show-output
adb push buck-out/gen/xplat/caffe2/pt_vulkan_api_test_binAndroid\#android-arm64 /data/local/tmp/vulkan_api_test
adb shell "/data/local/tmp/vulkan_api_test"
```
Test result on Android (Google Pixel 5):
```
[ RUN ] VulkanAPITest.gru_mclareninputs_success
[ OK ] VulkanAPITest.gru_mclareninputs_success (59 ms)
[ RUN ] VulkanAPITest.gru_invalidinputs_exceptions
[ OK ] VulkanAPITest.gru_invalidinputs_exceptions (17 ms)
```
Reviewed By: SS-JIA
Differential Revision: D33995221
fbshipit-source-id: d7875298ec37425c7eb1df34d163178b61a84fc9
(cherry picked from commit 90aa32915d8bcd87df1efd3e20b2b07ccffd4677)