Make torch.cond work with retracing (#92646)
We simplify the handling of branch submodules by only working with flattened input/output so that there is no need for adjusting in_spec and out_spec in the second round of tracing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92646
Approved by: https://github.com/zhxchen17, https://github.com/voznesenskym