[FSDP][3/N] Register `flat_param` to wrapped module (#87086)
This PR registers each `FlatParameter` to the wrapped module, eliminating `FlattenParamsWrapper` usage completely from FSDP.
Registering each `FlatParameter` to the wrapped module is preferred over registering to the `FullyShardedDataParallel` instance for both functional-like and non-recursive wrapping. It simplifies the `FlatParameter` naming to be a function of the number of `FlatParameter`s per wrapped module instead of the number of `FlatParameter`s per FSDP instance. For now, we assume 1 `FlatParameter` per wrapped module, so we can simply use a single name `FLAT_PARAM = _flat_param`.
From an implementation perspective, we raise some methods from `FlattenParamsWrapper` directly up to `FullyShardedDataParallel`. There will need to be further refactoring for functional-like and non-recursive wrapping. For example, the property `self._has_params -> bool` may need to change to a method `self._has_params(wrapped_module) -> bool`. Such changes are out of scope for this PR and will be done in follow-ups.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87086
Approved by: https://github.com/zhaojuanmao