Make nn.stateless correctly reset parameters if the forward pass fails (#81262) (#81262) (#81880)
Summary:
This bug came up as I was adding new tests for ExpandedWeights
If the forwards pass errors when the `_reparametrize_module` context manager is still on, the values from reparameterization will remain on the module outside of the context manager, where it should be the original values. This fixes that by putting a try/finally block around the forward call and call to reset the parameters
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81262
Approved by: https://github.com/zou3519
Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/56d1c755185c32cdc058cda190edfe139b895410
Reviewed By: DanilBaibak
Differential Revision: D37813203
Pulled By: samdow
fbshipit-source-id: 9c32485c074b10b985b35d2d575c35f16337af5f
Co-authored-by: samdow (Meta Employee) <samdow@fb.com>