pytorch
2766662c - [PyTorch][2/N] Basic implementation of ShardedEmbeddingBag using ShardedTensor. (#67188)

Commit
3 years ago
[PyTorch][2/N] Basic implementation of ShardedEmbeddingBag using ShardedTensor. (#67188) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67188 This diff/PR is trying to implement the ShardedEmbeddingBag using the ShardedTensor. We support both row-wise and column-wise sharding of the embedding bag. The detailed logic can be found in the comment. Several caveats: 1. Only the sharding of one weight is supported now. 1. We support limited input params for the op. To support more params are on the way. 2. We only support chuck sharding for now. 3. We only support a single local shard per rank for now. Some other changes include: 1. Refactor the ShardedEmbedding code so that the common logic can be reused. 2. Fix tiny typos and corner cases in API `get_chunked_dim_size`. Where it will return -1 if the we set the dim_size = 5, split_size = 2, idx = 3. (This is a valid case because when chunks = 4, dim_size = 5, then the split_size = 2) ghstack-source-id: 142325915 Test Plan: Unit test and CI Reviewed By: pritamdamania87 Differential Revision: D31749458 fbshipit-source-id: ed77e05e4ec94ef1a01b1feda8bbf32dc5d5da1b
Author
Parents
Loading