Replace individual detaches with overall torch.no_grad decorator (#120638)
Fixes https://github.com/pytorch/pytorch/issues/120611.
At first, I thought there were too many detaches, but @awgu and I made the conclusion that both `clip_grad_norm_` and `clip_grad_value_` should be run under torch.no_grad similar to optimizer step. One option is to continue calling `detach`, but doing that on many tensors is slower than setting the context to be no_grad (I think?) and Andrew had noticed: "the 1st round of detaches takes 10 ms for FSDP2, whereas existing FSDP's clip_grad_norm_ only takes 3 ms total" since there are more tensors in FSDP2.
This change also disables grad mode for the foreach path of `clip_grad_value_`, which the first attempt that didn't do this was an oversight. Not sure how to add a test case for this since grad mode will be turned back on after the call.
New profile is not much different from the one in the bottom of this stack, but the number of detaches is 0 :D:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (c71bcceb)]$ python playground2.py
STAGE:2024-02-26 13:07:15 211224:211224 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-02-26 13:07:16 211224:211224 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-02-26 13:07:16 211224:211224 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
cudaLaunchKernel 70.63% 110.415ms 70.63% 110.415ms 5.811ms 0.000us 0.00% 0.000us 0.000us 19
aten::linalg_vector_norm 0.18% 284.000us 26.00% 40.636ms 40.636ms 3.000us 0.99% 3.000us 3.000us 1
aten::clamp 0.09% 148.000us 14.88% 23.261ms 23.261ms 1.000us 0.33% 1.000us 1.000us 1
aten::to 0.75% 1.170ms 14.05% 21.970ms 84.826us 0.000us 0.00% 258.000us 0.996us 259
aten::_to_copy 2.28% 3.562ms 13.31% 20.800ms 161.240us 0.000us 0.00% 258.000us 2.000us 129
aten::_foreach_norm 4.44% 6.935ms 12.72% 19.878ms 9.939ms 19.000us 6.29% 21.000us 10.500us 2
aten::add 0.11% 173.000us 10.97% 17.153ms 17.153ms 1.000us 0.33% 1.000us 1.000us 1
aten::stack 2.99% 4.673ms 9.15% 14.300ms 14.300ms 0.000us 0.00% 6.000us 6.000us 1
aten::copy_ 5.49% 8.586ms 8.96% 14.001ms 108.535us 258.000us 85.43% 258.000us 2.000us 129
aten::reciprocal 0.11% 179.000us 8.35% 13.051ms 13.051ms 1.000us 0.33% 1.000us 1.000us 1
aten::cat 0.64% 993.000us 4.42% 6.902ms 6.902ms 6.000us 1.99% 6.000us 6.000us 1
aten::zeros 0.04% 69.000us 4.28% 6.698ms 3.349ms 0.000us 0.00% 2.000us 1.000us 2
aten::zero_ 0.04% 66.000us 4.13% 6.462ms 3.231ms 0.000us 0.00% 2.000us 1.000us 2
aten::fill_ 0.06% 98.000us 4.09% 6.396ms 3.198ms 2.000us 0.66% 2.000us 1.000us 2
aten::_foreach_mul_ 1.50% 2.342ms 3.79% 5.924ms 2.962ms 10.000us 3.31% 10.000us 5.000us 2
aten::empty 3.27% 5.115ms 3.27% 5.115ms 19.826us 0.000us 0.00% 0.000us 0.000us 258
aten::empty_strided 2.07% 3.237ms 2.07% 3.237ms 25.093us 0.000us 0.00% 0.000us 0.000us 129
cudaDeviceEnablePeerAccess 1.93% 3.023ms 1.93% 3.023ms 1.512ms 0.000us 0.00% 0.000us 0.000us 2
aten::unsqueeze 1.21% 1.896ms 1.74% 2.725ms 10.645us 0.000us 0.00% 0.000us 0.000us 256
cudaMemcpyAsync 1.01% 1.572ms 1.01% 1.572ms 12.186us 0.000us 0.00% 0.000us 0.000us 129
aten::as_strided 0.54% 839.000us 0.54% 839.000us 3.265us 0.000us 0.00% 0.000us 0.000us 257
cudaStreamWaitEvent 0.34% 539.000us 0.34% 539.000us 2.089us 0.000us 0.00% 0.000us 0.000us 258
cudaEventRecord 0.18% 274.000us 0.18% 274.000us 1.062us 0.000us 0.00% 0.000us 0.000us 258
aten::mul 0.07% 107.000us 0.08% 132.000us 132.000us 1.000us 0.33% 1.000us 1.000us 1
cudaDeviceSynchronize 0.01% 17.000us 0.01% 17.000us 8.500us 0.000us 0.00% 0.000us 0.000us 2
cudaDeviceCanAccessPeer 0.00% 7.000us 0.00% 7.000us 3.500us 0.000us 0.00% 0.000us 0.000us 2
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 2.000us 0.66% 2.000us 1.000us 2
void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 13.000us 4.30% 13.000us 3.250us 4
void at::native::lpnorm_cleanup<float, (at::native::... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 1.99% 6.000us 3.000us 2
Memcpy PtoP (Device -> Device) 0.00% 0.000us 0.00% 0.000us 0.000us 258.000us 85.43% 258.000us 2.000us 129
void at::native::(anonymous namespace)::CatArrayBatc... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 1.99% 6.000us 3.000us 2
void at::native::reduce_kernel<512, 1, at::native::R... 0.00% 0.000us 0.00% 0.000us 0.000us 3.000us 0.99% 3.000us 3.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 10.000us 3.31% 10.000us 2.500us 4
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 156.319ms
Self CUDA time total: 302.000us
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120638
Approved by: https://github.com/Skylion007, https://github.com/albanD
ghstack dependencies: #120623