pytorch
c620ece7 - port sparse_mm.reduce to pytorch and optimize it on CPU (#83727)

Commit
1 year ago
port sparse_mm.reduce to pytorch and optimize it on CPU (#83727) ### Motivation of this PR This patch is to migrate `spmm_reduce` from `torch-sparse` (a 3rd party dependency for PyG) to `torch`, which is a response to the initial proposal for fusion of **Gather, Apply Scatter** in Message Passing of GNN inference/training. https://github.com/pytorch/pytorch/issues/71300 **GAS** is the major step for Message Passing, the behavior of **GAS** can be classified into 2 kinds depending on the storage type of `EdgeIndex` which records the connections of nodes: * COO: the hotspot is `scatter_reduce` * CSR: the hotspot is `spmm_reduce` The reduce type can be choose from: "max", "mean", "max", "min". extend `torch.sparse.mm` with an `reduce` argument, maps to `torch.sparse_mm.reduce` internally. `sparse_mm_reduce` is registered under the TensorTypeId of `SparseCsrCPU`, and this operator requires an internal interface `_sparse_mm_reduce_impl` which has dual outputs: * `out` - the actual output * `arg_out` - records output indices in the non zero elements if the reduce type is "max" or "min", this is only useful for training. So for inference, it will not be calculated. ### Performance Benchmark on GCN for obgn-products on Xeon single socket, the workload is improved by `4.3x` with this patch. Performance benefit for training will be bigger, the original backward impl for `sum|mean` is sequential; the original backward impl for `max|min` is not fused. #### before: ``` ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ torch_sparse::spmm_sum 97.09% 56.086s 97.09% 56.088s 6.232s 9 aten::linear 0.00% 85.000us 1.38% 795.485ms 88.387ms 9 aten::matmul 0.00% 57.000us 1.38% 795.260ms 88.362ms 9 aten::mm 1.38% 795.201ms 1.38% 795.203ms 88.356ms 9 aten::relu 0.00% 50.000us 0.76% 440.434ms 73.406ms 6 aten::clamp_min 0.76% 440.384ms 0.76% 440.384ms 73.397ms 6 aten::add_ 0.57% 327.801ms 0.57% 327.801ms 36.422ms 9 aten::log_softmax 0.00% 23.000us 0.10% 55.503ms 18.501ms 3 ``` #### after ``` ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::spmm_sum 87.35% 11.826s 87.36% 11.827s 1.314s 9 aten::linear 0.00% 92.000us 5.87% 794.451ms 88.272ms 9 aten::matmul 0.00% 62.000us 5.87% 794.208ms 88.245ms 9 aten::mm 5.87% 794.143ms 5.87% 794.146ms 88.238ms 9 aten::relu 0.00% 53.000us 3.35% 452.977ms 75.496ms 6 aten::clamp_min 3.35% 452.924ms 3.35% 452.924ms 75.487ms 6 aten::add_ 2.58% 348.663ms 2.58% 348.663ms 38.740ms 9 aten::argmax 0.42% 57.473ms 0.42% 57.475ms 14.369ms 4 aten::log_softmax 0.00% 22.000us 0.39% 52.605ms 17.535ms 3 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/83727 Approved by: https://github.com/jgong5, https://github.com/cpuhrsch, https://github.com/rusty1s, https://github.com/pearu
Author
Committer
Parents
Loading