pytorch
e5938165 - [Vulkan] Implement GRU operator (#72692)

Commit
2 years ago
[Vulkan] Implement GRU operator (#72692) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72692 Implemented GRU operator in the Vulkan GPU backend: * This is an initial implementation to support an internal model. * Internal name for GRU is `aten::gru.input` * There should be 2 weights and 2 biases per layer. See [GRU >> Variables](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html) section * For num_layers=1 the weights should contain [weight_ih, weight_hh, bias_ih, bias_hh] (4 elements) * Need to reshape input and hidden state to 2D since Vulkan `mm` and `addmm` ops accept only 2D dim * By design, all weights and biases should be on the CPU where input sequence and hidden state should be on the Vulkan GPU. * Input arguments and return values: * `input_vk`: input tensor of shape (L, N, H_in) when batch_first=False or (N, L, H_in) when batch_first=True containing the features of the input sequence * `hx_vk`: initial hidden state for each element in the batch. tensor of shape (D * num_layers, N, H_out) * `output`: tensor of shape (N, L, D * H_out)) when batch_first=True * `h_n`: tensor of shape (D * num_layers, N, H_out) * where * L = sequence length * N = batch size * D = 2 if bidirectional=True otherwise 1 * H_in = input_size (# of expected features in the input x) * H_out = hidden_size (# of features in the hidden state h) * This initial implementation has some limitations: * Tensor dim should be 3 for input sequence and hidden state. * has_biases=True * train=False * bidirectional=False * batch_first=True * dropout=0.0 * D=1 since bidirectional=False * N=1 (batch size) * L=1 (sequence length) * GRU high-level python code: ``` import torch from torch import nn import numpy as np import math H_in = 10 H_out = 10 num_layers = 2 D = 1 gru = nn.GRU(H_in, H_out, num_layers) input = torch.randn(1, 1, H_in) h0 = torch.randn(D * num_layers, 1, H_out) output, h_n = gru(input, h0) print(output) print(h_n) print(gru._all_weights) # the same result can be calculated directly x = input output = x h_n = [] for i in range(num_layers): h = h0[i] W_ih, W_hh, b_ih, b_hh = gru._flat_weights[i * 4 : (i + 1) * 4] W_ir, W_iz, W_in = W_ih.split(H_in) W_hr, W_hz, W_hn = W_hh.split(H_in) b_ir, b_iz, b_in = b_ih.split(H_in) b_hr, b_hz, b_hn = b_hh.split(H_in) r = torch.sigmoid(x @ W_ir.T + b_ir + h @ W_hr.T + b_hr) z = torch.sigmoid(x @ W_iz.T + b_iz + h @ W_hz.T + b_hz) n = torch.tanh(x @ W_in.T + b_in + r * (h @ W_hn.T + b_hn)) h = (1 - z) * n + z * h x = h output = x h_n.append(h[0]) print(output) print(h_n) ``` * References * PyTorch Docs > torch.nn > [GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html) * Dive into Deep Learning > [9.1. Gated Recurrent Units (GRU)](https://d2l.ai/chapter_recurrent-modern/gru.html) * [Gated Recurrent Unit (GRU) With PyTorch](https://blog.floydhub.com/gru-with-pytorch/) * [From GRU to Transformer](https://ogunlao.github.io/blog/2020/06/12/from_gru_to_transformer.html) Test Plan: Build & test on Android: ``` cd ~/fbsource buck build -c ndk.custom_libcxx=false -c pt.enable_qpl=0 //xplat/caffe2:pt_vulkan_api_test_binAndroid\#android-arm64 --show-output adb push buck-out/gen/xplat/caffe2/pt_vulkan_api_test_binAndroid\#android-arm64 /data/local/tmp/vulkan_api_test adb shell "/data/local/tmp/vulkan_api_test" ``` Test result on Android (Google Pixel 5): ``` [ RUN ] VulkanAPITest.gru_mclareninputs_success [ OK ] VulkanAPITest.gru_mclareninputs_success (59 ms) [ RUN ] VulkanAPITest.gru_invalidinputs_exceptions [ OK ] VulkanAPITest.gru_invalidinputs_exceptions (17 ms) ``` Reviewed By: SS-JIA Differential Revision: D33995221 fbshipit-source-id: d7875298ec37425c7eb1df34d163178b61a84fc9 (cherry picked from commit 90aa32915d8bcd87df1efd3e20b2b07ccffd4677)
Author
Committer
Parents
Loading