[FSDP] Fix wrapped module changing after ctor (#87837)
Recently, I retired `FlattenParamsWrapper`, which meant that FSDP registers its `FlatParameter` on the wrapped module instead of the `FlattenParamsWrapper` instance. This is only relevant for `use_orig_params=False`.
If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the `FlatParameter` is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the `FlatParameter` is currently registered as an early return condition for `rank0_only=True`.
The solution in this PR is to re-establish the wrapped module in `_lazy_init()`, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon `_lazy_init()`.
The direct access to the private attribute `_parameters` from `nn.Module` is not ideal, but we already rely on it for the dynamic `FlatParameter` registration. The tradeoff is whether we want an additional `nn.Module` wrapper (`FlattenParamsWrapper`) and use `delattr` plus a singleton list to do the dynamic registration or we want to access `_parameters`. If this becomes a problem, we can work with Core team on a solution.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87837
Approved by: https://github.com/zhaojuanmao