diffusers
bd2c9195 - multi-GPU VAE Fix for Cosmos 3 (#13924)

Commit
3 days ago
multi-GPU VAE Fix for Cosmos 3 (#13924) fix(cosmos3): pin VAE latent norm buffers to encode output device Under sharded placement (device_map="balanced"), vae.encode() runs on the VAE's own device while the mean/inv_std buffers were pinned to x.device, causing a cross-device RuntimeError. Compute raw_mu first, then pin the normalization buffers to its device so all tensors share one device. Co-authored-by: Atharva Joshi <atjoshi@smc521ge-0036.ipp2a2.colossus.nvidia.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Parents
Loading