[ONNX] RNN scripting (#57564) (#58691)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58691
Note the first commit in this PR has its own pull request here since it seemed self-contained:
https://github.com/pytorch/pytorch/pull/57082
* [ONNX] simplify batch_first logic in RNN tests
* [ONNX] support GRU with packed input in scripting mode
This required two changes:
* Add as_tensor to symbolic_opset9.py
* Change torch::jit::pushPackingPastRnn to recognize and properly
replace another use of the batch_sizes output of prim::PackPadded.
Previously the code assumed that the first use was as input to the
RNN operator. However in some cases, it is also used to compute
max_batch_size. For example in this code:
https://github.com/pytorch/pytorch/blob/febff45/torch/nn/modules/rnn.py#L815-L815
With these changes the GRU tests now pass in scripting mode for opset
version >= 11.
Test Plan: Imported from OSS
Reviewed By: driazati
Differential Revision: D28714805
Pulled By: SplitInfinity
fbshipit-source-id: f19647a04533d9ec76399a8793b3f712ea0337d2
Co-authored-by: Gary Miguel <garymiguel@microsoft.com>