pytorch
762a81cb - [spmd compile api] pre-flatten state container and pass the flattened state container to transforms (#98392)

Commit
1 year ago
[spmd compile api] pre-flatten state container and pass the flattened state container to transforms (#98392) Move the responsibility of flattening the input arguments from the graph module to the caller. This serves two purposes: - Transformations that add/remove state need to manipulate a state container that maintains the state tensors in the same order as they appear in graph placeholders. - Reduced runtime cost. The state container is only flattened once upfront. Pull Request resolved: https://github.com/pytorch/pytorch/pull/98392 Approved by: https://github.com/mrshenli
Author
Committer
Parents
Loading