[AO] Callable norm function for sparsifier (#85236)
The `WeightNormSparsifier` currently only supports L2-norm. This allows the users specify the function that is applied to compute the norm. In addition, L1-norm is also added, as an `.abs` function.
## Implementation details
- The functions that are referred to as "norms", are not strictly such. For example, L2-norm of `x` is computed as `F.avg_pool(x * x, ...)`. Similarly, L1-norm of `x` is computed as `F.avg_pool(x.abs(), ...)`.
- When passing callable functions for the norm, the above assumption must hold: `F.avg_pool(norm_fn(x), ...)` will be applied.
## Example:
```python
>>> # L3-norm
>>> l3 = lambda T: T * T * T
>>> sparsifier = WeightNormSparsifier(norm=l3)
>>>
>>> # L0-norm
>>> l0 = lambda T: (torch.logical_or(torch.zeros(T.shape), T != 0).to(T.dtype)
>>> sparsifier = WeightNormSparsifier(norm=l0)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85236
Approved by: https://github.com/jcaip