diffusers
62bfa5a2 - fix(freeu): run FFT in float32 for float16 inputs to avoid ComplexHalf (#13503)

Commit
18 days ago
fix(freeu): run FFT in float32 for float16 inputs to avoid ComplexHalf (#13503) * fix(freeu): run FFT in float32 for float16 inputs to avoid ComplexHalf `fourier_filter` already upcasts `bfloat16` inputs to `float32` before calling `torch.fft.fftn`, because PyTorch's FFT does not support bf16. The same is true for `float16`: depending on the PyTorch version, `fftn` either - produces the experimental `torch.complex32` (ComplexHalf) dtype and emits a `UserWarning: ComplexHalf support is experimental`, or - raises `RuntimeError: Unsupported dtype Half` outright. Both paths were reachable from FreeU with half-precision models (e.g. `sd-turbo` + `fp16` + `enable_freeu`) as reported in #12504. Extend the existing upcast branch to cover `float16` too. The function already downcasts the result back to `x_in.dtype` at the end, so the externally observable dtype is unchanged. Closes #12504. * Address review: generalize upcast to non-float32 + fix ruff F821 - Apply @sayakpaul's suggestion: use `elif x.dtype != torch.float32:` so any non-float32 dtype (bf16, fp16, and future half-precision dtypes) is upcast to float32 before the FFT. - Drop the `"torch.Tensor"` return annotation on the test helper that triggered ruff F821 in CI (torch is imported inside the method body, not at module scope).
Author
Parents
Loading