Explicit attribute setting for pruning and weight_norm upon reparam removal (#34170)
Summary:
To address one of the problems with RNNs that emerged in https://github.com/pytorch/pytorch/issues/33618, I modified the `remove` methods in `torch.nn.utils.prune` and `torch.nn.utils.weight_norm` to make an explicit call to `setattr`, which, in `rnn.py` directly modifies `_flat_weights` (https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py#L96) to include the new element.
This is important so that `_flat_weights` can reflect the presence of the `Parameter` after the (pruning or weight norm) reparametrization is removed. Without this, the weight in `_flat_weights` would remain a tensor, as originally set by the reparametrization.
Simple testing is added, which depends on the current naming scheme for the LSTM module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34170
Differential Revision: D21265965
Pulled By: mickypaganini
fbshipit-source-id: 29de4a6b17052d42ccfe67c8560b7f83c20fd09d