pytorch
1f7ab1c8 - fix performance issue in torch.sparse.mm reduce mode (#94969) (#95018)

Commit
2 years ago
fix performance issue in torch.sparse.mm reduce mode (#94969) (#95018) Fix performance bug for `torch.sparse.mm()` with reduce flag. Found this bug within internal benchmarking. Made a mistake when updating previous patch which causes load imbalance between threads: Test on ogbn-products datasets on Xeon CLX with 24 cores: #### before ``` sparse.mm: mean: 1156.148 ms sparse.mm: sum: 1163.754 ms sparse.mm: (using mkl): 703.227 ms ``` #### after ``` sparse.mm: mean: 662.578 ms sparse.mm: sum: 662.301 ms sparse.mm: (using mkl): 700.178 ms ``` The result also indicates that the current spmm kernel is no worse than MKL's sparse_mm . Also update results on `pyg benchmark` with: ``` python gnn.py --use_sage --epochs=3 --runs=1 --inference ``` * Out of box: `13.32s` * Without the fix in this PR: `5.87s` * With the fix in this PR: `3.19s` Pull Request resolved: https://github.com/pytorch/pytorch/pull/94969 Approved by: https://github.com/jgong5
Author
Parents
Loading