Reimplement spectral_norm using new parametrization functionality (#57784)
Summary:
Adds a new file under `torch/nn/utils/parametrizations.py` which should contain all the parametrization implementations
For spectral_norm we add the `SpectralNorm` module which can be registered using `torch.nn.utils.parametrize.register_parametrization` or using a wrapper: `spectral_norm`, the same API the old implementation provided.
Most of the logic is borrowed from the old implementation:
- Just like the old implementation, there should be cases when retrieving the weight should perform another power iteration (thus updating the weight) and cases where it shouldn't. For example in eval mode `self.training=True`, we do not perform power iteration.
There are also some differences/difficulties with the new implementation:
- Using new parametrization functionality as-is there doesn't seem to be a good way to tell whether a 'forward' call was the result of parametrizations are unregistered (and leave_parametrizations=True) or when the injected property's getter was invoked. The issue is that we want perform power iteration in the latter case but not the former, but we don't have this control as-is. So, in this PR I modified the parametrization functionality to change the module to eval mode before triggering their forward call
- Updates the vectors based on weight on initialization to fix https://github.com/pytorch/pytorch/issues/51800 (this avoids silently update weights in eval mode). This also means that we perform twice any many power iterations by the first forward.
- right_inverse is just the identity for now, but maybe it should assert that the passed value already satisfies the constraints
- So far, all the old spectral_norm tests have been cloned, but maybe we don't need so much testing now that the core functionality is already well tested
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57784
Reviewed By: ejguan
Differential Revision: D28413201
Pulled By: soulitzer
fbshipit-source-id: e8f1140f7924ca43ae4244c98b152c3c554668f2