fix performance issue in torch.sparse.mm reduce mode (#94969)
Fix performance bug for `torch.sparse.mm()` with reduce flag.
Found this bug within internal benchmarking.
Made a mistake when updating previous patch which causes load imbalance between threads:
Test on ogbn-products datasets on Xeon CLX with 24 cores:
#### before
```
sparse.mm: mean: 1156.148 ms
sparse.mm: sum: 1163.754 ms
sparse.mm: (using mkl): 703.227 ms
```
#### after
```
sparse.mm: mean: 662.578 ms
sparse.mm: sum: 662.301 ms
sparse.mm: (using mkl): 700.178 ms
```
The result also indicates that the current spmm kernel is no worse than MKL's sparse_mm .
Also update results on `pyg benchmark` with:
```
python gnn.py --use_sage --epochs=3 --runs=1 --inference
```
* Out of box: `13.32s`
* Without the fix in this PR: `5.87s`
* With the fix in this PR: `3.19s`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94969
Approved by: https://github.com/jgong5