[FSDP2] Added initial meta-device init support (#120351)
This PR adds initial support for meta-device initialization for pre-training without loading from a state dict. The idea is to allow `fully_shard(module)` to return and still have sharded parameters on meta device. Then, the user is free to initialize them as they please, e.g. using `to_empty()`.
We override `_apply` to achieve the following:
- Reshard the parameters to ensure that sharded parameters are registered (for correctness) -- we will always need this
- Pad new local tensors and use the padded local tensors (to handle uneven sharding) -- we will remove this once `DTensor` pads its local tensor
We use the `swap_tensors` path in `_apply`. For now, this requires setting `torch.__future__.set_swap_module_params_on_conversion(True)`; however, in the future, this may be enabled by default for wrapper subclasses and will not need any explicit API call. If requiring this call is too intrusive in the short term, we can also call it in `_apply` or when importing `fully_shard`.
```
# Pre-training flow (no checkpoint)
global_mesh = init_device_mesh(..., mesh_dim_names=("dp", "tp"))
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
with torch.device("meta"):
model = ...
parallelize_module(model, tp_mesh, ...)
fully_shard(model, mesh=dp_mesh, ...)
for param in model.parameters():
assert param.device.type == "meta"
model.to_empty(device="cuda")
random.manual_seed(42, global_mesh)
for module in model.modules():
if hasattr(module, "reset_parameters"):
module.reset_parameters()
```
This PR includes some minor changes to allow the user to similarly cast the module to a different dtype after construction time but before forward.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120351
Approved by: https://github.com/wanchaol