Support reduction vectorization (#87356)
This PR is to optimize reduction implementation by `at::vec`. The main idea is as same as the aten implementation.
- Step1: Parallelize and vectorize the reduction implementation
- Step2: Invoke `at::vec::vec_reduce_all` to reduce the vector generated at step 1 to a single scalar
- Step3: Handle the tail elements
For the implementation, we create two kernels - `CppVecKernel` and `CppKernel`. The code block generation is as follows step by step.
- Gen the non-reduction loop - [Code](https://github.com/pytorch/pytorch/blob/gh/EikanWang/9/head/torch/_inductor/codegen/cpp.py#L1008-L1010)
- Gen the reduction initialization both for vectorization and non-vectorization kernel - [Code](https://github.com/pytorch/pytorch/blob/gh/EikanWang/9/head/torch/_inductor/codegen/cpp.py#L1015)
- Gen the reduction loop for the vectorization kernel - [Code](https://github.com/pytorch/pytorch/blob/gh/EikanWang/9/head/torch/_inductor/codegen/cpp.py#L1021-L1023)
- Gen the code to reduce the vector to scalar - [Code](https://github.com/pytorch/pytorch/blob/gh/EikanWang/9/head/torch/_inductor/codegen/cpp.py#L1033)
- Gen the reduction loop for the non-vectorization kernel - [Code](https://github.com/pytorch/pytorch/blob/gh/EikanWang/9/head/torch/_inductor/codegen/cpp.py#L1042)
- Do some post-reduction things like store reduction value - [Code](https://github.com/pytorch/pytorch/blob/gh/EikanWang/9/head/torch/_inductor/codegen/cpp.py#L1049)
```python
# Gen the non-reduction loop
for loop in CppVecKernel.NoneReductionLoop:
# Gen the reduction initialization both for vectorization and non-vectorization kernel
CppVecKernel.ReductionPrefix
# Gen the reduction loop for the vectorization kernel
for loop in CppVecKernel.ReductionLoop
CppVecKernel.Loads
CppVecKernel.Compute
CppVecKernel.Stores
# Gen the code to reduce the vector to scalar
CppVecKernel.ReductionSuffix
# Gen the reduction loop for the non-vectorization kernel
for loop in CppKernel.ReductionLoop
CppKernel.Loads
CppKernel.Compute
CppKernel.Stores
# The reduction is almost finished. To do some post-reduction things like store reduction value.
CppKernel.ReductionSuffix
```
The code snippet for maximum reduction exemplifies the idea. More detailed comments are inlined.
```C++
{
// Declare reduction for at::vec::Vectorized since it is not built-in data type.
#pragma omp declare reduction(+:at::vec::Vectorized<float>:omp_out += omp_in) initializer(omp_priv={{0}})
float tmp4 = 0;
// tmp4_vec is used to vectorize the sum reduction for tmp4
auto tmp4_vec = at::vec::Vectorized<float>(tmp4);
float tmp6 = 0;
// tmp6_vec is used to vectorize the sum reduction for tmp6
auto tmp6_vec = at::vec::Vectorized<float>(tmp6);
#pragma omp parallel num_threads(48)
{
// Parallelize the vectorized reduction
#pragma omp for reduction(+:tmp4_vec) reduction(+:tmp6_vec)
for(long i0=0; i0<192; i0+=1)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + 8*i0);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + 8*i0);
auto tmp2 = tmp0 - tmp1;
auto tmp3 = tmp2.abs();
auto tmp5 = tmp2 * tmp2;
tmp4_vec += tmp3;
tmp6_vec += tmp5;
}
// Reduce the tmp4_vec as a scalar and store at tmp4
tmp4 = at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>&y) {return x + y;}, tmp4_vec);
// Reduce the tmp6_vec as a scalar and store at tmp6
tmp6 = at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>&y) {return x + y;}, tmp6_vec);
// Handle the tail elements that could not be vectorized by aten.
#pragma omp for simd simdlen(4) reduction(+:tmp4) reduction(+:tmp6)
for(long i0=1536; i0<1536; i0+=1)
{
auto tmp0 = in_ptr0[i0];
auto tmp1 = in_ptr1[i0];
auto tmp2 = tmp0 - tmp1;
auto tmp3 = std::abs(tmp2);
auto tmp5 = tmp2 * tmp2;
tmp4 += tmp3;
tmp6 += tmp5;
}
}
out_ptr0[0] = tmp4;
out_ptr1[0] = tmp6;
}
```
Performance(Measured by operatorbench and the base line of speedup ratio is aten operator performance):
Softmax (1,16,384,384,dim=3) | Speedup ratio (simdlen=None) | Speedup ratio (simdlen=8) + this PR
-- | -- | --
24c | 0.37410838067524177 | 0.9036240100351164
4c | 0.24655829520907663 | 1.0255329993674518
1c | 0.21595768114988007 | 1.000587368005134
HW Configuration:
SKU: SKX Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz
MemTotal: 196708148 kB
MemFree: 89318532 kB
MemBandwidth: 112195.1MB/S
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87356
Approved by: https://github.com/jgong5, https://github.com/jansel