[pruning][core][feature] LSTM Structured Pruning prune_functions + pattern (#90801)
Summary:
This PR adds in support for LSTM Structured Pruning.
- Adds in LSTMSaliencyPruner, an implemented pruner that splits the packed weights, finds the appropriate mask for each piece individually based on saliency, and then combines to create an overall mask for the LSTM.
- Adds in pruning functions for LSTM pruning, which will split the weights, apply the masks, and then recombine the pruned weights. Works for both single and multiple-layer LSTMs.
Also added a basic pattern to the default set of of patterns for
LSTM -> Linear pruning
LSTM -> LayerNorm -> Linear pruning
Adds in test to check that LSTM pruning works, as well as for LSTMSaliencyPruner
Test Plan:
`python test/test_ao_sparsity.py -- TestBaseStructuredSparsifier.test_prune_lstm_linear_single_layer`
`python test/test_ao_sparsity.py -- TestBaseStructuredSparsifier.test_prune_lstm_linear_multiple_layer`
`python test/test_ao_sparsity.py -- TestBaseStructuredSparsifier.test_prune_lstm_layernorm_linear_single_layer`
`python test/test_ao_sparsity.py -- TestBaseStructuredSparsifier.test_prune_lstm_layernorm_linear_multiple_layer`
`python test/test_ao_sparsity.py -- TestSaliencyPruner.test_lstm_saliency_pruner_update_mask`
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: [D42199001](https://our.internmc.facebook.com/intern/diff/D42199001)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90801
Approved by: https://github.com/jerryzh168