Fix fsdp2 load full state dict dtype mismatch (#4021)
* [FSDP2] Capture original_sd after upcast to fix dtype mismatch hang
`fsdp2_prepare_model` was capturing `original_sd = model.state_dict()`
before the upcast-to-fp32 loop introduced in #3985. With
`mixed_precision != "no"` and `cpu_ram_efficient_loading=True`,
`original_sd` ended up holding bf16/fp16 references while the model
parameters (and therefore `meta_sharded_sd` after `fully_shard`) were
fp32.
Inside `fsdp2_load_full_state_dict`, rank 0 broadcasts the bf16/fp16
tensor from `original_sd` while rank N allocates `torch.empty(..., dtype=sharded_param.dtype)`
which is fp32. The element sizes do not match, so the `dist.broadcast`
deadlocks and training hangs forever.
Move the snapshot to after the upcast so its dtype matches the
post-`fully_shard` sharded state dict.
* Fix