pytorch
d101816f - [ONNX] RNN scripting (#57564) (#58691)

Commit
3 years ago
[ONNX] RNN scripting (#57564) (#58691) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58691 Note the first commit in this PR has its own pull request here since it seemed self-contained: https://github.com/pytorch/pytorch/pull/57082 * [ONNX] simplify batch_first logic in RNN tests * [ONNX] support GRU with packed input in scripting mode This required two changes: * Add as_tensor to symbolic_opset9.py * Change torch::jit::pushPackingPastRnn to recognize and properly replace another use of the batch_sizes output of prim::PackPadded. Previously the code assumed that the first use was as input to the RNN operator. However in some cases, it is also used to compute max_batch_size. For example in this code: https://github.com/pytorch/pytorch/blob/febff45/torch/nn/modules/rnn.py#L815-L815 With these changes the GRU tests now pass in scripting mode for opset version >= 11. Test Plan: Imported from OSS Reviewed By: driazati Differential Revision: D28714805 Pulled By: SplitInfinity fbshipit-source-id: f19647a04533d9ec76399a8793b3f712ea0337d2 Co-authored-by: Gary Miguel <garymiguel@microsoft.com>
Author
Parents
Loading