[spmd compile api] pre-flatten state container and pass the flattened state container to transforms (#98392)
Move the responsibility of flattening the input arguments from the graph module to the caller. This serves two purposes:
- Transformations that add/remove state need to manipulate a state container that maintains the state tensors in the same order as they appear in graph placeholders.
- Reduced runtime cost. The state container is only flattened once upfront.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98392
Approved by: https://github.com/mrshenli