optimize scatter_add performance for gnn usage on CPU (#82703)
### Motivation of this PR
This PR is targeting at improving performance of `scatter_add` for GNN usage scenarios on PyG. Currently only CPU optimizations is covered.
`Message Passing` is the major step in GNN learning which means exchanging/aggregating info between nodes. And from the perf point of view, if the `EdgeIndex` is stored as [2, num_edges], `scatter_reduce` would be a major perf hotspot on current pytorch implementations.
To be more specific, in the process of message passing, `scatter_add` is used in a very similar way as `index_select`, except that the `self` tensor is written into while `index_select` is only reading. Therefore, the `index` tensor passed to `scatter_add` is an expanded tensor on dim0, which means all the rest of dims would end up with the same value.
### Algorithm
Current impl on scatter would do parallel on the inner dims for such case which would cause bad perf: non-contiguous memory access pattern and non-vectorized.
This PR did sorting on the `index` to solve the write conflicts if we directly parallel on dim0. The algorithm is equivalent to:
* convert memory format from `COO` to `CSR`
* do spmm reduce
### Perf improvement
The benchmark comes from https://github.com/pyg-team/pytorch_geometric/tree/master/examples, `python reddit.py` which runs model SAGE on dataset reddit.
CPU type: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz
` aten::scatter_add_` has been reduced from **37.797s** to **5.989s**:
* breakdown before
```
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::scatter_add_ 49.00% 37.797s 49.00% 37.797s 41.445ms 912
aten::index_select 19.74% 15.223s 19.74% 15.227s 6.678ms 2280
aten::linear 0.01% 5.706ms 15.04% 11.602s 12.721ms 912
aten::addmm 6.62% 5.108s 7.92% 6.112s 13.403ms 456
aten::matmul 0.00% 2.339ms 7.10% 5.475s 12.006ms 456
```
* breakdown after
```
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::index_select 32.41% 14.677s 32.42% 14.681s 6.439ms 2280
aten::linear 0.01% 6.665ms 26.43% 11.968s 13.123ms 912
aten::addmm 11.76% 5.328s 13.76% 6.232s 13.667ms 456
aten::scatter_add_ 13.22% 5.989s 13.22% 5.989s 6.566ms 912
aten::matmul 0.01% 2.303ms 12.63% 5.720s 12.543ms 456
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82703
Approved by: https://github.com/jgong5, https://github.com/ezyang