Add option to automatically handle unsorted variable-length sequences in RNNs (#15225)
Summary:
Fixes #3584.
Motivation: manually sorting sequences, packing them, and then unsorting them
is something a lot of users have complained about doing, especially when we can
offer library support for them.
Overview: we internally sort sequences before packing them and store a list of
`unsorted_indices` that represent how to unsort the sequences inside
PackedSequence. The packing helper functions return PackedSequence with the
`permutation` field and the unpacking helper functions use it to unsort.
To implement this, the following changes were made:
- PackedSequence now keeps `sorted_indices` and `unsorted_indices`.
These two can be thought of as permutations and are inverses of each other.
`sorted_indices` is how the sequences were sorted; `unsorted_indices` is how
to unsort the sequences.
- Added an `enforce_sorted` argument to pack_sequence and pack_padded_sequence
that maintains the legacy behavior of error-ing out on unsorted-sequences.
When `enforce_sorted=True`, these functions maintain their ONNX exportability.
- pack_sequence(sequences, enforce_sorted) takes in unsorted sequences.
- pack_padded_sequence can take in a padded tensor that represents padded,
unsorted sequences.
- pad_packed_sequence unsorts the PackedSequence such that it is still the
inverse operation of packed_padded_sequence.
- RNNs apply `sort_indices` to their input hidden state and apply
`unsort_indices` to their output hidden state. This is to ensure that the
hidden state batches correspond to the user's ordering of input sequences.
NOT BC-Breaking
- The default for pack_sequence and pack_padded_sequence is
`enforce_sorted=True` to avoid breaking ONNX export. To use the new
functionality, pass in `enforce_sorted=False`
Testing Plan
- Modified TestNN.test_pack_sequence, TestNN.test_packed_padded_sequence,
and TestNN.test_variable_sequence (RNN test) to check the behavior
of unsorted sequences, sorted sequences, and sorted sequences with
enforce_sorted=True
- test/test_jit.py has a test to see if RNNs are exportable with
enforce_sorted=True
cc colesbury
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15225
Reviewed By: soumith
Differential Revision: D13507138
Pulled By: zou3519
fbshipit-source-id: b871dccd6abefffca81bc4e3efef1873faa242ef