[Vulkan] Implement LSTM operator (#78943)
Summary:
Implemented LSTM operator for the Vulkan backend:
Inputs:
- `input_vk`: input tensor containing the features of the input sequence. It has shape (L, N, H_in) when batch_first=False or (N, L, H_in) when batch_first=True
- `hx`: list of two tensors: `hx_vk` and `cx_vk`
- `hx_vk`: tensor of shape (D * num_layers, N, H_out) containing the initial hidden state for each element in the input sequence.
- `cx_vk`: tensor of shape (D * num_layers, N, H_cell) containing the initial cell state for each element in the input sequence.
- `params_cpu`: list of tensors containing the weights/biases of size 4 * num_layers. There should be 2 weights and 2 biases per layer. See [LSTM >> Variables](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html) section.
- `has_biases`: This initial implementation only supports has_biases=True
- `num_layers`: Number of recurrent layers.
- `dropout`: This initial implementation only supports dropout=0.0
- `train`: This initial implementation only supports train=False
- `bidirectional`: This initial implementation only supports bidirectional=False
- `batch_first`: This initial implementation only supports batch_first=True
Outputs:
- `output`: tensor of shape (L, N, D * H_out) when batch_first=False or (N, L, D * H_out) when batch_first=True containing the output features (h_t) from the last layer of the LSTM, for each t
- `h_n`: tensor of shape (D * num_layers, N, H_out) containing the final hidden state for each element in the sequence.
- `c_n`: tensor of shape (D * num_layers, N, H_cell) containing the final cell state for each element in the sequence.
Notation:
- L = sequence length.
- N = batch size.
- D = 2 if bidirectional=True otherwise 1
- H_in = input_size (# of expected features in the input x)
- H_cell = hidden_size (# of features in the hidden state h)
- H_out = hidden_size
This initial implementation only supports L = 1 & N = 1.
References
- PyTorch Docs > torch.nn > [LSTM](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html)
- Dive into Deep Learning > [9.2. Long Short-Term Memory (LSTM)](https://d2l.ai/chapter_recurrent-modern/lstm.html)
- [Long Short-Term Memory: From Zero to Hero with PyTorch](https://blog.floydhub.com/long-short-term-memory-from-zero-to-hero-with-pytorch/)
Differential Revision: D36882986
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78943
Approved by: https://github.com/SS-JIA