Parametrizations depending on several inputs (#58488)
Summary:
Makes possible that the first register parametrization depends on a number of parameters rather than just one. Examples of these types of parametrizations are `torch.nn.utils.weight_norm` and low rank parametrizations via the multiplication of a `n x k` tensor by a `k x m` tensor with `k <= m, n`.
Follows the plan outlined in https://github.com/pytorch/pytorch/pull/33344#issuecomment-768574924. A short summary of the idea is: we call `right_inverse` when registering a parametrization to generate the tensors that we are going to save. If `right_inverse` returns a sequence of tensors, then we save them as `original0`, `original1`... If it returns a `Tensor` or a sequence of length 1, we save it as `original`.
We only allow to have many-to-one parametrizations in the first parametrization registered. The next parametrizations would need to be one-to-one.
There were a number of choices in the implementation:
If the `right_inverse` returns a sequence of parameters, then we unpack it in the forward. This is to allow to write code as:
```python
class Sum(nn.Module):
def forward(self, X, Y):
return X + Y
def right_inverse(Z):
return Z, torch.zeros_like(Z)
```
rather than having to unpack manually a list or a tuple within the `forward` function.
At the moment the errors are a bit all over the place. This is to avoid having to check some properties of `forward` and `right_inverse` when they are registered. I left this like this for now, but I believe it'd be better to call these functions when they are registered to make sure the invariants hold and throw errors as soon as possible.
The invariants are the following:
1. The following code should be well-formed
```python
X = module.weight
Y = param.right_inverse(X)
assert isinstance(Y, Tensor) or isinstance(Y, collections.Sequence)
Z = param(Y) if isisntance(Y, Tensor) else param(*Y)
```
in other words, if `Y` is a `Sequence` of `Tensor`s (we check also that the elements of the sequence are Tensors), then it is of the same length as the number parameters `param.forward` accepts.
2. Always: `X.dtype == Z.dtype and X.shape == Z.shape`. This is to protect the user from shooting themselves in the foot, as it's too odd for a parametrization to change the metadata of a tensor.
3. If it's one-to-one: `X.dtype == Y.dtype`. This is to be able to do `X.set_(Y)` so that if a user first instantiates the optimiser and then puts the parametrisation, then we reuse `X` and the user does not need to add a new parameter to the optimiser. Alas, this is not possible when the parametrisation is many-to-one. The current implementation of `spectral_norm` and `weight_norm` does not seem to care about this, so this would not be a regression. I left a warning in the documentation though, as this case is a bit tricky.
I'm still missing to go over the formatting of the documentation, I'll do that tomorrow.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58488
Reviewed By: soulitzer
Differential Revision: D29100708
Pulled By: albanD
fbshipit-source-id: b9e91f439cf6b5b54d5fa210ec97c889efb9da38