pytorch
e23a9dc1 - [C++ API] RNN / GRU / LSTM layer refactoring (#34322)

Commit
4 years ago
[C++ API] RNN / GRU / LSTM layer refactoring (#34322) Summary: This PR refactors RNN / GRU / LSTM layers in C++ API to exactly match the implementation in Python API. **BC-breaking changes:** - Instead of returning `RNNOutput`, RNN / GRU forward method now returns `std::tuple<Tensor, Tensor>`, and LSTM forward method now returns `std::tuple<Tensor, std::tuple<Tensor, Tensor>>`, matching Python API. - RNN / LSTM / GRU forward method now accepts the same inputs (input tensor and optionally hidden state), matching Python API. - RNN / LSTM / GRU now has `forward_with_packed_input` method which accepts `PackedSequence` as input and optionally hidden state, matching the `forward(PackedSequence, ...)` variant in Python API. - In `RNNOptions` - `tanh()` / `relu()` / `activation` are removed. Instead, `nonlinearity` is added which takes either `torch::kTanh` or `torch::kReLU` - `layers` -> `num_layers` - `with_bias` -> `bias` - In `LSTMOptions` - `layers` -> `num_layers` - `with_bias` -> `bias` - In `GRUOptions` - `layers` -> `num_layers` - `with_bias` -> `bias` The majority of the changes in this PR focused on refactoring the implementations in `torch/csrc/api/src/nn/modules/rnn.cpp` to match the Python API. RNN tests are then changed to reflected the revised API design. Pull Request resolved: https://github.com/pytorch/pytorch/pull/34322 Differential Revision: D20311699 Pulled By: yf225 fbshipit-source-id: e2b60fc7bac64367a8434647d74c08568a7b28f7
Author
Will Feng
Parents
Loading