pytorch
79dc500a - Add error message for sequence length to be equal to 0 case for RNNs (#60269)

Commit
4 years ago
Add error message for sequence length to be equal to 0 case for RNNs (#60269) Summary: Fixes #https://github.com/pytorch/pytorch/issues/50192 It has been discussed in the issue that, currently RNN apis do not support inputs with `seq_len=0` and the error message does not reflect this issue clearly. This PR is suggesting a solution to this issue, by adding a more clear error message that, none of RNN api (nn.RNN, nn.GRU and nn.LSTM) do not support `seq_len=0` for neither one-directional nor bi-directional layers. ``` import torch input_size = 5 hidden_size = 6 rnn = torch.nn.GRU(input_size, hidden_size) for seq_len in reversed(range(4)): output, h_n = rnn(torch.zeros(seq_len, 10, input_size)) print('{}, {}'.format(output.shape, h_n.shape)) ``` Previously was giving output as : ``` torch.Size([3, 10, 6]), torch.Size([1, 10, 6]) torch.Size([2, 10, 6]), torch.Size([1, 10, 6]) torch.Size([1, 10, 6]), torch.Size([1, 10, 6]) Traceback (most recent call last): File "test.py", line 8, in <module> output, h_n = rnn(torch.zeros(seq_len, 10, input_size)) File "/opt/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/opt/miniconda3/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 739, in forward result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers, RuntimeError: stack expects a non-empty TensorList ``` However, after adding this PR, this error message change for any combination of [RNN, GRU and LSTM] x [one-directional, bi-directional]. Let's illustrate the change with the following code snippet: ``` import torch input_size = 5 hidden_size = 6 rnn = torch.nn.LSTM(input_size, hidden_size, bidirectional=True) output, h_n = rnn(torch.zeros(0, 10, input_size)) ``` would give output as following: ``` Traceback (most recent call last): File "<stdin>", line 2, in <module> File "/fsx/users/iramazanli/pytorch/torch/nn/modules/module.py", line 1054, in _call_impl return forward_call(*input, **kwargs) File "/fsx/users/iramazanli/pytorch/torch/nn/modules/rnn.py", line 837, in forward result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers, RuntimeError: Expected sequence length to be larger than 0 in RNN ``` *********************************** The change for Packed Sequence didn't seem to be necessary because from the following code snippet error message looks clear about the issue: ``` import torch import torch.nn.utils.rnn as rnn_utils import torch.nn as nn packed = rnn_utils.pack_sequence([]) ``` returns: ``` Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/fsx/users/iramazanli/pytorch/torch/nn/utils/rnn.py", line 398, in pack_sequence return pack_padded_sequence(pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted) File "/fsx/users/iramazanli/pytorch/torch/nn/utils/rnn.py", line 363, in pad_sequence return torch._C._nn.pad_sequence(sequences, batch_first, padding_value) RuntimeError: received an empty list of sequences ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/60269 Reviewed By: mrshenli Differential Revision: D29299914 Pulled By: iramazanli fbshipit-source-id: 5ca98faa28d4e6a5a2f7600a30049de384a3b132
Author
Parents
Loading