[PyTorch][2/N] Basic implementation of ShardedEmbeddingBag using ShardedTensor. (#67188)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67188
This diff/PR is trying to implement the ShardedEmbeddingBag using the ShardedTensor.
We support both row-wise and column-wise sharding of the embedding bag. The detailed logic can be found in the comment.
Several caveats:
1. Only the sharding of one weight is supported now.
1. We support limited input params for the op. To support more params are on the way.
2. We only support chuck sharding for now.
3. We only support a single local shard per rank for now.
Some other changes include:
1. Refactor the ShardedEmbedding code so that the common logic can be reused.
2. Fix tiny typos and corner cases in API `get_chunked_dim_size`. Where it will return -1 if the we set the dim_size = 5, split_size = 2, idx = 3. (This is a valid case because when chunks = 4, dim_size = 5, then the split_size = 2)
ghstack-source-id: 142325915
Test Plan: Unit test and CI
Reviewed By: pritamdamania87
Differential Revision: D31749458
fbshipit-source-id: ed77e05e4ec94ef1a01b1feda8bbf32dc5d5da1b