The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
The implementation here gives the following result:
Context length = 16, stride = 8 | Context length = 16, stride = 8 |
![]() |
![]() |
I'm not sure why the last few frames go berserk. I will take a better look soon since we need to rewrite this anyway.
The implementation you see here is horrible. Let me explain why it's done this way at the moment:
[bhw, f, c]
tensors into k = num_frames / context_length
chunks of shape [bhw, k, c]
and performing self/cross-attn on each chunk, following a weighted averaging of all frames.BasicTransformerBlock
. This makes it very challenging to do frame-wise chunked inference to determine if a pass is spatial or temporal. Currently, to make FreeNoise work, I determine this by some hardcoded logic which will be removed later once we address the design problem.import torch
from diffusers import AnimateDiffPipeline
from diffusers.models import AutoencoderKL, MotionAdapter
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-2"
vae_id = "stabilityai/sd-vae-ft-mse"
device = "cuda"
motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
beta_schedule="linear",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
pipe = AnimateDiffPipeline.from_pretrained(
model_id,
motion_adapter=motion_adapter,
vae=vae,
scheduler=scheduler,
torch_dtype=torch.float16,
).to(device)
pipe.enable_free_noise(context_length=16, context_stride=4, shuffle=True)
prompt = "a racoon playing a guitar, sitting in a boat, floating in the ocean, high quality, realistic"
negative_prompt = "bad quality, worst quality"
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=512,
height=512,
# num_frames=16,
num_frames=80, # must be 80 for the current hardcoded logic to work
num_inference_steps=25,
guidance_scale=8,
generator=torch.Generator().manual_seed(1337),
).frames[0]
export_to_gif(video, "animatediff_freenoise.gif")
I would like to hear your thoughts on the following.
Just out of curiosity, what is the limit of "long" here? To me 16 is not long enough. Is it fair to expect a min-long video with free-noise? If not, perhaps it might be better to wait as the value add is not that evident. Of course, I could be looking at it entirely wrong. So, happy to stand corrected.
Just out of curiosity, what is the limit of "long" here? To me 16 is not long enough. Is it fair to expect a min-long video with free-noise? If not, perhaps it might be better to wait as the value add is not that evident. Of course, I could be looking at it entirely wrong. So, happy to stand corrected.
For text-to-video, I would say 5-30 seconds at 8-12 fps is currently considered "long" enough if the video is consistent. For control-based vid2vid, there is no limit really and you can animate really long videos (1+ mins), especially with Comfy.
The value added may not be evident from the example I showed. Keep in mind that I want to get FreeNoise to work first and iron out design/implementation bugs and then apply to other pipelines. For text-to-video, it is indeed really hard to achieve good videos even with FreeNoise. The place where it really shines is naive vid2vid, controlnet vid2vid, and in combination with other tricks, and I think we should really really support it because it opens up potential for many workflows such as the following within Diffusers. It is widely considered as the best open-source method for long generations (in vid2vid atleast from what I've seen). Essentially, all that's need to support it is chunked frame-wise inference in the BasicTransformerBlock and weighted-averaging of latents.
1, 2, 3, 4, 5, 6, 7 and 8 off the top of my saved reddit posts.
I don't have good workflows set up yet (WIP), so I stole this to generate videos with Comfy for the time being. You can find the results here.
cc @DN6
Thanks for explaining.
My take is we first make it work to convince ourselves about the results we feel good about. We can then work through this PR to reach a design. I think before that achieving those results would be a nice thing to optimize.
The FreeNoise generations at 768x432 (AnimateDiff_00003.gif) takes 3 minutes 12 seconds and has background consistency somewhat.
The non-FreeNoise generation at 768x432 (AnimateDiff_00003_nofreenoise.gif) takes 9 minutes and the background keeps changing with a few more artifacts on the body. This is using the Context Scheduler approach for which I have an open PR but we're still deciding if that's worth adding in comparison to just FreeNoise.
Which results are these in the videos you shared?
Thanks for explaining.
My take is we first make it work to convince ourselves about the results we feel good about. We can then work through this PR to reach a design. I think before that achieving those results would be a nice thing to optimize.
I see what you mean. Alright, I'll get it to work with our current community AnimateDiff controlnet implementation (which I really think should now be in core because of how broadly the Comfy equivalent is used) and SparseCtrl once it's merged.
Which results are these in the videos you shared?
Check this. AnimateDiff_00003.gif
is the FreeNoise version. AnimateDiff_00003-nofreenoise.gif
is the non-FreeNoise version. The other files are input and different settings with FreeNoise.
Here are some results to demonstrate the effectiveness of FreeNoise in vid2vid settings:
import requests
from io import BytesIO
import imageio
import torch
from controlnet_aux.processor import LineartDetector, OpenposeDetector
from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, DPMSolverMultistepScheduler, LCMScheduler
from diffusers.pipelines.animatediff.pipeline_animatediff_controlnet import AnimateDiffControlNetPipeline
from diffusers.utils import export_to_gif, export_to_video
from PIL import Image
# model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
model_id = "stablediffusionapi/darksushimixv225"
# model_id = "emilianJR/epiCRealism"
# motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
motion_adapter_id = "wangfuyun/AnimateLCM"
controlnet1_id = "/raid/aryan/hub/models--lllyasviel--ControlNet-v1-1/snapshots/69fc48b9cbd98661f6d0288dc59b59a5ccb32a6b/control_v11p_sd15_openpose.pth"
controlnet2_id = "/raid/aryan/hub/models--lllyasviel--ControlNet-v1-1/snapshots/69fc48b9cbd98661f6d0288dc59b59a5ccb32a6b/control_v11p_sd15_lineart.pth"
vae_id = "stabilityai/sd-vae-ft-mse"
device = "cuda:0"
motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id)
controlnet1 = ControlNetModel.from_single_file(controlnet1_id, torch_dtype=torch.float16)
controlnet2 = ControlNetModel.from_single_file(controlnet2_id, torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16)
pipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained(
model_id,
motion_adapter=motion_adapter,
controlnet=[controlnet1, controlnet2],
vae=vae,
).to(device=device, dtype=torch.float16)
# pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# timestep_spacing="linspace",
# beta_schedule="linear",
# algorithm_type="dpmsolver++",
# use_karras_sigmas=True,
# steps_offset=1,
# )
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
pipe.set_adapters(["lcm-lora"], [0.8])
def load_video(file_path: str):
images = []
if file_path.startswith(('http://', 'https://')):
# If the file_path is a URL
response = requests.get(file_path)
response.raise_for_status()
content = BytesIO(response.content)
reader = imageio.get_reader(content)
else:
# Assuming it's a local file path
reader = imageio.get_reader(file_path)
for frame in reader:
pil_image = Image.fromarray(frame)
images.append(pil_image)
return images
# skip_nth_frame = 1
# max_frames = 16
# video = load_video("input.gif")[::skip_nth_frame][:max_frames]
# width = 512
# height = 768
skip_nth_frame = 2
max_frames = 80
assert max_frames == 80 # Must be 80 for FreeNoise to work because of the hardcoded implementation at the moment
video = load_video("vid2vid_input2.mov")[::skip_nth_frame][:max_frames]
width = 768
height = 432
p1 = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to(device)
p2 = LineartDetector.from_pretrained("lllyasviel/Annotators").to(device)
cn1, cn2 = [], []
with pipe.progress_bar(total=len(video)) as progress_bar:
for frame in video:
cn1.append(p1(frame, include_body=True, include_hand=True, include_face=True))
cn2.append(p2(frame))
progress_bar.update()
prompt = "girl dancing, blue hair, high quality, surreal"
negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
# pipe.enable_free_init(use_fast_sampling=True)
# pipe.enable_free_init(use_fast_sampling=False)
pipe.enable_free_noise(context_length=16, context_stride=4, shuffle=True)
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_frames=len(video),
num_inference_steps=10,
guidance_scale=2.0,
conditioning_frames=[cn1, cn2],
controlnet_conditioning_scale=[0.5, 0.8],
decode_batch_size=16,
generator=torch.Generator().manual_seed(42),
).frames[0]
export_to_video(video, "animatediff_controlnet_long_1.mp4", fps=8)
Viewing the GIF in your browser will show various artifacts so the mp4 versions along with inputs can be found here. Ignore the last 8 frames in the output - it is caused due to an implementation bug that I'll look into soon.
cc @asomoza since we can write a few guides to improve long video generation quality once we have a more stable implementation merged
Looks pretty good. Good luck with the bug hunting!
This PR requires #8972 to be merged to remove the AnimateDiffControlNet changes done here for the sake of demo code to run if someone wants to replicate. While #8979 shoudn't really cause any problems with the implementation here for inference, it would be better for that to be merged as well and handle merge conflicts here.
Left some comments on why certain changes were made and what will get removed as some of the linked PRs are merged.
Big thanks to @DN6 for the design suggestions and helping me on integrating this. It's been long overdue and I can't wait to cook up some good tutorials on long video generation with Diffusers.
272 | 272 | attention_out_bias: bool = True, | |
273 | 273 | ): | |
274 | 274 | super().__init__() | |
275 | self.dim = dim |
These changes were made to initialize the FreeNoiseTransformerBlock correctly. I'm not sure how else we could determine these attributes in a "simple" way without accessing the interal pytorch dimensions which adds many many extra LOC after make style
.
20 | 20 | ||
21 | 21 | from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config | |
22 | from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin | ||
22 | from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin |
From #8995 to fix LoRA in UNetMotionModel. Once that's merged, these changes shouldn't be visible here
161 | |||
162 | # Define 3 blocks. Each block has its own normalization layer. | ||
163 | # 1. Self-Attn | ||
164 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) |
I've made some assumptions here and removed all the branches handling different configurations of layernorms and attention. We support multiple norm_type
in BasicTransformerBlock. I think that SD15 checkpoints, used in AnimateDiff, always use LayerNorm but am only 99% sure. LMK if any other norm types must be handled
214 | |||
215 | return frame_indices | ||
216 | |||
217 | def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: |
The original FreeNoise implementation proposes using a pyramid weighted averaging (see Eq. 9 of the paper. However, the diffusion community found different weighting schemes that also seem to work well in practice. While I haven't tested it deeply, I would like to keep the implementation to extension in the future. For now, let's roll with the original unless we can test different methods qualitatively before next release
229 | |||
230 | return weights | ||
231 | |||
232 | def set_free_noise_properties( |
Not sure what to name it, feel free to suggest. It's a helper function to change properties dynamically at inference from FreeNoiseMixin for already initialized FreeNoiseTransformerBlocks without doing the entire initialization part again
267 | frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1) | ||
268 | is_last_frame_batch_complete = frame_indices[-1][1] == num_frames | ||
269 | |||
270 | # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length | ||
271 | # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges: | ||
272 | # [(0, 16), (4, 20), (8, 24), (10, 26)] | ||
273 | if not is_last_frame_batch_complete: | ||
274 | if num_frames < self.context_length: | ||
275 | raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}") | ||
276 | last_frame_batch_length = num_frames - frame_indices[-1][1] | ||
277 | frame_indices.append((num_frames - self.context_length, num_frames)) |
The original implementation here does not seem to support this case. Essentially, if we select a context_length and context_stride, or have to process number of frames, such that perfect frame-wise batching is not possible, we try and process the full context_length amount of frames BUT only accumulate on the unaccomodated frames. I tested it and it works well on cases like num_frames=26, context_length=16, context_stride=4
322 | ) | ||
323 | hidden_states_chunk = attn_output + hidden_states_chunk | ||
324 | |||
325 | if i == len(frame_indices) - 1 and not is_last_frame_batch_complete: | ||
326 | accumulated_values[:, -last_frame_batch_length:] += ( | ||
327 | hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:] | ||
328 | ) | ||
329 | num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length] | ||
330 | else: | ||
331 | accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights | ||
332 | num_times_accumulated[:, frame_start:frame_end] += weights |
Specialized logic to handle unfitting frame batch case as described above. LMK if this needs to be more readable and possible suggestions
394 | 396 | ||
395 | 397 | return ip_adapter_image_embeds | |
396 | 398 | ||
397 | # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents | ||
398 | def decode_latents(self, latents): | ||
399 | def decode_latents(self, latents, vae_batch_size: int = 16): |
Need to support frame-wise chunking in all intermediate layers including ResNet blocks if we want to save memory otherwise this blows up. For now, we can process 64 frames on a 24 GB card by taking care of the VAE encode/decode. I suggest we get the FreeNoise functionality in first, and later take care of optimizing memory in different internal blocks later
519 | 525 | ||
526 | # If FreeNoise is enabled, shuffle latents in every window as described in Equation (7) of | ||
527 | # [FreeNoise](https://arxiv.org/abs/2310.15169) | ||
528 | if self.free_noise_enabled and self._free_noise_shuffle: |
I hope this is self-explanatory. I've tried making it as readable as possible but LMK if changes are needed.
When we initialise latents with the shape (batch, num_frames, channels, h, w), each latent is already independent of the others right? How does shuffling within windows help?
If it is needed. I think we can move it to a method inside the AnimateDiffFreeNoiseMixin
Isn't the shuffling done by creating an init noise tensor of shape (batch, context_size, channels, h, w) and then shuffling these tensors over the full video length?
Isn't the shuffling done by creating an init noise tensor of shape (batch, context_size, channels, h, w) and then shuffling these tensors over the full video length?
You are absolutely right. I was thinking about rewriting this as well to make the intent clearer.
The original implementation first creates (batch_size, num_latent_channels, num_frames, h, w)
but then in the for loop that follows, they set the values of first context_length
number of frames to the remaining full-video frames. I did it the same way as them because if FreeNoise is not enabled, we can just return the full-length latents already created above.
Just to elaborate more, assume context_length=16, context_stride=4, num_frames=32:
latents[0:16] <- rand tensor of shape (b, num_channel_latents, context_length, h, w)
latents[16:20] <- shuffle(latents[0:4])
latents[20:24] <- shuffle(latents[4:8])
latents[24:28] <- shuffle(latents[8:12])
latents[28:32] <- shuffle(latents[12:16])
... and so on if there are more. This is what we're doing here as well
1 | # Copyright 2024 The HuggingFace Team. All rights reserved. |
This file will be removed from here once my animatediff controlnet PR is merged :)
118 | weighting_scheme (`str`, defaults to `4`): | ||
119 | TODO(aryan) | ||
120 | shuffle (`str`, defaults to `True`): | ||
121 | TODO(aryan): decide if this is even needed |
Latent shuffling (or better explained as reusing context_length
number of latent frames intead of num_frames
in the pipeline) is very much required to improve temporal consistency. In my initial pass of the paper, I misunderstood what it meant and implemented incorrectly. Now it's done correctly and you can see that text2vid quality has significantly improved.
We can either remove this as parameter and always do shuffling, or leave it in for more experimental freedom in different settings. Either is okay with me
125 | self._free_noise_weighting_scheme = weighting_scheme | ||
126 | self._free_noise_shuffle = shuffle | ||
127 | |||
128 | blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] |
I've not added FreeNoise to AnimateDiff SparseCtrl yet because it fails somewhere in the Mid block. It seems that it will require more time but there are more important things to be attended to at the moment, so I propose to roll with this and revisit in near future
401 | 402 | "Enabling of FreeInit should lead to results different from the default pipeline results", | |
402 | 403 | ) | |
403 | 404 | ||
405 | def test_free_noise_blocks(self): |
I've added two tests for FreeNoise based on what I think are the most important parts. LMK if anything else is needed
569 | 601 | clip_skip: Optional[int] = None, | |
570 | 602 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | |
571 | 603 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
604 | vae_batch_size: int = 16, |
Let's use naming/logic similar to SVD for batch decoding.
This is also used in the vae encode for animatediff_video2video btw, but can rename it that
Login to write a write a comment.
What does this PR do?
FreeNoise is free lunch in existing short video diffusion models that allows longer video generation without additional training and almost no overhead in inference.
Project: http://haonanqiu.com/projects/FreeNoise.html
Paper: https://arxiv.org/abs/2310.15169
Code: https://github.com/arthur-qiu/FreeNoise-AnimateDiff
Fixes #5576.
Results
All inputs can be found here.
AnimateDiff Text-to-Video
pipeline_animatediff_freenoise-shuffle_False-context_length_16-context_stride_4.webm
pipeline_animatediff_freenoise-shuffle_True-context_length_16-context_stride_4.webm
pipeline_animatediff_freenoise-shuffle_True-context_length_20-context_stride_4.webm
pipeline_animatediff_freenoise-shuffle_True-context_length_20-context_stride_8.webm
pipeline_animatediff_freenoise-shuffle_True-context_length_24-context_stride_4.webm
pipeline_animatediff_freenoise-shuffle_True-context_length_24-context_stride_8.webm
Code
AnimateDiff ControlNet
pipeline_animatediff_controlnet_freenoise-shuffle_True-context_length_16-context_stride_4.webm
pipeline_animatediff_controlnet_freenoise-shuffle_True-context_length_16-context_stride_8.webm
Additionally, using the code here:
animatediff_controlnet_long_1.webm
animatediff_controlnet_long_2.webm
Code
AnimateDiff Video2Video
pipeline_animatediff_vid2vid_freenoise-shuffle_True-context_length_16-context_stride_4.webm
pipeline_animatediff_vid2vid_freenoise-shuffle_True-context_length_16-context_stride_8.webm
Code
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@DN6 @sayakpaul
cc @yiyixuxu as well for library design