clip_grad_norm can use fast foreach path for inf norm (#120623)
Now that foreach_norm supports inf, we should not special case it.
For a mere 256 parameters, we get a win of 30ms in CPU time and ~800us -> 300us decrease in CUDA time. This win is only bigger for more parameters.
New profile:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (bf1c0490|REBASE-i|detached HEAD)]$ python playground2.py
STAGE:2024-02-26 13:14:10 395517:395517 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-02-26 13:14:11 395517:395517 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-02-26 13:14:11 395517:395517 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
cudaLaunchKernel 67.01% 102.262ms 67.01% 102.262ms 5.382ms 2.000us 0.66% 2.000us 0.105us 19
aten::linalg_vector_norm 0.20% 311.000us 23.44% 35.776ms 35.776ms 3.000us 0.99% 3.000us 3.000us 1
aten::to 0.79% 1.208ms 14.62% 22.311ms 86.143us 0.000us 0.00% 263.000us 1.015us 259
aten::clamp 0.12% 182.000us 13.96% 21.303ms 21.303ms 1.000us 0.33% 1.000us 1.000us 1
aten::_to_copy 2.38% 3.628ms 13.83% 21.103ms 163.589us 0.000us 0.00% 263.000us 2.039us 129
aten::_foreach_norm 4.71% 7.185ms 13.54% 20.659ms 10.329ms 19.000us 6.29% 23.000us 11.500us 2
aten::add 0.14% 211.000us 10.86% 16.580ms 16.580ms 1.000us 0.33% 1.000us 1.000us 1
aten::stack 3.11% 4.744ms 9.59% 14.642ms 14.642ms 0.000us 0.00% 6.000us 6.000us 1
aten::copy_ 5.71% 8.721ms 9.27% 14.152ms 109.705us 258.000us 85.43% 263.000us 2.039us 129
aten::reciprocal 0.13% 193.000us 7.93% 12.100ms 12.100ms 1.000us 0.33% 1.000us 1.000us 1
aten::cat 0.67% 1.017ms 4.67% 7.129ms 7.129ms 6.000us 1.99% 6.000us 6.000us 1
aten::zeros 0.05% 79.000us 4.46% 6.800ms 3.400ms 0.000us 0.00% 2.000us 1.000us 2
aten::zero_ 0.05% 79.000us 4.28% 6.537ms 3.268ms 0.000us 0.00% 2.000us 1.000us 2
aten::fill_ 0.09% 131.000us 4.23% 6.458ms 3.229ms 2.000us 0.66% 2.000us 1.000us 2
aten::_foreach_mul_ 1.56% 2.377ms 3.86% 5.896ms 2.948ms 10.000us 3.31% 10.000us 5.000us 2
aten::empty 3.55% 5.414ms 3.55% 5.414ms 20.984us 0.000us 0.00% 0.000us 0.000us 258
aten::empty_strided 2.18% 3.323ms 2.18% 3.323ms 25.760us 0.000us 0.00% 0.000us 0.000us 129
aten::detach 0.85% 1.302ms 2.10% 3.199ms 12.496us 0.000us 0.00% 0.000us 0.000us 256
cudaDeviceEnablePeerAccess 2.01% 3.069ms 2.01% 3.069ms 1.534ms 0.000us 0.00% 0.000us 0.000us 2
aten::unsqueeze 1.24% 1.899ms 1.81% 2.769ms 10.816us 0.000us 0.00% 0.000us 0.000us 256
detach 1.24% 1.897ms 1.24% 1.897ms 7.410us 0.000us 0.00% 0.000us 0.000us 256
cudaMemcpyAsync 1.01% 1.539ms 1.01% 1.539ms 11.930us 0.000us 0.00% 0.000us 0.000us 129
aten::as_strided 0.58% 881.000us 0.58% 881.000us 3.428us 0.000us 0.00% 0.000us 0.000us 257
cudaStreamWaitEvent 0.35% 540.000us 0.35% 540.000us 2.093us 0.000us 0.00% 0.000us 0.000us 258
cudaEventRecord 0.18% 278.000us 0.18% 278.000us 1.078us 5.000us 1.66% 5.000us 0.019us 258
aten::mul 0.08% 125.000us 0.09% 138.000us 138.000us 1.000us 0.33% 1.000us 1.000us 1
cudaDeviceSynchronize 0.01% 13.000us 0.01% 13.000us 6.500us 0.000us 0.00% 0.000us 0.000us 2
cudaDeviceCanAccessPeer 0.00% 5.000us 0.00% 5.000us 2.500us 0.000us 0.00% 0.000us 0.000us 2
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 2.000us 0.66% 2.000us 1.000us 2
void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 13.000us 4.30% 13.000us 3.250us 4
void at::native::lpnorm_cleanup<float, (at::native::... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 1.99% 6.000us 3.000us 2
Memcpy PtoP (Device -> Device) 0.00% 0.000us 0.00% 0.000us 0.000us 258.000us 85.43% 258.000us 2.000us 129
void at::native::(anonymous namespace)::CatArrayBatc... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 1.99% 6.000us 3.000us 2
void at::native::reduce_kernel<512, 1, at::native::R... 0.00% 0.000us 0.00% 0.000us 0.000us 3.000us 0.99% 3.000us 3.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 10.000us 3.31% 10.000us 2.500us 4
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 152.613ms
Self CUDA time total: 302.000us
```
Compared to on main:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5a0a9644)]$ python playground2.py
STAGE:2024-02-26 13:09:56 285045:285045 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-02-26 13:09:57 285045:285045 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-02-26 13:09:57 285045:285045 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
cudaLaunchKernel 61.42% 113.375ms 61.42% 113.375ms 424.625us 45.000us 5.66% 45.000us 0.169us 267
aten::linalg_vector_norm 14.04% 25.909ms 37.67% 69.534ms 271.617us 514.000us 64.65% 559.000us 2.184us 256
aten::to 0.78% 1.433ms 12.87% 23.751ms 91.703us 0.000us 0.00% 278.000us 1.073us 259
aten::_to_copy 2.02% 3.730ms 12.09% 22.318ms 173.008us 0.000us 0.00% 278.000us 2.155us 129
aten::clamp 0.09% 174.000us 11.43% 21.103ms 21.103ms 1.000us 0.13% 1.000us 1.000us 1
aten::add 0.11% 205.000us 9.08% 16.768ms 16.768ms 1.000us 0.13% 1.000us 1.000us 1
aten::copy_ 4.94% 9.112ms 8.15% 15.043ms 116.612us 258.000us 32.45% 278.000us 2.155us 129
aten::stack 2.76% 5.091ms 7.97% 14.719ms 14.719ms 0.000us 0.00% 6.000us 6.000us 1
aten::reciprocal 0.11% 194.000us 7.01% 12.933ms 12.933ms 1.000us 0.13% 1.000us 1.000us 1
aten::max 0.09% 165.000us 6.43% 11.868ms 11.868ms 3.000us 0.38% 3.000us 3.000us 1
aten::detach 1.58% 2.911ms 4.12% 7.596ms 14.836us 0.000us 0.00% 0.000us 0.000us 512
aten::cat 0.56% 1.042ms 3.73% 6.882ms 6.882ms 6.000us 0.75% 6.000us 6.000us 1
aten::_foreach_mul_ 1.36% 2.503ms 3.33% 6.145ms 3.072ms 10.000us 1.26% 10.000us 5.000us 2
detach 2.54% 4.685ms 2.54% 4.685ms 9.150us 0.000us 0.00% 0.000us 0.000us 512
aten::empty_strided 1.92% 3.545ms 1.92% 3.545ms 27.481us 0.000us 0.00% 0.000us 0.000us 129
cudaDeviceEnablePeerAccess 1.64% 3.022ms 1.64% 3.022ms 1.511ms 0.000us 0.00% 0.000us 0.000us 2
aten::unsqueeze 1.03% 1.892ms 1.49% 2.746ms 10.727us 0.000us 0.00% 0.000us 0.000us 256
aten::as_strided 1.35% 2.494ms 1.35% 2.494ms 4.862us 0.000us 0.00% 0.000us 0.000us 513
cudaMemcpyAsync 1.01% 1.868ms 1.01% 1.868ms 14.481us 4.000us 0.50% 4.000us 0.031us 129
cudaStreamWaitEvent 0.41% 760.000us 0.41% 760.000us 2.946us 8.000us 1.01% 8.000us 0.031us 258
cudaEventRecord 0.15% 276.000us 0.15% 276.000us 1.070us 8.000us 1.01% 8.000us 0.031us 258
aten::mul 0.08% 139.000us 0.08% 153.000us 153.000us 1.000us 0.13% 1.000us 1.000us 1
aten::empty 0.02% 35.000us 0.02% 35.000us 35.000us 0.000us 0.00% 0.000us 0.000us 1
cudaDeviceSynchronize 0.01% 14.000us 0.01% 14.000us 7.000us 0.000us 0.00% 0.000us 0.000us 2
cudaDeviceCanAccessPeer 0.00% 5.000us 0.00% 5.000us 2.500us 0.000us 0.00% 0.000us 0.000us 2
void at::native::reduce_kernel<512, 1, at::native::R... 0.00% 0.000us 0.00% 0.000us 0.000us 514.000us 64.65% 514.000us 2.008us 256
Memcpy PtoP (Device -> Device) 0.00% 0.000us 0.00% 0.000us 0.000us 258.000us 32.45% 258.000us 2.000us 129
void at::native::(anonymous namespace)::CatArrayBatc... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 0.75% 6.000us 3.000us 2
void at::native::reduce_kernel<512, 1, at::native::R... 0.00% 0.000us 0.00% 0.000us 0.000us 3.000us 0.38% 3.000us 3.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.13% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.13% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.13% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.13% 1.000us 1.000us 1
void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 10.000us 1.26% 10.000us 2.500us 4
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 184.579ms
Self CUDA time total: 795.000us
```
For script:
```
import torch
from math import inf
from torch.nn.utils import clip_grad_norm_
params = [torch.rand(32, 16, device="cuda:3")*5 for _ in range(128)] + [torch.rand(32, 16, device="cuda:4")*-7 for _ in range(128)]
for p in params:
p.grad = torch.rand_like(p)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
) as p:
total_norm = clip_grad_norm_(params, 10.0, norm_type=inf)
torch.cuda.synchronize()
print(p.key_averages().table(sort_by="cpu_time_total"))
print(total_norm)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120623
Approved by: https://github.com/Skylion007, https://github.com/mikaylagawarecki