pytorch
13dbad63 - use scatter_add for index_add when dim is the most inner dim (#88729)

Commit
2 years ago
use scatter_add for index_add when dim is the most inner dim (#88729) ### Motivation When dim is -1 and the slice of source or result is noncontiguous, original `index_add` is slow as it uses add for the sliced tensor, which is serial on index and parallel on sliced tensor to avoid write conflict. Doing parallel on the sliced tensor is not optimal as the size of sliced tensor may be not big enough to parallel and also causes multiple parallelizations. `scatter_add ` is used to speedup for this case as `scatter_add ` parallels on the outer dimension of input and is serial on the inner dimension to avoid write conflict. `scatter_add ` only need one parallel and the size of outer dimensions is bigger to do parallel. ### Testing - Single core: Before: shape | fp32 / s | bf16 / s -- | -- | -- [10, 128, 20, 20] | 2.82E-03 | 2.11E-03 [10, 128, 50, 50] | 0.023604 | 0.023794 After: shape | fp32 / s | bf16 / s -- | -- | -- [10, 128, 20, 20] | 9.30E-04 | 1.66E-03 [10, 128, 50, 50] | 0.005995 | 0.010003 - Single socket (28 cores): Before: shape | fp32 / s | bf16 / s -- | -- | -- [10, 128, 20, 20] | 2.96E-03 | 2.52E-03 [10, 128, 50, 50] | 0.012208 | 0.012568 After: shape | fp32 / s | bf16 / s -- | -- | -- [10, 128, 20, 20] | 7.44E-05 | 1.33E-04 [10, 128, 50, 50] | 0.000333 | 0.000469 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88729 Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/malfet
Author
Committer
Parents
Loading