pytorch
061e71b1 - Parametrizations depending on several inputs (#58488)

Commit
3 years ago
Parametrizations depending on several inputs (#58488) Summary: Makes possible that the first register parametrization depends on a number of parameters rather than just one. Examples of these types of parametrizations are `torch.nn.utils.weight_norm` and low rank parametrizations via the multiplication of a `n x k` tensor by a `k x m` tensor with `k <= m, n`. Follows the plan outlined in https://github.com/pytorch/pytorch/pull/33344#issuecomment-768574924. A short summary of the idea is: we call `right_inverse` when registering a parametrization to generate the tensors that we are going to save. If `right_inverse` returns a sequence of tensors, then we save them as `original0`, `original1`... If it returns a `Tensor` or a sequence of length 1, we save it as `original`. We only allow to have many-to-one parametrizations in the first parametrization registered. The next parametrizations would need to be one-to-one. There were a number of choices in the implementation: If the `right_inverse` returns a sequence of parameters, then we unpack it in the forward. This is to allow to write code as: ```python class Sum(nn.Module): def forward(self, X, Y): return X + Y def right_inverse(Z): return Z, torch.zeros_like(Z) ``` rather than having to unpack manually a list or a tuple within the `forward` function. At the moment the errors are a bit all over the place. This is to avoid having to check some properties of `forward` and `right_inverse` when they are registered. I left this like this for now, but I believe it'd be better to call these functions when they are registered to make sure the invariants hold and throw errors as soon as possible. The invariants are the following: 1. The following code should be well-formed ```python X = module.weight Y = param.right_inverse(X) assert isinstance(Y, Tensor) or isinstance(Y, collections.Sequence) Z = param(Y) if isisntance(Y, Tensor) else param(*Y) ``` in other words, if `Y` is a `Sequence` of `Tensor`s (we check also that the elements of the sequence are Tensors), then it is of the same length as the number parameters `param.forward` accepts. 2. Always: `X.dtype == Z.dtype and X.shape == Z.shape`. This is to protect the user from shooting themselves in the foot, as it's too odd for a parametrization to change the metadata of a tensor. 3. If it's one-to-one: `X.dtype == Y.dtype`. This is to be able to do `X.set_(Y)` so that if a user first instantiates the optimiser and then puts the parametrisation, then we reuse `X` and the user does not need to add a new parameter to the optimiser. Alas, this is not possible when the parametrisation is many-to-one. The current implementation of `spectral_norm` and `weight_norm` does not seem to care about this, so this would not be a regression. I left a warning in the documentation though, as this case is a bit tricky. I'm still missing to go over the formatting of the documentation, I'll do that tomorrow. Pull Request resolved: https://github.com/pytorch/pytorch/pull/58488 Reviewed By: soulitzer Differential Revision: D29100708 Pulled By: albanD fbshipit-source-id: b9e91f439cf6b5b54d5fa210ec97c889efb9da38
Author
Parents
Loading