pytorch
06c206ce - [SPMD] Add the default graph module transformation that is applied after tracing and expansion (#98182)

Commit
1 year ago
[SPMD] Add the default graph module transformation that is applied after tracing and expansion (#98182) This PR adds the GraphModuleTransformation class that can be used as the default transformation after the `train_step()` is traced and expand. The current implementation includes: 1. Wrap the input graph module with IterGraphModule. This will enable the futher graph optimizations which are all implemented based on IterGraphModule. 2. Ability to lower the graph module to the Inductor. To achieve this goal, `lower_to_inductor()` is implemented. TODO: 1. The `override` and `gm_transofmation` have overlapping functions -- `override.transform` can be used to achieve the same function as `gm_transformation`. However, the current semantics of `override` is to override and transform partial graphs while `gm_transformation` is to transform the entire expaned GM. The final UX of `compile()` needs some discussion. 2. The current `lower_to_inductor()` assumes that the entire graph can be lowered to Inductor. This assumption is okay for integration of graph optimizations but is too restrictive for many models. We should upstream `partial_lowering()`. Differential Revision: [D44616783](https://our.internmc.facebook.com/intern/diff/D44616783/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/98182 Approved by: https://github.com/mrshenli
Author
Committer
Parents
Loading