stack_module_state should return unrelated parameters (#92278)
`torch.func.stack_module_state` is our replacement for
`functorch.combine_state_for_ensemble`. The most common usage for
combine_state_for_ensemble is to
- create stacked parameters and buffers
- use vmap to run the forward pass
- use regular PyTorch autograd to run the backward pass (e.g.,
Tensor.backwrd)
- optimize directly over the stacked parameters (this is more performant
than optimizing over the unstacked parameters).
Right now, stack_module_state returns stacked parameters that cannot be
optimized directly (only leaf tensors can have a .grad field); this PR
fixes that by turning the stacked parameters back into leaf tensors.
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92278
Approved by: https://github.com/soulitzer