pytorch
11dc1581 - Remove sync in embedding (#70943)

Commit
3 years ago
Remove sync in embedding (#70943) Summary: This together with https://github.com/pytorch/pytorch/pull/66580 and https://github.com/pytorch/pytorch/pull/68376 will remove all syncs in embedding. This PR includes https://github.com/pytorch/pytorch/pull/68376, please review after merging https://github.com/pytorch/pytorch/pull/68376 This PR introduces perf regressions and increases memory usage: - `exclusive_sum` is now computing the entire `numel` elements instead of `num_of_segments` elements, and the trailing `numel - num_of_segments` results will be discarded. - Some memory allocation now needs `numel` spaces instead of `num_of_segments` or `num_of_partial_segments`. These are the prices we must pay in order to get a sync-free implementation. I haven't done any benchmark yet. I will do it later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/70943 Reviewed By: H-Huang Differential Revision: D34881660 Pulled By: ngimel fbshipit-source-id: b0760fa33608c46cd4145ceb09878bf94a9f959d (cherry picked from commit d959fa4783cfee84bf17c1fa6d0f5d6bde268d75)
Author
Committer
Parents
Loading