pytorch
8f5f15a6 - optimize scatter_add performance for gnn usage on CPU (#82703)

Commit
2 years ago
optimize scatter_add performance for gnn usage on CPU (#82703) ### Motivation of this PR This PR is targeting at improving performance of `scatter_add` for GNN usage scenarios on PyG. Currently only CPU optimizations is covered. `Message Passing` is the major step in GNN learning which means exchanging/aggregating info between nodes. And from the perf point of view, if the `EdgeIndex` is stored as [2, num_edges], `scatter_reduce` would be a major perf hotspot on current pytorch implementations. To be more specific, in the process of message passing, `scatter_add` is used in a very similar way as `index_select`, except that the `self` tensor is written into while `index_select` is only reading. Therefore, the `index` tensor passed to `scatter_add` is an expanded tensor on dim0, which means all the rest of dims would end up with the same value. ### Algorithm Current impl on scatter would do parallel on the inner dims for such case which would cause bad perf: non-contiguous memory access pattern and non-vectorized. This PR did sorting on the `index` to solve the write conflicts if we directly parallel on dim0. The algorithm is equivalent to: * convert memory format from `COO` to `CSR` * do spmm reduce ### Perf improvement The benchmark comes from https://github.com/pyg-team/pytorch_geometric/tree/master/examples, `python reddit.py` which runs model SAGE on dataset reddit. CPU type: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz ` aten::scatter_add_` has been reduced from **37.797s** to **5.989s**: * breakdown before ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::scatter_add_ 49.00% 37.797s 49.00% 37.797s 41.445ms 912 aten::index_select 19.74% 15.223s 19.74% 15.227s 6.678ms 2280 aten::linear 0.01% 5.706ms 15.04% 11.602s 12.721ms 912 aten::addmm 6.62% 5.108s 7.92% 6.112s 13.403ms 456 aten::matmul 0.00% 2.339ms 7.10% 5.475s 12.006ms 456 ``` * breakdown after ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::index_select 32.41% 14.677s 32.42% 14.681s 6.439ms 2280 aten::linear 0.01% 6.665ms 26.43% 11.968s 13.123ms 912 aten::addmm 11.76% 5.328s 13.76% 6.232s 13.667ms 456 aten::scatter_add_ 13.22% 5.989s 13.22% 5.989s 6.566ms 912 aten::matmul 0.01% 2.303ms 12.63% 5.720s 12.543ms 456 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/82703 Approved by: https://github.com/jgong5, https://github.com/ezyang
Author
Committer
Parents
Loading