[PT-D] Enable megatron-lm style MLP layers (Changes mainly on sharded linear op) (#69735)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69735
We want to build a prototype of Megatron-LM so that we can apply PT-D op to models like transformer and other Meta flagship models like
The basic idea of Megatron-LM is as following:
1. Col-wise sharding of linear weight. Perform the linear op for the first layer.
2. Perform a math op (optional), such as ReLU or GeLU. We use GeLU in our example unit test. The input is from step 1.
3. Row-wise sharing of linear weight. Perform the linear op for the second layer. The input is from step 2.
We then save communications to concatenate the col-wise sharding results and spreading the input to different ranks for row-wise sharding.
The change is as following:
1. Return a ShardedTensor for the col-wise sharding in the sharded_linear op.
2. Return a PartialTensors for the row-wise sharding in the sharded_linear op.
3. Leverage APIs already defined for `reshard` to merge/aggregate local results to a fully sync local result if needed.
4. Add helper function to create sharded tensor based on the local result.
5. Add a unit test to test the Megatron-LM idea mentioned above and compare with local ops, including the grad and optimizer so that we can ensure the correctness of the implementation.
6. Refactor the unit test of sharded linear to reflect the changes in the code.
ghstack-source-id: 148273049
Test Plan: Unit test + CI
Reviewed By: pritamdamania87
Differential Revision: D32978221
fbshipit-source-id: 565fc92e7807e19d53b0261f8ace3945bef69e3e
(cherry picked from commit 344abe75202493c8313502e1b22d634568e1b225)