pytorch
7aeee284 - Parametrization Functionality (#33344)

Commit
3 years ago
Parametrization Functionality (#33344) Summary: Provides the implementation for feature request issue https://github.com/pytorch/pytorch/issues/28937. Adds the `Parametrization` functionality and implements `Pruning` on top of it. It adds the `auto` mode, on which the parametrization is just computed once per forwards pass. The previous implementation computed the pruning on every forward, which is not optimal when pruning RNNs for example. It implements a caching mechanism for parameters. This is implemented through the mechanism proposed at the end of the discussion https://github.com/pytorch/pytorch/issues/7313. In particular, it assumes that the user will not manually change the updated parameters between the call to `backwards()` and the `optimizer.step()`. If they do so, they would need to manually call the `.invalidate()` function provided in the implementation. This could be made into a function that gets a model and invalidates all the parameters in it. It might be the case that this function has to be called in the `.cuda()` and `.to` and related functions. As described in https://github.com/pytorch/pytorch/issues/7313, this could be used, to implement in a cleaner way the `weight_norm` and `spectral_norm` functions. It also allows, as described in https://github.com/pytorch/pytorch/issues/28937, for the implementation of constrained optimization on manifolds (i.e. orthogonal constraints, positive definite matrices, invertible matrices, weights on the sphere or the hyperbolic space...) TODO (when implementation is validated): - More thorough test - Documentation Resolves https://github.com/pytorch/pytorch/issues/28937 albanD Pull Request resolved: https://github.com/pytorch/pytorch/pull/33344 Reviewed By: zhangguanheng66 Differential Revision: D26816708 Pulled By: albanD fbshipit-source-id: 07c8f0da661f74e919767eae31335a9c60d9e8fe
Author
Parents
Loading