diffusers
[refactor] move positional embeddings to patch embed layer for CogVideoX
#9263
Merged

[refactor] move positional embeddings to patch embed layer for CogVideoX #9263

a-r-r-o-w merged 12 commits into main from cogvideox/pipeline-followups
a-r-r-o-w
a-r-r-o-w266 days ago (edited 257 days ago)

What does this PR do?

  • removes the 49-frame limit since CogVideoX-5B generalizes better than 2B and is able to generate more frames

  • moves the positional embedding creation logic to the pipeline similar to rotary embeddings

  • move the positional embedding logic to patch embed layer

as a side-effect of this PR, one can generate > 49 frames with CogVideoX-2b which will produce bad results, but we can add a recommendation about this.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yiyixuxu

a-r-r-o-w remove frame limit in cogvideox
7fa2bde6
a-r-r-o-w remove debug prints
22311d1c
a-r-r-o-w a-r-r-o-w changed the title [refactor] removes the frame limititation in CogVideoX [refactor] removes the frame limitation in CogVideoX 266 days ago
a-r-r-o-w a-r-r-o-w changed the title [refactor] removes the frame limitation in CogVideoX [refactor] remove the frame limitation in CogVideoX 266 days ago
HuggingFaceDocBuilderDev
a-r-r-o-w
a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu yiyixuxu 266 days ago
a-r-r-o-w
a-r-r-o-w commented on 2024-08-24
Conversation is marked as resolved
Show resolved
src/diffusers/models/transformers/cogvideox_transformer_3d.py
426409 # 3. Position embedding
427410 text_seq_length = encoder_hidden_states.shape[1]
428411
if not self.config.use_rotary_positional_embeddings:
a-r-r-o-w266 days ago
Suggested change
if not self.config.use_rotary_positional_embeddings:
if not self.config.use_rotary_positional_embeddings and positional_emb is not None:
tin2tin
a-r-r-o-w
tin2tin
a-r-r-o-w
tin2tin
tin2tin
a-r-r-o-w
a-r-r-o-w
tin2tin
a-r-r-o-w Update src/diffusers/models/transformers/cogvideox_transformer_3d.py
f8f03a1e
a-r-r-o-w Merge branch 'main' into cogvideox/pipeline-followups
49b804cf
yiyixuxu
yiyixuxu commented on 2024-08-25
yiyixuxu265 days ago

thanks!
I'm aware there is some discrepancy between the the rotary and sincos embedding, but moving it to pipeline will create more discrepancy from rest of the code base so I think we can either leave it as it is for now, or maybe consider adding a positional embedding in transformer for rotary embedding too (and we can apply same patten across all relevant pipeline in a follow-up PR)
let me know what you think!

Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
443443 self.transformer.unfuse_qkv_projections()
444444 self.fusing_transformer = False
445445
446
def _prepare_normal_positional_embeddings(
yiyixuxu265 days ago

Positional embedding is always part of the models, though, not the pipeline so I think we should not move the positional embedding to pipeline here

for rotary embedding, there is indeed some discrepancy, hunyuan-dit, lumina and stable audio and now cogvideox all create it in pipeline but flux make it in transformer

self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)

design-wise, I think the flux method is the most nice (making a positional embedding that works for both rotary and sincos), it might have a slowdown due to the fact that we recompute it every step, but should be negotiable and we do that for other conditions anyway - should we consider that design if we want to make a refactor here?

a-r-r-o-w264 days ago (edited 264 days ago)

Hmm... you are right that we do it for other conditions so saving some here isn't going to make a difference. My reasoning for this change was that I don't want there to be a limit on the number of frames imposed by the pipeline (previously, it was limited by CogVideoXTransformer3D which created the sincos embeds based on sample_frames parameter, and so just to give a cleaner error I put this bit of code here). I also know that you can change the parameter at initialization in .from_pretrained(..., sample_frames=<WHATEVER MAX LIMIT YOU WANT>) to get around this, but it seems like an extra step for anyone wanting to generate a different number of frames and not something you'd expect a user to know

I think that the FluxPosEmbed method is the best to go forward with, since you would know the shapes in the forward pass. Just to confirm, this would remove image_rotary_emb from the pipeline, and based on the transformer config use_rotary_positional_embedding, we make CogVideoXPosEmbed decide whether to return positional or rotary embeds?

yiyixuxu264 days ago (edited 264 days ago)

for positional embedding, I think we can move it inside CogVideoXPatchEmbed (see how our PatchEmbed works here)

from there, we have a few choices:

  • you can generate the pos_embed during initialization based on sample_frames and then re-generate it during inference time whennum_frames user passed is larger than the sample_frames(similar to here) - I don't like this very much. I think if we want to keep this sample_frames config, it is better to throw an error at inference time and ask user to override the config with from_pretrained(...) method because it is not a config you need to change on every generation as it is just a maximum, i.e. if you set it at 100, you can generate any number of frames <=100

  • If you want to change it to generate the positional embedding on the fly for each generation based on the num_frames config, I'm okay with that too! As long as it is tested and we're sure that it will not result in worse performance in terms of memory and speed.

We can do Rotary embedding in a separate PR if we want to refactor it because the current implementation is consistent with most pipelines, and if we want to change it, we would need to change for all of them

a-r-r-o-w263 days ago

I don't like the idea of re-generating it per forward pass either, as it seems like unnecessary extra time added. I think the from_pretrained(..., sample_frames=X) method is the best way to do it then.

Let's hold on a bit with the RoPE refactor, and just use this PR to remove the num_frame limitation and pos embeds to CogVideoXPatchEmbed to make it consistent. I think what we have currently in the pipeline is nice since it saves the recomputation and consistent with the rest. FluxPosEmbed like method could be used in all pipelines eventually - I have a design idea to introduce something at the DiffusionPipeline level that would allow us to cache the conditions that are re-computed at each denoising step (encoder_hidden_states intermediates, for example), even though they don't have to be. I will try and open a PoC soon that should not be intrusive to any of the existing modeling code but help make inference speed go brrrr

yiyixuxu263 days ago

I have a design idea to introduce something at the DiffusionPipeline level that would allow us to cache the conditions that are re-computed at each denoising step (encoder_hidden_states intermediates, for example), even though they don't have to be. I will try and open a PoC soon that should not be intrusive to any of the existing modeling code but help make inference speed go brrrr

would be interesting to see some experiments on speed improvement if you have time, but

  1. we tested it before and found the speed difference is very insignificant. That's why we went with the current design
  2. if the difference is indeed significant, we have the option to have all the projection layers as a separate module module so they are only computed once - so no need for cache I think
a-r-r-o-w263 days ago

thanks! is there any discussion/PR for reference on this so I can better understand what's already been tried and set up some expectations? I am particularly curious about deep stacked models like Flux and Cog which might really benefit from this. i'll try and get some numbers end of the week

a-r-r-o-w263 days ago

Hmm, so I gave it a shot for Cog. I forgot that the joint text+video attention recomputes the encoder_hidden_states which is a must at each step, and so we can't save any time in the projection layers there. The only place where this is usable is the patch embed projection and first transformer layer norm, and fwiw I tried it out to save a mere ~1 second :( I see what you mean about having tested in the past and finding it insignificant

a-r-r-o-w Merge branch 'main' into cogvideox/pipeline-followups
c464655d
a-r-r-o-w revert pipeline; remove frame limitation
92a2f7e6
a-r-r-o-w revert transformer changes
392f726a
a-r-r-o-w address review comments
431ad60c
a-r-r-o-w add error message
555ed913
a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu yiyixuxu 263 days ago
yiyixuxu
yiyixuxu commented on 2024-08-28
yiyixuxu
yiyixuxu commented on 2024-08-28
yiyixuxu262 days ago

I left a comment!

a-r-r-o-w
a-r-r-o-w commented on 2024-08-28
yiyixuxu
yiyixuxu commented on 2024-08-28
a-r-r-o-w apply suggestions from review
b3b9ecca
a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu yiyixuxu 257 days ago
a-r-r-o-w Merge branch 'main' into cogvideox/pipeline-followups
bf2907f4
a-r-r-o-w a-r-r-o-w changed the title [refactor] remove the frame limitation in CogVideoX [refactor] move positional embeddings to patch embed layer for CogVideoX 257 days ago
yiyixuxu
yiyixuxu commented on 2024-09-02
yiyixuxu
yiyixuxu approved these changes on 2024-09-02
yiyixuxu257 days ago (edited 257 days ago)

thanks for the PR!

a-r-r-o-w Merge branch 'main' into cogvideox/pipeline-followups
b31dcb88
a-r-r-o-w
a-r-r-o-w a-r-r-o-w merged 9d49b45b into main 256 days ago
a-r-r-o-w a-r-r-o-w deleted the cogvideox/pipeline-followups branch 256 days ago

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone