[caffe2] fix bug when weight_decay is used with fused rowwise + SLWS grad (#57090)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57090
We did loop-invariant code motion to avoid multiplying with in_weight_temp for each element but this breaks down when weight decay is not zero.
Test Plan:
In devgpu
buck test mode/dev-nosan //caffe2/caffe2/fb/net_transforms/tests:fuse_sparse_ops_test -- test_fuse_sparse_adagrad_with_sparse_lengths_weighted_sum_gradient --run-disabled
Reviewed By: jianyuh
Differential Revision: D28051026
fbshipit-source-id: f8906b72a41a87c2d43c447197b5fd695373ae23