pytorch
6ee54a87 - fix weight norm backward bug on CPU when OMP_NUM_THREADS <= 2 (#80930)

Commit
2 years ago
fix weight norm backward bug on CPU when OMP_NUM_THREADS <= 2 (#80930) fix https://github.com/pytorch/pytorch/issues/80569 root cause: `weight_norm_backward_last_dim_kernel` will create a temp buffer to do vertical reduction, size of [num_threads, N] (N is the size of last dimension of v) to save additional memory allocation, the original kernel reuse the buffer after the vertical sum: 1st row stores the final result of sum 2nd row stores coefficient a 3rd row stores coefficient b when OMP_NUM_THREADS <=2, this will cause illegal memory access since the buffer size will be only 1*N or 2*N; the fix is use a separate buffer (`a_b`) to store the coefficients of a and b. Pull Request resolved: https://github.com/pytorch/pytorch/pull/80930 Approved by: https://github.com/frank-wei, https://github.com/malfet
Author
Committer
Parents
Loading