pytorch
a9a9b6c8 - [Vulkan] Implement LSTM operator (#78943)

Commit
3 years ago
[Vulkan] Implement LSTM operator (#78943) Summary: Implemented LSTM operator for the Vulkan backend: Inputs: - `input_vk`: input tensor containing the features of the input sequence. It has shape (L, N, H_in) when batch_first=False or (N, L, H_in) when batch_first=True - `hx`: list of two tensors: `hx_vk` and `cx_vk` - `hx_vk`: tensor of shape (D * num_layers, N, H_out) containing the initial hidden state for each element in the input sequence. - `cx_vk`: tensor of shape (D * num_layers, N, H_cell) containing the initial cell state for each element in the input sequence. - `params_cpu`: list of tensors containing the weights/biases of size 4 * num_layers. There should be 2 weights and 2 biases per layer. See [LSTM >> Variables](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html) section. - `has_biases`: This initial implementation only supports has_biases=True - `num_layers`: Number of recurrent layers. - `dropout`: This initial implementation only supports dropout=0.0 - `train`: This initial implementation only supports train=False - `bidirectional`: This initial implementation only supports bidirectional=False - `batch_first`: This initial implementation only supports batch_first=True Outputs: - `output`: tensor of shape (L, N, D * H_out) when batch_first=False or (N, L, D * H_out) when batch_first=True containing the output features (h_t) from the last layer of the LSTM, for each t - `h_n`: tensor of shape (D * num_layers, N, H_out) containing the final hidden state for each element in the sequence. - `c_n`: tensor of shape (D * num_layers, N, H_cell) containing the final cell state for each element in the sequence. Notation: - L = sequence length. - N = batch size. - D = 2 if bidirectional=True otherwise 1 - H_in = input_size (# of expected features in the input x) - H_cell = hidden_size (# of features in the hidden state h) - H_out = hidden_size This initial implementation only supports L = 1 & N = 1. References - PyTorch Docs > torch.nn > [LSTM](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html) - Dive into Deep Learning > [9.2. Long Short-Term Memory (LSTM)](https://d2l.ai/chapter_recurrent-modern/lstm.html) - [Long Short-Term Memory: From Zero to Hero with PyTorch](https://blog.floydhub.com/long-short-term-memory-from-zero-to-hero-with-pytorch/) Differential Revision: D36882986 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78943 Approved by: https://github.com/SS-JIA
Committer
Parents
Loading