[ao][sparsity] Vectorized WeightNormSparsifier (#80059)
The previous implementation was using loops to compute the sparsity within a block in a mask, as well as across the mask blocks. This implements the vectorized version.
## Vectorization:
A high level overview of the vectorization procedure falls into a two step process:
### Tensor-level masking
A tensor-level masking is a mask generation routine that has a granularity of `sparse_block_shape`. That means that only patches of that shape can be considered sparse/dense. To vectorize:
1. Reshape the data such that one of the dimensions represents the patches of sparse_block_shape.
2. Create a mask of the same shape as the reshaped data
3. Find the smallest `k` elements in the the data, given the dimension of the sparse "patches". `k` represents a derived paramter specifying the sparsity level.
4. Apply the 0/1 to the patches in the mask
5. Reshape the mask back to the original dimensions
Note: because the shape of the mask might not be multiple of the sparse_block_shape, we nudge the sshape of the mask, and truncate it afterwards.
## Block-level masking
A block-level masking is a mask generation routine that concerns itself only with sparsity within a patch of shape `sparse_block_shape`. This is useful when block sparsity allows partial block sparsification.
To vectorize:
Overall the block-level masking follows the same routine as the tensor-level algorithm described above. One distinction is that when reshaping the data/mask tensors we aim for creating a dimension that captures the internals of each patch. For example, if a `sparse_block_shape` is `(2, 2)`, we want to reshape the data/mask into `(2, 2, -1)`. That allows us to sort the internal elements on the last axis, and zero-out the ones that obey the sparse logic.
Differential Revision: [D37352494](https://our.internmc.facebook.com/intern/diff/D37352494/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D37352494/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80059
Approved by: https://github.com/jerryzh168