fix weight norm backward bug on CPU when OMP_NUM_THREADS <= 2 (#80930)
fix https://github.com/pytorch/pytorch/issues/80569
root cause: `weight_norm_backward_last_dim_kernel` will create a temp buffer to
do vertical reduction, size of [num_threads, N] (N is the size of last dimension of v)
to save additional memory allocation, the original kernel reuse the buffer after
the vertical sum:
1st row stores the final result of sum
2nd row stores coefficient a
3rd row stores coefficient b
when OMP_NUM_THREADS <=2, this will cause illegal memory access since the buffer size
will be only 1*N or 2*N;
the fix is use a separate buffer (`a_b`) to store the coefficients of a and b.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80930
Approved by: https://github.com/frank-wei, https://github.com/malfet