accelerate
d1d2bba4 - Fix fsdp2 load full state dict dtype mismatch (#4021)

Commit
37 days ago
Fix fsdp2 load full state dict dtype mismatch (#4021) * [FSDP2] Capture original_sd after upcast to fix dtype mismatch hang `fsdp2_prepare_model` was capturing `original_sd = model.state_dict()` before the upcast-to-fp32 loop introduced in #3985. With `mixed_precision != "no"` and `cpu_ram_efficient_loading=True`, `original_sd` ended up holding bf16/fp16 references while the model parameters (and therefore `meta_sharded_sd` after `fully_shard`) were fp32. Inside `fsdp2_load_full_state_dict`, rank 0 broadcasts the bf16/fp16 tensor from `original_sd` while rank N allocates `torch.empty(..., dtype=sharded_param.dtype)` which is fp32. The element sizes do not match, so the `dist.broadcast` deadlocks and training hangs forever. Move the snapshot to after the upcast so its dtype matches the post-`fully_shard` sharded state dict. * Fix
Author
Parents
Loading