pytorch
75fb0b6c - Enable full train_step tracing and customizable dist graph expansion (#97416)

Commit
1 year ago
Enable full train_step tracing and customizable dist graph expansion (#97416) This commit adds an entry point for full `train_step` tracing and expansion. Model forward, backwrd, and optimizer step will be included in one graph. DTensor expansion will be applied on top to insert collective communications. Users can also provide an `Override` implementation to skip non-traceable submodules and directly install submodule logic to the DTensor-expanded graph by inserting `fx.Nodes`. Differential Revision: [D44325177](https://our.internmc.facebook.com/intern/diff/D44325177) Pull Request resolved: https://github.com/pytorch/pytorch/pull/97416 Approved by: https://github.com/yifuwang, https://github.com/wanchaol
Author
Committer
Parents
Loading