pytorch
943acd4d - [FSDP] Fix `nn.Parameter` usage for 2D and `use_orig_params=True` (#89782)

Commit
2 years ago
[FSDP] Fix `nn.Parameter` usage for 2D and `use_orig_params=True` (#89782) This ensures that all elements of `FlatParameter._params` and `FlatParameter._shared_params` are `nn.Parameter`s (as expected). This was violated by the local tensor of a `DTensor` when using 2D parallelism. To fix the breakage, we simply wrap with `nn.Parameter` if needed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89782 Approved by: https://github.com/fduwjj
Author
Committer
Parents
Loading