pytorch
96c745df - Fix int() casting in torch.nn.RNN to have correctly traced JIT and ONNX graph. (#92970)

Commit
1 year ago
Fix int() casting in torch.nn.RNN to have correctly traced JIT and ONNX graph. (#92970) Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com> Fixes #91351 As for unit tests - in this PR I only fixed LSTM unit test to properly use dynamic axes and expose export issue by running test with same ONNX for additional inputs. If the changes approved, we should also fix the rest of the tests (RNN/GRU and beyond). I have verified the following updated tests are working with new code and failing with the old code: test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset_version_14_is_script_False_keep_initializers_as_inputs_True::test_rnn_name_lstm_nonlinearity_None_unilayer_bidirectional_no_initial_state_with_variable_length_sequences_with_dropout test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset_version_14_is_script_False_keep_initializers_as_inputs_True::test_rnn_name_lstm_nonlinearity_None_unilayer_bidirectional_with_initial_state_with_variable_length_sequences_with_dropout Pull Request resolved: https://github.com/pytorch/pytorch/pull/92970 Approved by: https://github.com/titaiwangms, https://github.com/kit1980
Author
Committer
Parents
Loading