pytorch
c599cf24 - [FSDP] Another fix for `DTensor`, `use_orig_params=True` (#89845)

Commit
3 years ago
[FSDP] Another fix for `DTensor`, `use_orig_params=True` (#89845) The issue for `test_2d_parallel.py` is that `DTensor` does not support the idiom `param.data = view` where `view` is a `DTensor`. To work around this, we do not preserve the parameter variable `param` and instead create a new parameter variable altogether via `nn.Parameter(view)`. Preserving the parameter variable when unsharded was not a strict requirement -- it just made sense to do that if we are already doing that when _sharded_, where it _is_ a strict requirement to support the optimizer step. The sharded case is not an issue for 2D because sharded implies local tensor, not `DTensor`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89845 Approved by: https://github.com/zhaojuanmao
Author
Andrew Gu
Committer
Parents
Loading