[sparsity] Fix for accumulation bug in WeightNormSparsifier (#65293)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65293
This fixes a bug in the WeightNormSparsifier, where the mask is being multiplied by the newly computed mask.
Because the mask elements are binary 0/1, this accumulates the mask over every iteration, eventually collapsing the mask to zero.
This bug accidentally bled through from old versions.
Test Plan: Imported from OSS
Reviewed By: gchanan
Differential Revision: D31186829
Pulled By: z-a-f
fbshipit-source-id: 3f5b2c833148ab0bd8084e7410ce398f1252e65e