pytorch
0dc98728 - Basic implementation of ShardedLinear using ShardedTensor. (#64128)

Commit
3 years ago
Basic implementation of ShardedLinear using ShardedTensor. (#64128) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64128 This PR implements a sharded nn.Linear layer using ShardedTensors with the following limitations: 1) Works only for ChunkShardingSpec. 2) Implementation is only aimed to demonstrate functionality and is most likely not performant at all. The PR also introduces a `shard_parameter` API to easily shard parameters of `nn.Modules`. This also has the following limitations: 1) Works only for ChunkShardingSpec. 2) Is not performant since it uses broadcast instead of scatter since ProcessGroupNCCL doesn't yet support scatter. Overall user API for running a sharded linear would be something like this: ``` # SPMD programming paradigm running same code on all nodes. fc = nn.Linear(10, 10) # Setup sharding. sharding_spec=ChunkShardingSpec(...) shard_parameter(fc, 'weight', sharding_spec, src_rank=0) # Run as a normal linear layer. inp = torch.rand(10, 10) output = fc(inp) ``` ghstack-source-id: 138500985 Test Plan: 1) unit tests. 2) waitforbuildbot Reviewed By: wanchaol, bowangbj Differential Revision: D30621215 fbshipit-source-id: 1aa7478568c18a4572f6c3462fdf24a4cbde01d6
Author
Parents
Loading