pytorch
953aa6d9 - [TP] Enable more generic attn in Tensor Parallelism (#100508)

Commit
1 year ago
[TP] Enable more generic attn in Tensor Parallelism (#100508) To make TP more generic for Attention module, we come up with this new col/rowwise parallel style. Basically, the idea behind is that: We only do DTensor op for Col/Rowwise sharded part. For the rest of ATen ops, we will leave it to Tensor ops. And we set this behavior as default for Colwise and Rowwise parallel style. If people want to customize it, they can always pass in different prepare_input or prepare_output Pull Request resolved: https://github.com/pytorch/pytorch/pull/100508 Approved by: https://github.com/wanchaol
Author
Committer
Parents
Loading