[TP] Refactor Parallel Style to make it more usable (#111160)
One thing we find it challenging for users is that we don't want to expose the concept of prepare_input and prepare_out to users since there are so many func names for users to select from which is quite confusing. On the other hand, the colwise and rowwise parallel always need input(out) and output(in) to be certain layout so we can somehow simplify the logic here and make it more usable.
So we added three public attributes to the parallelStyle here and the code logic is like:
```python
class ParallelStyle(ABC):
"""
The parallel style user wants the module or submodule to be parallelized.
We can add more in future, but this seems sufficient for immediate needs. Users can extend this class to build their own parallel style with customized input/output preparations.
"""
input_layouts: Union[placement, Tuple[placement]]
output_layouts: Union[placement, Tuple[placement]]
use_local: bool
class RowwiseParallel(ParallelStyle):
"""
Partitioning the row of a module. We assume the input to be a sharded DTensor and output to be a replicate Tensor.
"""
def __init__(self):
super().__init__(input_layouts=Shard(-1), output_layouts=Replicate(), use_local=True)
Class ColwiseParallel(ParallelStyle):
"""
Partitioning the column of a module. We assume the input to be a Replicated DTensor and output to be a sharded DTensor.
"""
def __init__(self):
super().__init__(input_layouts=Replicate(), output_layouts=Shard(-1), use_local=True)
# For the case of Sequence parallel, users just set different input_shard, Shard(0) or Shard(1) instead of Replicate()
Class PrepareModuleInput(ParallelStyle):
"""
Only used to specify the input distribute spec for a module.
"""
def __init__(self):
super().__init__(input_layouts=Shard(0), output_layouts=Replicate(), use_local=False)
Class PrepareModuleOutput(ParallelStyle):
"""
Only used to specify the output distribute spec for a module.
"""
def __init__(self):
super().__init__(input_layouts=Replicate(), output_layouts=Shard(0), use_local=True)
parallelize_plan = {
"embedding": ColwiseParallel(output_shard=Replicate()),
"attn": PrepareModuleInput(),
"attn.w1": ColwiseParallel(),
"attn.w2": ColwiseParallel(),
"attn.w3": ColwiseParallel(),
"attn.wo": RowwiseParallel(),
}
parallelize_module(
module=block, # this can be a submodule or module
device_mesh=mesh['tp'],
parallelize_plan=parallelize_plan,
)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111160
Approved by: https://github.com/wanchaol