Fixes moving after weight norm application (#32563)
Summary:
This PR updates how RNNs handle their "flat weights." In particular, it allows for only some flat weights to be "materialized" when apply is called, and it updates the flattening behavior to only apply if all flat weights are (1) materialized, (2) share a dtype and (3) are acceptable to cuDNN.
One test is modified and another created to test these changes. One practical effect of this change is that weight norm can be successfully applied to a module BEFORE that module is moved to an accelerator. Previously doing so would throw an error.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32563
Differential Revision: D19602725
Pulled By: mruberry
fbshipit-source-id: d8f9441d17815c8c9ba15b256d4be36f784a3cf9