diffusers
CogVideoX-5b-I2V support
#9418
Merged

CogVideoX-5b-I2V support #9418

zRzRzRzRzRzRzR
zRzRzRzRzRzRzR349 days ago (edited 347 days ago)🎉 11❤ 12🚀 10

The purpose of this PR is to adapt our upcoming CogVideoX-5B-I2V model to the diffusers framework:

  1. The model takes an image and text as input and outputs a video.
  2. The in-channel of the model has been modified to 32, while the rest of the model structure is similar to the 5B T2V.
  3. A new pipeline, CogVideoXImage2Video, has been created, and the documentation has been updated accordingly.

@a-r-r-o-w @zRzRzRzRzRzRzR

zRzRzRzRzRzRzR draft Init
6e3ae045
zRzRzRzRzRzRzR draft
ad78738a
zRzRzRzRzRzRzR vae encode image
8966671c
zRzRzRzRzRzRzR Merge branch 'huggingface:main' into cogvideox-5b-i2v
a56c5106
a-r-r-o-w make style
c238fe28
a-r-r-o-w image latents preparation
c1f7a800
a-r-r-o-w remove image encoder from conversion script
3df95b2c
a-r-r-o-w fix minor bugs
677a5530
a-r-r-o-w make pipeline work
4f518298
a-r-r-o-w make style
33c7cd6b
a-r-r-o-w remove debug prints
bc07f9f0
a-r-r-o-w fix imports
98f10238
a-r-r-o-w update example
aa12e1b5
a-r-r-o-w make fix-copies
1970f4fa
a-r-r-o-w add fast tests
e044850c
a-r-r-o-w Merge branch 'main' into cogvideox-5b-i2v
f7d8e37c
a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu yiyixuxu 348 days ago
a-r-r-o-w a-r-r-o-w requested a review from sayakpaul sayakpaul 348 days ago
a-r-r-o-w fix import
9f6f3f64
HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev348 days ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

sayakpaul
sayakpaul commented on 2024-09-13
sayakpaul348 days ago

I left a few comments but all of them very minor in nature. Basically, this PR looks solid to me and it shouldn't take much time to merge.

Off to @yiyixuxu.

Conversation is marked as resolved
Show resolved
src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
10891089 return self.tiled_encode(x)
10901090
10911091 frame_batch_size = self.num_sample_frames_batch_size
1092
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
10921093
enc = []
1093
for i in range(num_frames // frame_batch_size):
1094
for i in range(num_batches):
sayakpaul348 days ago

Better, nice!

src/diffusers/models/transformers/cogvideox_transformer_3d.py
465465 hidden_states = self.proj_out(hidden_states)
466466
467467 # 5. Unpatchify
468
# Note: we use `-1` instead of `channels`:
469
# - It is okay to use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
470
# - However, for CogVideoX-5b-I2V, input image (number of input channels is twice the output channels)
sayakpaul348 days ago

I think this is sufficiently supplemented with a comment, it should be fine!

Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
49 >>> pipe.to("cuda")
50
51 >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
52
>>> image = load_image("astronaut.jpg") # TODO: Add link to 720x480 image from HF Docs repo
sayakpaul348 days ago

To update before merge.

Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
434 extra_step_kwargs["generator"] = generator
435 return extra_step_kwargs
436
437
def check_inputs(
438
self,
439
prompt,
440
height,
441
width,
442
negative_prompt,
443
callback_on_step_end_tensor_inputs,
444
video=None,
445
latents=None,
446
prompt_embeds=None,
447
negative_prompt_embeds=None,
448
):
sayakpaul348 days ago

If the input image needs to follow any constraints, we could check for them and error out accordingly.

Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
490 if video is not None and latents is not None:
491 raise ValueError("Only one of `video` or `latents` should be provided")
492
493
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
494
def fuse_qkv_projections(self) -> None:
495
r"""Enables fused QKV projections."""
496
self.fusing_transformer = True
497
self.transformer.fuse_qkv_projections()
498
499
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
500
def unfuse_qkv_projections(self) -> None:
501
r"""Disable QKV projection fusion if enabled."""
502
if not self.fusing_transformer:
503
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
504
else:
505
self.transformer.unfuse_qkv_projections()
506
self.fusing_transformer = False
sayakpaul348 days ago

@yiyixuxu

I think it'd be okay to add this because @a-r-r-o-w has already verified that QKV fusion provides speed benefits when combined with quantization and torch.compile(). Since the transformer for this pipeline isn't changing, I'd expect to see similar speedups here.

But LMK if you think otherwise.

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
764 image_rotary_emb=image_rotary_emb,
765 return_dict=False,
766 )[0]
767
noise_pred = noise_pred.float()
sayakpaul348 days ago

This seems interesting. Why do we have to manually perform the upcasting here?

a-r-r-o-w348 days ago (edited 348 days ago)

I think @yiyixuxu would better be able to answer this since it was copied over from other Cog pipelines. IIRC, the original codebase had an upcast here too which is why we kept it too

Conversation is marked as resolved
Show resolved
docs/source/en/api/pipelines/cogvideox.md
9898 - all
9999 - __call__
100100
101
## CogVideoXImageToVideoPipeline
102
103
[[autodoc]] CogVideoXImageToVideoPipeline
104
- all
105
- __call__
sayakpaul348 days ago

If there's any restrictions on precisions (like fp16 shouldn't be used or something) that shouldn't be used we could add them in the docs.

Conversation is marked as resolved
Show resolved
tests/pipelines/cogvideo/test_cogvideox_image2video.py
283 "VAE tiling should not affect the inference results",
284 )
285
286
@unittest.skip("xformers attention processor does not exist for CogVideoX")
287
def test_xformers_attention_forwardGenerator_pass(self):
288
pass
sayakpaul348 days ago

It should already work without this modification. We have to just set the test_xformers_attention attribute of CogVideoXPipelineFastTests to False.

https://github.com/huggingface/diffusers/blob/6dc6486565ea1d8d1be567eefc1094e9185560a1/tests/pipelines/test_pipelines_common.py#L1648C21-L1648C35

Conversation is marked as resolved
Show resolved
tests/pipelines/cogvideo/test_cogvideox_image2video.py
323 ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
324 assert np.allclose(
325 original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
326
), "Original outputs should match when fused QKV projections are disabled."
sayakpaul348 days ago

Would you maybe like to add a slow integration test as well with a @unittest.skip marker on top so that we know what kind of slices to be expected?

a-r-r-o-w348 days ago

I think it would be better to add a slow test once the model is public, in a follow-up PR, because it would fail after this is merged into main, no?

sayakpaul348 days ago👍 1

As mentioned, if we always skip it (with the marker), it shouldn't matter but the test still remains for our convenience.

yiyixuxu
yiyixuxu approved these changes on 2024-09-13
yiyixuxu348 days ago

thanks! left some minor comments, feel free to merge once addressed!

Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
360 f" size of {batch_size}. Make sure the batch size matches the length of the generators."
361 )
362
363
assert image.ndim == 4
yiyixuxu348 days ago
Suggested change
assert image.ndim == 4

It is not a method users would directly use, is it? So, we don't need an assertion here, but we can make a note about the expected shape.

Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
363 assert image.ndim == 4
364 image = image.unsqueeze(2) # [B, C, F, H, W]
365
366
if isinstance(generator, list):
367
if len(generator) != batch_size:
368
raise ValueError(
369
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
370
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
371
)
yiyixuxu348 days ago
Suggested change
if isinstance(generator, list):
if len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

I think these lines are duplicates, we did this at line 357

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
751 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
752
753 latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
754
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
yiyixuxu348 days ago👀 1

interesting, they don't add noise to the image

a-r-r-o-w update vae
877cdc0c
a-r-r-o-w update docs
29f10070
a-r-r-o-w update image link
0c1358c4
a-r-r-o-w apply suggestions from review
8222a55f
a-r-r-o-w Merge branch 'main' into cogvideox-5b-i2v
61831bd3
a-r-r-o-w
a-r-r-o-w commented on 2024-09-13
Conversation is marked as resolved
Show resolved
scripts/convert_cogvideox_to_diffusers.py
9096 "freqs_cos": remove_keys_inplace,
9197 "position_embedding": remove_keys_inplace,
98 # TODO zRzRzRzRzRzRzR: really need to remove?
99
"pos_embedding": remove_keys_inplace,
a-r-r-o-w348 days ago (edited 348 days ago)

Note that CogVideoX-5b-I2V uses learned positional embeddings after the patch embedding as well as RoPE embeddings on QK.

TODO: As discussed with Yuxuan, we are able to generate videos without the learned positional embeddings but ideally this is needed even if its role is minimal. This will limit multi-resolution or multi-frame generation as we can't dynamically generate the learned embeddings on-the-fly. Will work together on adding this

a-r-r-o-w apply suggestions from review
2d8dce9d
a-r-r-o-w add slow test
4f894269
a-r-r-o-w make use of learned positional embeddings
21a6f79b
a-r-r-o-w a-r-r-o-w requested a review from sayakpaul sayakpaul 347 days ago
sayakpaul
sayakpaul commented on 2024-09-13
scripts/convert_cogvideox_to_diffusers.py
7884 "mixins.final_layer.norm_final": "norm_out.norm",
7985 "mixins.final_layer.linear": "proj_out",
8086 "mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
87
"mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V
sayakpaul347 days ago

Should we have any if/else to guard that accordingly?

a-r-r-o-w347 days ago

This layer is absent in the T2V models actually. It's called positional_embedding in T2V which is just sincos PE, while pos_embedding here. I think it's safe but going to verify it now

a-r-r-o-w347 days ago

Yep, this is safe and should not affect the T2V checkpoints since they follow different layer naming conventions

sayakpaul
sayakpaul commented on 2024-09-13
src/diffusers/models/embeddings.py
421 if self.use_positional_embeddings or self.use_learned_positional_embeddings:
422 if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
423 raise ValueError(
424
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
sayakpaul347 days ago

In other words, the 2b variant supports it?

a-r-r-o-w347 days ago

Yes, we had some success with multiresolution inference quality on 2B T2V. The reason for allowing this is to not confine lora training to 720x480 videos on 2B model. 5B T2V will skip this entire branch. 5B I2V use positional embeddings that were learned, so we can't generate them on-the-fly like sincos for the 2B T2V model

sayakpaul
sayakpaul commented on 2024-09-13
Conversation is marked as resolved
Show resolved
src/diffusers/models/transformers/cogvideox_transformer_3d.py
235235 spatial_interpolation_scale: float = 1.875,
236236 temporal_interpolation_scale: float = 1.0,
237237
use_rotary_positional_embeddings: bool = False,
238
use_learned_positional_embeddings: bool = False,
sayakpaul347 days ago

Can both be true? If not, I would maybe add a check to error as early as possible.

a-r-r-o-w347 days ago

Both are true in case of CogVideoX-5b-I2V

sayakpaul347 days ago

Okay. But can we use combination of accepted values here? Not saying we should test all of them but I think it'd be good to be aware.

a-r-r-o-w347 days ago

The accepted combinations would be False, False (for 2B T2V), True, False (for 5B T2V), True, True (for 5B I2V). Do you mean I should explicitly document this here oro add an error check for the missing case False, True (I don't its needed though tbh)?

sayakpaul347 days ago

Either should be fine but I would prefer an error.

a-r-r-o-w347 days ago

Added an error

sayakpaul
sayakpaul commented on 2024-09-13
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
53 >>> image = load_image(
54 ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
55 ... )
56
>>> video = pipe(image, prompt, use_dynamic_cfg=True)
sayakpaul347 days ago

Should we let the users know why use a dynamic CFG?

a-r-r-o-w347 days ago

Ah, this is just a remnant from my script. One can generate without it as well for similar quality but the Cog folks recommend dynamic CFG. WDYT we should do?

sayakpaul347 days ago

Let's follow the recommendation of the authors then!

sayakpaul
sayakpaul commented on 2024-09-13
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
776
777 # perform guidance
778 if use_dynamic_cfg:
779
self._guidance_scale = 1 + guidance_scale * (
780
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
781
)
sayakpaul347 days ago❤ 1

(can revisit later)

This can introduce graph-breaks because we are combining non-torch operations with torch tensors. .item() is a data-dependent call and can also lead to performance issues.

Just noting so that we can revisit if needs be.

sayakpaul
sayakpaul commented on 2024-09-13
Conversation is marked as resolved
Show resolved
tests/pipelines/cogvideo/test_cogvideox_image2video.py
63
64 def get_dummy_components(self):
65 torch.manual_seed(0)
66
transformer = CogVideoXTransformer3DModel(
sayakpaul347 days ago

Should we check for learned position embeddings too?

a-r-r-o-w347 days ago

Oh, nice catch! Yes, we should have that parameter set to True for the I2V tests

sayakpaul
sayakpaul approved these changes on 2024-09-13
sayakpaul347 days ago

Looks good. My comments are minor, not blockers at all.

a-r-r-o-w apply suggestions from review
6ce07784
zRzRzRzRzRzRzR Merge branch 'huggingface:main' into cogvideox-5b-i2v
7e637d6c
zRzRzRzRzRzRzR doc change
6f313e85
zRzRzRzRzRzRzR zRzRzRzRzRzRzR changed the title Cogvideox 5b i2v draft CogVideoX-5b-I2V support 347 days ago
a-r-r-o-w Merge branch 'main' into cogvideox-5b-i2v
ed8bda96
zRzRzRzRzRzRzR Update convert_cogvideox_to_diffusers.py
c8ec68ca
a-r-r-o-w make style
33056c54
a-r-r-o-w final changes
6dc9bdb5
a-r-r-o-w
a-r-r-o-w345 days ago

Will be merging after CI turns green. Will take up any changes in follow-up PRs

a-r-r-o-w make style
edeb626f
a-r-r-o-w fix tests
380a820c
a-r-r-o-w a-r-r-o-w merged 8336405e into main 345 days ago
tin2tin
tin2tin344 days ago

OSError: THUDM/CogVideoX-5b-I2V is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'

a-r-r-o-w
a-r-r-o-w344 days ago

The planned date for the model release in some time in the next few days when the CogVideoX team is ready. Until then, we will be preparing for a Diffusers patch release to ship the pipeline

zRzRzRzRzRzRzR
zRzRzRzRzRzRzR344 days ago

Thank you for your support! We expect to open source the project next week. If the release patch can be published before then, it would be a great help to us.

The planned date for the model release in some time in the next few days when the CogVideoX team is ready. Until then, we will be preparing for a Diffusers patch release to ship the pipeline

zRzRzRzRzRzRzR zRzRzRzRzRzRzR deleted the cogvideox-5b-i2v branch 225 days ago

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone