pytorch
0336308b - [AO] Callable norm function for sparsifier (#85236)

Commit
2 years ago
[AO] Callable norm function for sparsifier (#85236) The `WeightNormSparsifier` currently only supports L2-norm. This allows the users specify the function that is applied to compute the norm. In addition, L1-norm is also added, as an `.abs` function. ## Implementation details - The functions that are referred to as "norms", are not strictly such. For example, L2-norm of `x` is computed as `F.avg_pool(x * x, ...)`. Similarly, L1-norm of `x` is computed as `F.avg_pool(x.abs(), ...)`. - When passing callable functions for the norm, the above assumption must hold: `F.avg_pool(norm_fn(x), ...)` will be applied. ## Example: ```python >>> # L3-norm >>> l3 = lambda T: T * T * T >>> sparsifier = WeightNormSparsifier(norm=l3) >>> >>> # L0-norm >>> l0 = lambda T: (torch.logical_or(torch.zeros(T.shape), T != 0).to(T.dtype) >>> sparsifier = WeightNormSparsifier(norm=l0) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/85236 Approved by: https://github.com/jcaip
Author
Committer
Parents
Loading