Remove sync in embedding (#70943)
Summary:
This together with https://github.com/pytorch/pytorch/pull/66580 and https://github.com/pytorch/pytorch/pull/68376 will remove all syncs in embedding.
This PR includes https://github.com/pytorch/pytorch/pull/68376, please review after merging https://github.com/pytorch/pytorch/pull/68376
This PR introduces perf regressions and increases memory usage:
- `exclusive_sum` is now computing the entire `numel` elements instead of `num_of_segments` elements, and the trailing `numel - num_of_segments` results will be discarded.
- Some memory allocation now needs `numel` spaces instead of `num_of_segments` or `num_of_partial_segments`.
These are the prices we must pay in order to get a sync-free implementation.
I haven't done any benchmark yet. I will do it later.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70943
Reviewed By: H-Huang
Differential Revision: D34881660
Pulled By: ngimel
fbshipit-source-id: b0760fa33608c46cd4145ceb09878bf94a9f959d
(cherry picked from commit d959fa4783cfee84bf17c1fa6d0f5d6bde268d75)