[Gradient Compression] Replace torch.sqrt(torch.sum(col ** 2)) by torch.norm() (#51629)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51629
Leverage the existing util functions as much as possible for potential performance gain.
Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 120919883
Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_powerSGD_ddp_comm_hook_nccl
No performance regression:
f248664994 uses `torch.norm()`
```
total:
32 GPUs -- 32 GPUs: p25: 1.050 30/s (batch size 32)
p50: 1.230 26/s (batch size 32)
p75: 1.449 22/s (batch size 32)
p90: 1.611 19/s (batch size 32)
p95: 1.702 18/s (batch size 32)
backward:
32 GPUs -- 32 GPUs: p25: 0.769 41/s (batch size 32)
p50: 0.920 34/s (batch size 32)
p75: 1.139 28/s (batch size 32)
p90: 1.322 24/s (batch size 32)
p95: 1.440 22/s (batch size 32)
```
f248678690 does not use `torch.norm()`
```
total:
32 GPUs -- 32 GPUs: p25: 1.056 30/s (batch size 32)
p50: 1.249 25/s (batch size 32)
p75: 1.443 22/s (batch size 32)
p90: 1.608 19/s (batch size 32)
p95: 1.711 18/s (batch size 32)
backward:
32 GPUs -- 32 GPUs: p25: 0.777 41/s (batch size 32)
p50: 0.939 34/s (batch size 32)
p75: 1.127 28/s (batch size 32)
p90: 1.322 24/s (batch size 32)
p95: 1.448 22/s (batch size 32)
```
Reviewed By: pritamdamania87
Differential Revision: D26219835
fbshipit-source-id: 31d8ad3401d4efced4a6069f4f1e169ea3372697