Add meta registration for _foreach_norm (#118604)
This PR also fixes the discrepancy between _foreach_norm fast path and slow path, where storage_offsets will be different between the lists of tensors. Here are some profile results showing that we aren't significantly slower. Do note that we're replacing N `as_strided`/`select` calls to N `empty` calls.
For script:
```
import torch
ts = [torch.rand(32, 16, device="cuda") for _ in range(128)]
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
) as p:
res = torch._foreach_norm(ts)
print(p.key_averages().table(sort_by="cpu_time_total"))
```
OG baseline:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (7cf98987)]$ python playground2.py
STAGE:2024-01-30 13:16:48 2740431:2740431 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-01-30 13:16:48 2740431:2740431 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-01-30 13:16:48 2740431:2740431 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::_foreach_norm 25.36% 4.209ms 99.94% 16.586ms 16.586ms 8.000us 88.89% 9.000us 9.000us 1
cudaLaunchKernel 61.21% 10.159ms 61.21% 10.159ms 2.540ms 0.000us 0.00% 0.000us 0.000us 4
aten::zeros 0.43% 71.000us 58.35% 9.683ms 9.683ms 0.000us 0.00% 1.000us 1.000us 1
aten::zero_ 0.33% 55.000us 57.35% 9.517ms 9.517ms 0.000us 0.00% 1.000us 1.000us 1
aten::fill_ 0.42% 69.000us 57.01% 9.462ms 9.462ms 1.000us 11.11% 1.000us 1.000us 1
aten::select 8.04% 1.335ms 11.29% 1.873ms 14.633us 0.000us 0.00% 0.000us 0.000us 128
aten::as_strided 3.24% 538.000us 3.24% 538.000us 4.203us 0.000us 0.00% 0.000us 0.000us 128
aten::empty 0.90% 150.000us 0.90% 150.000us 75.000us 0.000us 0.00% 0.000us 0.000us 2
cudaDeviceSynchronize 0.06% 10.000us 0.06% 10.000us 10.000us 0.000us 0.00% 0.000us 0.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 11.11% 1.000us 1.000us 1
void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 66.67% 6.000us 3.000us 2
void at::native::lpnorm_cleanup<float, (at::native::... 0.00% 0.000us 0.00% 0.000us 0.000us 2.000us 22.22% 2.000us 2.000us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 16.596ms
Self CUDA time total: 9.000us
```
And here's after this PR:
```
STAGE:2024-02-05 08:27:02 1127843:1127843 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-02-05 08:27:02 1127843:1127843 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-02-05 08:27:02 1127843:1127843 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::_foreach_norm 30.95% 4.653ms 99.95% 15.026ms 15.026ms 9.000us 90.00% 10.000us 10.000us 1
cudaLaunchKernel 52.41% 7.879ms 52.41% 7.879ms 1.970ms 0.000us 0.00% 0.000us 0.000us 4
aten::zeros 0.39% 58.000us 48.29% 7.260ms 7.260ms 0.000us 0.00% 1.000us 1.000us 1
aten::zero_ 0.35% 53.000us 47.25% 7.103ms 7.103ms 0.000us 0.00% 1.000us 1.000us 1
aten::fill_ 0.43% 65.000us 46.90% 7.050ms 7.050ms 1.000us 10.00% 1.000us 1.000us 1
aten::empty 15.42% 2.318ms 15.42% 2.318ms 17.969us 0.000us 0.00% 0.000us 0.000us 129
cudaDeviceSynchronize 0.05% 7.000us 0.05% 7.000us 7.000us 0.000us 0.00% 0.000us 0.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 10.00% 1.000us 1.000us 1
void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 60.00% 6.000us 3.000us 2
void at::native::lpnorm_cleanup<float, (at::native::... 0.00% 0.000us 0.00% 0.000us 0.000us 3.000us 30.00% 3.000us 3.000us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 15.033ms
Self CUDA time total: 10.000us
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118604
Approved by: https://github.com/albanD