[FSDP] Retire `self.device_id`; clean up ctor (#83663)
### Overview
This PR retires `self.device_id` by coalescing it with `self.compute_device` and more generally cleans up the FSDP constructor.
### Existing FSDP Constructor Semantics (In Order)
1. Compute the ignored parameters/modules from `ignored_modules` and the buffer names (to avoid cloning in `state_dict()`)
2. Recursively auto wrap if needed
5. Define process group attributes
6. Determine `device_id`
7. Materialize the wrapped module if using meta device or `torchdistX` deferred initialization
8. Move the module if needed (based on `self.device_id`)
9. Determine `compute_device`
10. Define `training_state`, gradient divide factors, FSDP feature-related attributes (`cpu_offload`, `forward_prefetch`, `backward_prefetch`, `sharding_strategy`, `mixed_precision`), `_orig_buffer_dtypes`
11. Determine the parameters to flatten
12. Sync module states if `sync_module_states`
13. Initialize the `FlattenParamsWrapper` with the parameters to flatten and the wrapped module, which constructs the `FlatParameter`
14. Shard the `FlatParameter` (in-place)
15. Define `_is_root`, shared attributes (`_streams`, `_fsdp_graph_order`), prefetching attributes (`_my_fsdp_idx_in_graph`, `_pre_backward_hook_full_params_prefetched`, `_forward_full_params_prefetched`), `reshard_after_forward` -- all of this is done in `_reset_lazy_init()`
16. Define `_require_backward_grad_sync` to configure `no_sync()`
17. Define state dict attributes (`_state_dict_type`, `_state_dict_config`) and register state dict hooks
18. Define backward pass flags (`_pre_backward_hook_has_run`, `_need_rebuild_full_params`)
19. Move `FlatParameter`s to CPU if `cpu_offload.offload_params`
20. Define `_exec_order_data` for execution order validation
21. Define communication hook attributes (`communication_hook`, `communication_hook_state`, `_hook_registered`)
### Notable Changes
- `self.mixed_precision`
- **Before:** `self.mixed_precision` itself could be `None`. Equivalently, `self.mixed_precision` could be `MixedPrecision(None, None, None)`. Both would disable mixed precision completely.
- **After:** `self.mixed_precision` itself is never `None`. We only have `MixedPrecision(None, None, None)` (default construction of the `dataclass`) to disable mixed precision. This catches the issue that for `test_summon_full_params.py`, we were passing `MixedPrecision(None, None, None)` when we wanted to actually enable mixed precision.
- `cpu_offload.offload_params=True` + `device_id`
- **Before:** For nested FSDP and `device_id` specified, `FlatParameter`s already offloaded to CPU are moved back to GPU and not re-offloaded to CPU.
- **After:** The nested `FlatParameter`s are re-offloaded to CPU. This is a temporary hack. The ideal solution removes the `module = module.to(<GPU device>)` in the first place and only moves the relevant parameters. Because the `module.to()` implementation has some complexity, I did not want to remove that call in this PR.
- `device_id` and `compute_device`
- **Before:** `self.device_id` is either `None` or equal to `self.compute_device`. `self.device_id` is not used after the FSDP constructor.
- **After:** `self.device_id` is removed and instead coalesced with `self.compute_device`. The only semantic change is that `test_module_device_mismatches_device_id()` errors earlier (but importantly, still errors).
- This PR also uses a helper method `_get_orig_params()`, which is more robust and may avoid issues like https://github.com/pytorch/pytorch/issues/82891 without having to gate higher-level logic.
- `_reset_lazy_init()` attributes
- **Before:** Some attributes were being _defined_ in `_reset_lazy_init()` (which may not be obvious to all devs).
- **After:** For this PR, we define these attributes in the constructor but leave `_reset_lazy_init()` as is. In the follow-ups, this gets further refactored.
- Otherwise, I simply moved some logic into their own methods and reorganized the attribute definitions to be grouped logically.
### Follow-Ups
1. What should the specification be for `device_id` + `ignored_modules`?
2. Investigate removing the `module = module.to(<GPU device>)` in favor of moving per parameter.
3. Should we call `_reset_lazy_init()` in `register_comm_hook()`?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83663
Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma