pytorch
0830808d - [spmd expansion] speed up expansion by ~5x (#98389)

Commit
2 years ago
[spmd expansion] speed up expansion by ~5x (#98389) According to profiling, the top two expensive operations in spmd expansion are propagate_op_sharding and make_fx (for every dispatcher op node). This PR makes the following changes to speed up spmd expansion: - We are unneccessarily doing propagate_op_sharding twice for every op. Remove one. - When no tensor redistribution is required, we only need to update non-tensor args of the node according to op_schema and avoid building a GraphModule just for the node. On a DDP use cases + foreach Adam, this change speeds up spmd expansion by ~5x (~10 min -> ~2 min). Pull Request resolved: https://github.com/pytorch/pytorch/pull/98389 Approved by: https://github.com/mrshenli
Author
Yifu Wang
Committer
Parents
Loading