[PyTorch] Make rearragement in sharded linear work as expected. (#66603)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66603
Found the issue here: https://github.com/pytorch/pytorch/issues/66281 by make the test cases more complicated.
By closely reading the code again, it turns out my original understanding is also wrong. Let's use the example mentioned in the issue to explain:
If the placement is like:
```
"rank:3/cuda:3",
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
```
First, we split the column or row by the order of [3, 0, 1, 2].
In the case of column-wise sharding:
We get to reaggrage the result from rank0-4.
Step 1: we split the output based on the original sharding strategy, aka, rank3 gets the 1st shard, rank0 get the 2nd shard, etc.
Step 2: we need to rearrange the result from rank0-4 by ordering them following the order of [3, 0, 1, 2], aka, the result from rank3 needs to be put in the front, and so forth.
In the case of row-wise sharding:
We need to rearrange the input being sent to rank0-4.
Step 1: we reorder the input and follow the map of [3, 0, 1, 2]. For example, the first shard goes to rank 3 so we need to put in the 3rd part, the second shard goes to rank 0, so we put it in the 2nd part, and so on.
Step 2: the size of the sharding for each rank is decided by the original placement: [3, 0, 1, 2], aka, rank 3 gets the first shard and its size, etc.
Update the unit test to reflect this change.
Also, correct some format and comments in the sharded linear.
ghstack-source-id: 141055689
Test Plan: unit test and wait for CI.
Reviewed By: pritamdamania87, bowangbj
Differential Revision: D31634590
fbshipit-source-id: 677a9c2b42da1e2c63220523ed2c004565bbecc7