pytorch
21d2bd78 - stack_module_state should return unrelated parameters (#92278)

Commit
2 years ago
stack_module_state should return unrelated parameters (#92278) `torch.func.stack_module_state` is our replacement for `functorch.combine_state_for_ensemble`. The most common usage for combine_state_for_ensemble is to - create stacked parameters and buffers - use vmap to run the forward pass - use regular PyTorch autograd to run the backward pass (e.g., Tensor.backwrd) - optimize directly over the stacked parameters (this is more performant than optimizing over the unstacked parameters). Right now, stack_module_state returns stacked parameters that cannot be optimized directly (only leaf tensors can have a .grad field); this PR fixes that by turning the stacked parameters back into leaf tensors. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/92278 Approved by: https://github.com/soulitzer
Author
Committer
Parents
Loading