multi-GPU VAE Fix for Cosmos 3 (#13924)
fix(cosmos3): pin VAE latent norm buffers to encode output device
Under sharded placement (device_map="balanced"), vae.encode() runs on the
VAE's own device while the mean/inv_std buffers were pinned to x.device,
causing a cross-device RuntimeError. Compute raw_mu first, then pin the
normalization buffers to its device so all tensors share one device.
Co-authored-by: Atharva Joshi <atjoshi@smc521ge-0036.ipp2a2.colossus.nvidia.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>