inductor: enable weight prepack for LSTM (#103071)
- Enabled LSTM weight prepack in inductor.
- Added a mkldnn decomposition for lstm which won't change for different `seq_lens`. With the previous decomposition, for dynamic shapes use case where `seq_lens` changes, the graph will be different.
- Extended several inductor utility functions to support `List(Tensor`) as input. Previously those functions only supported `Tensor` input.
**Update 2023-07-26:**
- https://github.com/pytorch/pytorch/pull/103851 has moved CPU weight packing to be after AOTAutograd. Fixed the support in this PR to follow the same way (mainly in https://github.com/pytorch/pytorch/commit/3b207f7f1c91aee224324b5baab83b38dbd5dae6#diff-6dffed1ade0ba3e887f9a4eafa3bfcec267ab2365b8adcb91bd391f49b3fd2e3).
LSTM is decomposed in `aten.mkldnn_rnn_layer` by layer and by direction. The weight prepack is done at the `mkldnn_rnn_layer` level.
- Add a fix in rnn `__get_state__` function in case we need to recompile an `LSTM` module.
When compiling the module, the weights tensors which are the `named_parameters` of the module are converted to `functional_tensor` here:
https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/nn/utils/stateless.py#L125-L128
The forward function of LSTM will be called:
https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_functorch/aot_autograd.py#L3379-L3381
In the forward function, the `_flat_weights` are updated to be the same as the weights, thus becoming `functional_tensor`:
https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/nn/modules/rnn.py#L775-L778
The weights tensors are converted back to the original tensors (which are not `functional_tensor` anymore) before exiting the `_reparametrize_module` context here:
https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/nn/utils/stateless.py#L130-L142
But since `_flat_weights` is not in the `named_parameters` of the module, it's still `functional_tensor` ([link of the parameters that will be converted to functional and reverted back](https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_functorch/aot_autograd.py#L3695-L3698)).
At this moment, if we need to recompile the model, `deepcopy` will be called:
https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_dynamo/utils.py#L915-L917
And it will report `UnImplemented` since we have `functional_tensor` (`_flat_weights`) and will trigger graph break which is not what we expect:
https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_subclasses/meta_utils.py#L514
Added a fix in the `__get_state__` to update the `_flat_weights` if ever weights have changed to fix this issue. The fix is covered in the `test_lstm_packed` UT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103071
Approved by: https://github.com/jgong5, https://github.com/jansel