pytorch
2dc5e166 - [TP][Inference] Enable DTensor TP inference (#110751)

Commit
1 year ago
[TP][Inference] Enable DTensor TP inference (#110751) In https://github.com/pytorch/pytorch/pull/109977, we observed that during inference mode, aten.Linear does not get decomposed. So instead of enabling sharding propagation for linear op, we use func.decompose so that it gets decomposed to matmul and mm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110751 Approved by: https://github.com/bdhirsh, https://github.com/wanchaol
Author
Committer
Parents
Loading