diffusers
e377c0a4 - Fix fp16 LoRA unscale crash after validation in train_dreambooth_lora.py (#13895)

Commit
24 days ago
Fix fp16 LoRA unscale crash after validation in train_dreambooth_lora.py (#13895) When training with `--mixed_precision="fp16"` and `--validation_prompt`, the first optimizer step after a validation run fails with `ValueError: Attempting to unscale FP16 gradients`. Under fp16, `cast_training_params` keeps the trainable LoRA params in fp32. The in-loop validation pipeline is built with the same live `unet` object, and `log_validation` then calls `pipeline.to(device, dtype=torch_dtype)`, which downcasts those fp32 LoRA params back to fp16. The next backward therefore produces fp16 grads and `GradScaler.unscale_` raises. Drop the dtype cast from that `.to(...)` so the shared `unet` keeps its fp32 LoRA params. This matches train_dreambooth_lora_sdxl.py, which moves the validation pipeline with `.to(accelerator.device)` only. Fixes #13124 Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Author
Parents
Loading