use two pass reduction for deterministic reduction order (#115620)
## Motivation
Address the [non-deterministic reduction order](https://github.com/pytorch/pytorch/issues/93542#issuecomment-1411294181) issue for `omp parallel reduction`.
## Latest update on 1.15:
https://github.com/pytorch/pytorch/pull/115620/commits/55d81901bc83f89df16fcb6ce84c50acc01e41c6.
Do not reduce to arr in loops. Instead, reduce to a local scaler and write it to arr after local reduction is done. This will allow the compiler to optimize the reduction variable in register instead read/write from memory. If the `working set` of `loop body` is quite large, `read/write from register/memory` will have a large gap.
```
vaddss (%xmm0, %xmm11, %xmm11) -> accumulate in register %xmm0
vaddssl ((%rdx, %rdi, 4), %xmm0, %xmm0) -> accumulate in memory address (%rdx, %rdi, 4)
```
Examples code:
```
tmp0_acc_arr[64];
#pragma omp parallel num_threads(64)
{
auto tid = omp_get_thread_num();
#pragma omp for
for(...){
....
tmp0_acc_arr[tid] = tmp0_acc_arr[tid] + tmp_x; // access array will always from memory
}
}
```
will be changed to
```
tmp0_acc_arr[64];
#pragma omp parallel num_threads(64)
{
auto tid = omp_get_thread_num();
**auto tmp0_acc_local = 0;**
#pragma omp for
for(...){
....
**tmp0_acc_local** = tmp0_acc_local + tmp_x;
}
**tmp0_acc_arr[tid] = tmp0_acc_local;**
}
```
## Descriptions
Following aten to use `two pass reduction` with `omp parallel` for deterministic reduction order.
https://github.com/pytorch/pytorch/blob/9c3ae37fc453505f5e437d1edadefdb278c2c39c/aten/src/ATen/Parallel-inl.h#L39
https://github.com/pytorch/pytorch/blob/9c3ae37fc453505f5e437d1edadefdb278c2c39c/aten/src/ATen/native/TensorIteratorReduce.cpp#L24
```
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
// init reduction buffer per thread
float tmp_acc0_arr[64];
at::vec::Vectorized<float> tmp_acc0_vec_arr[64];
for (int tid = 0; tid < 64; tid++)
{
tmp_acc0_arr[tid] = 0;
tmp_acc0_vec_arr[tid] = at::vec::Vectorized<float>(0);
}
#pragma omp parallel num_threads(64)
{
int tid = omp_get_thread_num();
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(3964928L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x0));
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x0));
auto tmp2 = tmp0 - tmp1;
auto tmp3 = tmp2 * tmp2;
// reduce to per thread buffers
tmp_acc0_vec_arr[tid] = tmp_acc0_vec_arr[tid] + tmp3;
}
}
// second pass reduce
for (int tid = 0; tid < 64; tid++)
{
tmp_acc0 = tmp_acc0 + tmp_acc0_arr[tid];
tmp_acc0_vec = tmp_acc0_vec + tmp_acc0_vec_arr[tid];
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr0[static_cast<long>(0L)] = static_cast<float>(tmp_acc0);
```
## Test results
I test this PR with dynamo benchmark on 32-core ICX system,
Result (avg speed up):
| | before this PR | after this PR |
| ---- | ---- | ---- |
| torchbench | 1.303 | 1.301 |
| hugginface | 1.346 | 1.343 |
| timms | 1.971 | 1.970 |
```
export LD_PRELOAD=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}/lib/libiomp5.so:${CONDA_PREFIX:-"$(dirname $(which conda))/../"}/lib/libjemalloc.so
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
export KMP_AFFINITY=granularity=fine,compact,1,0
export KMP_BLOCKTIME=1
multi_threads_test() {
CORES=$(lscpu | grep Core | awk '{print $4}')
export OMP_NUM_THREADS=$CORES
end_core=$(expr $CORES - 1)
numactl -C 0-${end_core} --membind=0 python benchmarks/dynamo/${SUITE}.py --${SCENARIO} --${DT} -dcpu -n50 --no-skip --dashboard --only "${MODEL}" ${Channels_extra} ${BS_extra} ${Shape_extra} ${Mode_extra} ${Wrapper_extra} ${Flag_extra} --timeout 9000 --backend=inductor --output=${LOG_BASE}/${SUITE}.csv
}
SCENARIO=performance
DT=float32
export TORCHINDUCTOR_FREEZING=1
Flag_extra="--freezing"
Mode_extra="--inference"
for suite in timm_models huggingface torchbench
do
export SUITE=$suite
echo $SUITE
export LOG_BASE=`date +%m%d%H%M%S`
mkdir $LOG_BASE
multi_threads_test
done
```
System info
```
ubuntu@ip-172-31-18-205:~/hz/pytorch$ lscpu
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 64
On-line CPU(s) list: 0-63
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
CPU family: 6
Model: 106
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 1
Stepping: 6
BogoMIPS: 5800.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic mo
vbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xs
aveopt xsavec xgetbv1 xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities
Virtualization features:
Hypervisor vendor: KVM
Virtualization type: full
Caches (sum of all):
L1d: 1.5 MiB (32 instances)
L1i: 1 MiB (32 instances)
L2: 40 MiB (32 instances)
L3: 54 MiB (1 instance)
NUMA:
NUMA node(s): 1
NUMA node0 CPU(s): 0-63
Vulnerabilities:
Gather data sampling: Unknown: Dependent on hypervisor status
Itlb multihit: Not affected
L1tf: Not affected
Mds: Not affected
Meltdown: Not affected
Mmio stale data: Mitigation; Clear CPU buffers; SMT Host state unknown
Retbleed: Not affected
Spec rstack overflow: Not affected
Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Srbds: Not affected
Tsx async abort: Not affected
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115620
Approved by: https://github.com/jgong5, https://github.com/jansel