[core] FreeNoise #8948

DN6 merged 40 commits into main from freenoise
a-r-r-o-w
a-r-r-o-w332 days ago (edited 328 days ago)❤ 1

What does this PR do?

FreeNoise is free lunch in existing short video diffusion models that allows longer video generation without additional training and almost no overhead in inference.

Project: http://haonanqiu.com/projects/FreeNoise.html
Paper: https://arxiv.org/abs/2310.15169
Code: https://github.com/arthur-qiu/FreeNoise-AnimateDiff

Fixes #5576.

Results

All inputs can be found here.

AnimateDiff Text-to-Video

context_length=16, context_stride=4, shuffle=False context_length=16, context_stride=4, shuffle=True
pipeline_animatediff_freenoise-shuffle_False-context_length_16-context_stride_4.webm
pipeline_animatediff_freenoise-shuffle_True-context_length_16-context_stride_4.webm
context_length=20, context_stride=4, shuffle=True context_length=20, context_stride=8, shuffle=True
pipeline_animatediff_freenoise-shuffle_True-context_length_20-context_stride_4.webm
pipeline_animatediff_freenoise-shuffle_True-context_length_20-context_stride_8.webm
context_length=24, context_stride=4, shuffle=True context_length=24, context_stride=8, shuffle=True
pipeline_animatediff_freenoise-shuffle_True-context_length_24-context_stride_4.webm
pipeline_animatediff_freenoise-shuffle_True-context_length_24-context_stride_8.webm
  • num_frames: 64
  • duration: ~50-60s (25 steps)
Code
import torch

from diffusers import AnimateDiffPipeline, DPMSolverMultistepScheduler, AutoencoderKL, MotionAdapter
from diffusers.utils import export_to_video

device = "cuda:0"

# Initialize models and pipeline
motion_adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to(device)
pipe = AnimateDiffPipeline.from_pretrained(
    "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=motion_adapter, vae=vae, torch_dtype=torch.float16,
).to(device)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, beta_schedule="linear", algorithm_type="dpmsolver++", use_karras_sigmas=True)

num_seconds = 8
fps = 8
num_frames = num_seconds * fps

for context_length in [16, 20, 24]:
    for context_stride in [4, 8]:
        print(f"Processing {context_length=}, {context_stride=}")
        
        # Enable FreeNoise for long video generation
        pipe.enable_free_noise(context_length=context_length, context_stride=context_stride, weighting_scheme="pyramid", shuffle=True)

        # Run inference
        video = pipe(
            prompt="a panda, playing a guitar, sitting in a boat, in the ocean, mountains in background, sunny day, realistic, high quality",
            negative_prompt="bad quality, worst quality",
            num_frames=num_frames,
            num_inference_steps=25,
            guidance_scale=8,
            generator=torch.Generator().manual_seed(1337),
        ).frames[0]

        export_to_video(video, f"animatediff_freenoise/pipeline_animatediff_freenoise-shuffle_True-context_length_{context_length}-context_stride_{context_stride}.mp4", fps=fps)

        # Disable FreeNoise shuffling
        # pipe.disable_free_noise() # optional
        pipe.enable_free_noise(context_length=context_length, context_stride=context_stride, shuffle=False)

        # Run inference
        video = pipe(
            prompt="a panda, playing a guitar, sitting in a boat, in the ocean, mountains in background, sunny day, realistic, high quality",
            negative_prompt="bad quality, worst quality",
            num_frames=num_frames,
            num_inference_steps=25,
            guidance_scale=8,
            generator=torch.Generator().manual_seed(1337),
        ).frames[0]

        export_to_video(video, f"animatediff_freenoise/pipeline_animatediff_freenoise-shuffle_False-context_length_{context_length}-context_stride_{context_stride}.mp4", fps=fps)

AnimateDiff ControlNet

context_length=16, context_stride=4 context_length=16, context_stride=8
pipeline_animatediff_controlnet_freenoise-shuffle_True-context_length_16-context_stride_4.webm
pipeline_animatediff_controlnet_freenoise-shuffle_True-context_length_16-context_stride_8.webm

Additionally, using the code here:

context_length=16, context_stride=4 context_length=16, context_stride=8
animatediff_controlnet_long_1.webm
animatediff_controlnet_long_2.webm
  • num_frames: 104
  • duration: ~1m 45s (10 steps, 16, 4), ~1m 30s (10 steps, 16, 8)
Code
import torch

from controlnet_aux.processor import LineartAnimeDetector, OpenposeDetector
from diffusers.pipelines.animatediff.pipeline_animatediff_controlnet import AnimateDiffControlNetPipeline
from diffusers import ControlNetModel, LCMScheduler, AutoencoderKL, MotionAdapter
from diffusers.utils import export_to_video, load_video

device = "cuda:1"

# Initialize models and pipeline
controlnet1 = ControlNetModel.from_single_file("/raid/aryan/hub/models--lllyasviel--ControlNet-v1-1/snapshots/69fc48b9cbd98661f6d0288dc59b59a5ccb32a6b/control_v11p_sd15_openpose.pth", torch_dtype=torch.float16).to(device)
controlnet2 = ControlNetModel.from_single_file("/raid/aryan/hub/models--lllyasviel--ControlNet-v1-1/snapshots/69fc48b9cbd98661f6d0288dc59b59a5ccb32a6b/control_v11p_sd15_lineart.pth", torch_dtype=torch.float16)
motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM", torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to(device)
pipe = AnimateDiffControlNetPipeline.from_pretrained(
    "stablediffusionapi/darksushimixv225", controlnet=[controlnet1, controlnet2], motion_adapter=motion_adapter, vae=vae, torch_dtype=torch.float16,
).to(device)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")

pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
pipe.set_adapters(["lcm-lora"], [0.8])

# Video credits: https://stable-diffusion-art.com/animatediff-prompt-travel-video2video/
select_nth_frame = 2
video = load_video("https://stable-diffusion-art.com/wp-content/uploads/2023/10/man_dance_2to3_24fps_9s.mp4")[::select_nth_frame]
width = 512
height = 768

# Preprocess video
lineart_processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators").to(device)
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to(device)
conditioning_frames1 = []
conditioning_frames2 = []

with pipe.progress_bar(total=len(video)) as progress_bar:
    for frame in video:
        conditioning_frames1.append(openpose_processor(frame, include_body=True, include_hand=True, include_face=True))
        conditioning_frames2.append(lineart_processor(frame))
        progress_bar.update()

for context_length in [16, 24]:
    for context_stride in [4, 8]:
        print(f"Processing {context_length=}, {context_stride=}")

        # Enable FreeNoise for long video generation
        pipe.enable_free_noise(context_length=context_length, context_stride=context_stride, weighting_scheme="pyramid", shuffle=True)

        # Run inference
        video = pipe(
            prompt="man dancing, blue shirt, red shorts, psychedelic",
            negative_prompt="bad quality, worst quality, jpeg artifacts, ugly",
            conditioning_frames=[conditioning_frames1, conditioning_frames2],
            controlnet_conditioning_scale=[0.5, 0.4],
            width=width,
            height=height,
            num_frames=len(video),
            num_inference_steps=10,
            guidance_scale=2,
            generator=torch.Generator().manual_seed(42),
        ).frames[0]

        export_to_video(video, f"animatediff_freenoise/pipeline_animatediff_controlnet_freenoise-shuffle_True-context_length_{context_length}-context_stride_{context_stride}.mp4", fps=12)

AnimateDiff Video2Video

context_length=16, context_stride=4 context_length=16, context_stride=8
pipeline_animatediff_vid2vid_freenoise-shuffle_True-context_length_16-context_stride_4.webm
pipeline_animatediff_vid2vid_freenoise-shuffle_True-context_length_16-context_stride_8.webm
  • num_frames: 150
  • duration: ~5m (25 steps, 16, 4), ~2m 30s (25 steps, 16, 8)
Code
import torch

from diffusers import AnimateDiffVideoToVideoPipeline, DPMSolverMultistepScheduler, AutoencoderKL, MotionAdapter
from diffusers.utils import export_to_video, load_video

device = "cuda:0"

# Initialize models and pipeline
motion_adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to(device)
pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(
    "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=motion_adapter, vae=vae, torch_dtype=torch.float16,
).to(device)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, beta_schedule="linear", algorithm_type="dpmsolver++", use_karras_sigmas=True)

select_nth_frame = 2
video = load_video("racecar.mp4")[::select_nth_frame]
width = 512
height = 768

for context_length in [16, 24]:
    for context_stride in [4, 8]:
        # Enable FreeNoise for long video generation
        pipe.enable_free_noise(context_length=context_length, context_stride=context_stride, weighting_scheme="pyramid", shuffle=True)

        # Run inference
        video = pipe(
            prompt="racecar, vaporwave style, cyberpunk, intricately detailed, bright colors, 8k resolution, photorealistic, masterpiece, cinematic lighting",
            negative_prompt="bad quality, worst quality, jpeg artifacts, ugly",
            video=video,
            strength=0.6,
            width=width,
            height=height,
            num_inference_steps=25,
            guidance_scale=8.5,
            generator=torch.Generator().manual_seed(42),
        ).frames[0]

        export_to_video(video, f"animatediff_freenoise/pipeline_animatediff_vid2vid_freenoise-shuffle_True-context_length_{context_length}-context_stride_{context_stride}.mp4", fps=24)

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6 @sayakpaul

cc @yiyixuxu as well for library design

a-r-r-o-w initial work draft for freenoise; needs massive cleanup
80e530fb
HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev332 days ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

a-r-r-o-w
a-r-r-o-w332 days ago (edited 332 days ago)

The implementation here gives the following result:

Context length = 16, stride = 8 Context length = 16, stride = 8

I'm not sure why the last few frames go berserk. I will take a better look soon since we need to rewrite this anyway.

The implementation you see here is horrible. Let me explain why it's done this way at the moment:

  • FreeNoise requires modifying the temporal forward pass by breaking up [bhw, f, c] tensors into k = num_frames / context_length chunks of shape [bhw, k, c] and performing self/cross-attn on each chunk, following a weighted averaging of all frames.
  • In Diffusers, both the spatial and temporal forward pass is implemented with BasicTransformerBlock. This makes it very challenging to do frame-wise chunked inference to determine if a pass is spatial or temporal. Currently, to make FreeNoise work, I determine this by some hardcoded logic which will be removed later once we address the design problem.
  • I've retained the forward pass code for original implementation to highlight that the changes needed are quite minimal to support FreeNoise. However, it would be difficult to do so with BasicTransformerBlock.
Code to replicate
import torch

from diffusers import AnimateDiffPipeline
from diffusers.models import AutoencoderKL, MotionAdapter
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image


model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-2"
vae_id = "stabilityai/sd-vae-ft-mse"
device = "cuda"

motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
    model_id,
    subfolder="scheduler",
    beta_schedule="linear",
    algorithm_type="dpmsolver++",
    use_karras_sigmas=True,
)
pipe = AnimateDiffPipeline.from_pretrained(
    model_id,
    motion_adapter=motion_adapter,
    vae=vae,
    scheduler=scheduler,
    torch_dtype=torch.float16,
).to(device)

pipe.enable_free_noise(context_length=16, context_stride=4, shuffle=True)

prompt = "a racoon playing a guitar, sitting in a boat, floating in the ocean, high quality, realistic"
negative_prompt = "bad quality, worst quality"

video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=512,
    height=512,
    # num_frames=16,
    num_frames=80, # must be 80 for the current hardcoded logic to work
    num_inference_steps=25,
    guidance_scale=8,
    generator=torch.Generator().manual_seed(1337),
).frames[0]

export_to_gif(video, "animatediff_freenoise.gif")

I would like to hear your thoughts on the following.

  • Is the FreeNoiseMixin implementation a good way to go? We would somehow need to pass the mixin parameters from pipelines to model when FreeNoise is enabled. I'm thinking added_cond_kwargs could help here.
  • What do you think about a BasicTemporalTransformerBlock? It'll be a copy of the BasicTransformerBlock but only containing features specific to AnimateDiff Temporal Blocks. Will require some rewriting of existing code in a non-breaking manner. Any alternative suggestions are welcome.
  • FreeNoise could potentially be applied to any video model that has spatial block - temporal block - spatial block - ... pattern (free lunch). We should try it on our models after addressing the previous issues, no?
sayakpaul
sayakpaul332 days ago

Just out of curiosity, what is the limit of "long" here? To me 16 is not long enough. Is it fair to expect a min-long video with free-noise? If not, perhaps it might be better to wait as the value add is not that evident. Of course, I could be looking at it entirely wrong. So, happy to stand corrected.

a-r-r-o-w
a-r-r-o-w331 days ago (edited 331 days ago)

Just out of curiosity, what is the limit of "long" here? To me 16 is not long enough. Is it fair to expect a min-long video with free-noise? If not, perhaps it might be better to wait as the value add is not that evident. Of course, I could be looking at it entirely wrong. So, happy to stand corrected.

For text-to-video, I would say 5-30 seconds at 8-12 fps is currently considered "long" enough if the video is consistent. For control-based vid2vid, there is no limit really and you can animate really long videos (1+ mins), especially with Comfy.

The value added may not be evident from the example I showed. Keep in mind that I want to get FreeNoise to work first and iron out design/implementation bugs and then apply to other pipelines. For text-to-video, it is indeed really hard to achieve good videos even with FreeNoise. The place where it really shines is naive vid2vid, controlnet vid2vid, and in combination with other tricks, and I think we should really really support it because it opens up potential for many workflows such as the following within Diffusers. It is widely considered as the best open-source method for long generations (in vid2vid atleast from what I've seen). Essentially, all that's need to support it is chunked frame-wise inference in the BasicTransformerBlock and weighted-averaging of latents.

1, 2, 3, 4, 5, 6, 7 and 8 off the top of my saved reddit posts.

I don't have good workflows set up yet (WIP), so I stole this to generate videos with Comfy for the time being. You can find the results here.

  • The FreeNoise generations at 768x432 (AnimateDiff_00003.gif) takes 3 minutes 12 seconds and has background consistency somewhat.
  • The non-FreeNoise generation at 768x432 (AnimateDiff_00003_nofreenoise.gif) takes 9 minutes and the background keeps changing with a few more artifacts on the body. This is using the Context Scheduler approach for which I have an open PR but we're still deciding if that's worth adding in comparison to just FreeNoise.

cc @DN6

sayakpaul
sayakpaul331 days ago

Thanks for explaining.

My take is we first make it work to convince ourselves about the results we feel good about. We can then work through this PR to reach a design. I think before that achieving those results would be a nice thing to optimize.

sayakpaul
sayakpaul331 days ago

The FreeNoise generations at 768x432 (AnimateDiff_00003.gif) takes 3 minutes 12 seconds and has background consistency somewhat.
The non-FreeNoise generation at 768x432 (AnimateDiff_00003_nofreenoise.gif) takes 9 minutes and the background keeps changing with a few more artifacts on the body. This is using the Context Scheduler approach for which I have an open PR but we're still deciding if that's worth adding in comparison to just FreeNoise.

Which results are these in the videos you shared?

a-r-r-o-w
a-r-r-o-w331 days ago

Thanks for explaining.

My take is we first make it work to convince ourselves about the results we feel good about. We can then work through this PR to reach a design. I think before that achieving those results would be a nice thing to optimize.

I see what you mean. Alright, I'll get it to work with our current community AnimateDiff controlnet implementation (which I really think should now be in core because of how broadly the Comfy equivalent is used) and SparseCtrl once it's merged.

a-r-r-o-w
a-r-r-o-w331 days ago (edited 331 days ago)

Which results are these in the videos you shared?

Check this. AnimateDiff_00003.gif is the FreeNoise version. AnimateDiff_00003-nofreenoise.gif is the non-FreeNoise version. The other files are input and different settings with FreeNoise.

a-r-r-o-w fix freeinit bug
441d3211
a-r-r-o-w add animatediff controlnet implementation
5d0f4c34
a-r-r-o-w Merge branch 'main' into freenoise
2e97ba7c
a-r-r-o-w
a-r-r-o-w331 days ago (edited 331 days ago)👍 1❤ 2

Here are some results to demonstrate the effectiveness of FreeNoise in vid2vid settings:

Code
import requests
from io import BytesIO

import imageio
import torch
from controlnet_aux.processor import LineartDetector, OpenposeDetector
from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, DPMSolverMultistepScheduler, LCMScheduler
from diffusers.pipelines.animatediff.pipeline_animatediff_controlnet import AnimateDiffControlNetPipeline
from diffusers.utils import export_to_gif, export_to_video
from PIL import Image


# model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
model_id = "stablediffusionapi/darksushimixv225"
# model_id = "emilianJR/epiCRealism"
# motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
motion_adapter_id = "wangfuyun/AnimateLCM"
controlnet1_id = "/raid/aryan/hub/models--lllyasviel--ControlNet-v1-1/snapshots/69fc48b9cbd98661f6d0288dc59b59a5ccb32a6b/control_v11p_sd15_openpose.pth"
controlnet2_id = "/raid/aryan/hub/models--lllyasviel--ControlNet-v1-1/snapshots/69fc48b9cbd98661f6d0288dc59b59a5ccb32a6b/control_v11p_sd15_lineart.pth"
vae_id = "stabilityai/sd-vae-ft-mse"
device = "cuda:0"

motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id)
controlnet1 = ControlNetModel.from_single_file(controlnet1_id, torch_dtype=torch.float16)
controlnet2 = ControlNetModel.from_single_file(controlnet2_id, torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16)
pipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained(
    model_id,
    motion_adapter=motion_adapter,
    controlnet=[controlnet1, controlnet2],
    vae=vae,
).to(device=device, dtype=torch.float16)
# pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
#     model_id,
#     subfolder="scheduler",
#     timestep_spacing="linspace",
#     beta_schedule="linear",
#     algorithm_type="dpmsolver++",
#     use_karras_sigmas=True,
#     steps_offset=1,
# )
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")

pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
pipe.set_adapters(["lcm-lora"], [0.8])

def load_video(file_path: str):
    images = []

    if file_path.startswith(('http://', 'https://')):
        # If the file_path is a URL
        response = requests.get(file_path)
        response.raise_for_status()
        content = BytesIO(response.content)
        reader = imageio.get_reader(content)
    else:
        # Assuming it's a local file path
        reader = imageio.get_reader(file_path)

    for frame in reader:
        pil_image = Image.fromarray(frame)
        images.append(pil_image)

    return images


# skip_nth_frame = 1
# max_frames = 16
# video = load_video("input.gif")[::skip_nth_frame][:max_frames]
# width = 512
# height = 768

skip_nth_frame = 2
max_frames = 80
assert max_frames == 80 # Must be 80 for FreeNoise to work because of the hardcoded implementation at the moment
video = load_video("vid2vid_input2.mov")[::skip_nth_frame][:max_frames]
width = 768
height = 432

p1 = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to(device)
p2 = LineartDetector.from_pretrained("lllyasviel/Annotators").to(device)
cn1, cn2 = [], []

with pipe.progress_bar(total=len(video)) as progress_bar:
    for frame in video:
        cn1.append(p1(frame, include_body=True, include_hand=True, include_face=True))
        cn2.append(p2(frame))
        progress_bar.update()

prompt = "girl dancing, blue hair, high quality, surreal"
negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"

# pipe.enable_free_init(use_fast_sampling=True)
# pipe.enable_free_init(use_fast_sampling=False)

pipe.enable_free_noise(context_length=16, context_stride=4, shuffle=True)

video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=len(video),
    num_inference_steps=10,
    guidance_scale=2.0,
    conditioning_frames=[cn1, cn2],
    controlnet_conditioning_scale=[0.5, 0.8],
    decode_batch_size=16,
    generator=torch.Generator().manual_seed(42),
).frames[0]

export_to_video(video, "animatediff_controlnet_long_1.mp4", fps=8)
Both videos are generated with FreeNoise context_length=16 and context_stride=4. Left video has FreeInit (different method) disabled (total 10 inference steps) whereas right video has it enabled in fast mode (3 + 6 + 10 inference steps).

Viewing the GIF in your browser will show various artifacts so the mp4 versions along with inputs can be found here. Ignore the last 8 frames in the output - it is caused due to an implementation bug that I'll look into soon.

cc @asomoza since we can write a few guides to improve long video generation quality once we have a more stable implementation merged

sayakpaul
sayakpaul331 days ago

Looks pretty good. Good luck with the bug hunting!

a-r-r-o-w Merge branch 'main' into freenoise
690dad69
a-r-r-o-w revert attention changes
610f433d
a-r-r-o-w add freenoise
10b65b31
a-r-r-o-w remove old helper functions
a41f843d
a-r-r-o-w add decode batch size param to all pipelines
f6897ae4
a-r-r-o-w make style
024e2da8
a-r-r-o-w fix copied from comments
1bb09845
a-r-r-o-w make fix-copies
1b7bc007
a-r-r-o-w make style
dc96a8d5
a-r-r-o-w copy animatediff controlnet implementation from #8972
691facfc
a-r-r-o-w add experimental support for num_frames not perfectly fitting context…
5a60a62c
a-r-r-o-w make unet motion model lora work again based on #8995
58c2ddcb
a-r-r-o-w copy load video utils from #8972
70001864
a-r-r-o-w copied from AnimateDiff::prepare_latents
c5db39f8
a-r-r-o-w address the case where last batch of frames does not match length of …
594d2d2c
a-r-r-o-w decode_batch_size->vae_batch_size; batch vae encode support in animat…
fb9ca347
a-r-r-o-w revert sparsectrl and sdxl freenoise changes
77ee296a
a-r-r-o-w revert pia
52884b3e
a-r-r-o-w add freenoise tests
1e2ef4df
a-r-r-o-w a-r-r-o-w marked this pull request as ready for review 328 days ago
a-r-r-o-w a-r-r-o-w requested a review from DN6 DN6 328 days ago
a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu yiyixuxu 328 days ago
a-r-r-o-w a-r-r-o-w requested a review from sayakpaul sayakpaul 328 days ago
a-r-r-o-w
a-r-r-o-w328 days ago

This PR requires #8972 to be merged to remove the AnimateDiffControlNet changes done here for the sake of demo code to run if someone wants to replicate. While #8979 shoudn't really cause any problems with the implementation here for inference, it would be better for that to be merged as well and handle merge conflicts here.

a-r-r-o-w
a-r-r-o-w commented on 2024-07-28
a-r-r-o-w327 days ago

Left some comments on why certain changes were made and what will get removed as some of the linked PRs are merged.

Big thanks to @DN6 for the design suggestions and helping me on integrating this. It's been long overdue and I can't wait to cook up some good tutorials on long video generation with Diffusers.

src/diffusers/models/attention.py
272272 attention_out_bias: bool = True,
273273 ):
274274 super().__init__()
275
self.dim = dim
a-r-r-o-w327 days ago

These changes were made to initialize the FreeNoiseTransformerBlock correctly. I'm not sure how else we could determine these attributes in a "simple" way without accessing the interal pytorch dimensions which adds many many extra LOC after make style.

src/diffusers/models/unets/unet_motion_model.py
2020
2121from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
22from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin
22
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
a-r-r-o-w327 days ago

From #8995 to fix LoRA in UNetMotionModel. Once that's merged, these changes shouldn't be visible here

src/diffusers/models/unets/unet_motion_model.py
161
162 # Define 3 blocks. Each block has its own normalization layer.
163 # 1. Self-Attn
164
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
a-r-r-o-w327 days ago

I've made some assumptions here and removed all the branches handling different configurations of layernorms and attention. We support multiple norm_type in BasicTransformerBlock. I think that SD15 checkpoints, used in AnimateDiff, always use LayerNorm but am only 99% sure. LMK if any other norm types must be handled

src/diffusers/models/unets/unet_motion_model.py
214
215 return frame_indices
216
217
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
a-r-r-o-w327 days ago

The original FreeNoise implementation proposes using a pyramid weighted averaging (see Eq. 9 of the paper. However, the diffusion community found different weighting schemes that also seem to work well in practice. While I haven't tested it deeply, I would like to keep the implementation to extension in the future. For now, let's roll with the original unless we can test different methods qualitatively before next release

Conversation is marked as resolved
Show resolved
src/diffusers/models/unets/unet_motion_model.py
211 window_start = i
212 window_end = min(num_frames, i + self.context_length)
213 frame_indices.append((window_start, window_end))
214
a-r-r-o-w327 days ago
Suggested change
src/diffusers/models/unets/unet_motion_model.py
229
230 return weights
231
232
def set_free_noise_properties(
a-r-r-o-w327 days ago

Not sure what to name it, feel free to suggest. It's a helper function to change properties dynamically at inference from FreeNoiseMixin for already initialized FreeNoiseTransformerBlocks without doing the entire initialization part again

src/diffusers/models/unets/unet_motion_model.py
267 frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
268 is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
269
270
# Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
271
# For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
272
# [(0, 16), (4, 20), (8, 24), (10, 26)]
273
if not is_last_frame_batch_complete:
274
if num_frames < self.context_length:
275
raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
276
last_frame_batch_length = num_frames - frame_indices[-1][1]
277
frame_indices.append((num_frames - self.context_length, num_frames))
a-r-r-o-w327 days ago

The original implementation here does not seem to support this case. Essentially, if we select a context_length and context_stride, or have to process number of frames, such that perfect frame-wise batching is not possible, we try and process the full context_length amount of frames BUT only accumulate on the unaccomodated frames. I tested it and it works well on cases like num_frames=26, context_length=16, context_stride=4

src/diffusers/models/unets/unet_motion_model.py
322 )
323 hidden_states_chunk = attn_output + hidden_states_chunk
324
325
if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
326
accumulated_values[:, -last_frame_batch_length:] += (
327
hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
328
)
329
num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
330
else:
331
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
332
num_times_accumulated[:, frame_start:frame_end] += weights
a-r-r-o-w327 days ago

Specialized logic to handle unfitting frame batch case as described above. LMK if this needs to be more readable and possible suggestions

src/diffusers/pipelines/animatediff/pipeline_animatediff.py
394396
395397 return ip_adapter_image_embeds
396398
397 # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
398 def decode_latents(self, latents):
399
def decode_latents(self, latents, vae_batch_size: int = 16):
a-r-r-o-w327 days ago

Need to support frame-wise chunking in all intermediate layers including ResNet blocks if we want to save memory otherwise this blows up. For now, we can process 64 frames on a 24 GB card by taking care of the VAE encode/decode. I suggest we get the FreeNoise functionality in first, and later take care of optimizing memory in different internal blocks later

src/diffusers/pipelines/animatediff/pipeline_animatediff.py
519525
526 # If FreeNoise is enabled, shuffle latents in every window as described in Equation (7) of
527 # [FreeNoise](https://arxiv.org/abs/2310.15169)
528
if self.free_noise_enabled and self._free_noise_shuffle:
a-r-r-o-w327 days ago

I hope this is self-explanatory. I've tried making it as readable as possible but LMK if changes are needed.

DN6327 days ago

When we initialise latents with the shape (batch, num_frames, channels, h, w), each latent is already independent of the others right? How does shuffling within windows help?

If it is needed. I think we can move it to a method inside the AnimateDiffFreeNoiseMixin

DN6327 days ago

Isn't the shuffling done by creating an init noise tensor of shape (batch, context_size, channels, h, w) and then shuffling these tensors over the full video length?

a-r-r-o-w327 days ago (edited 327 days ago)

Isn't the shuffling done by creating an init noise tensor of shape (batch, context_size, channels, h, w) and then shuffling these tensors over the full video length?

You are absolutely right. I was thinking about rewriting this as well to make the intent clearer.

The original implementation first creates (batch_size, num_latent_channels, num_frames, h, w) but then in the for loop that follows, they set the values of first context_length number of frames to the remaining full-video frames. I did it the same way as them because if FreeNoise is not enabled, we can just return the full-length latents already created above.

Just to elaborate more, assume context_length=16, context_stride=4, num_frames=32:

latents[0:16] <- rand tensor of shape (b, num_channel_latents, context_length, h, w)
latents[16:20] <- shuffle(latents[0:4])
latents[20:24] <- shuffle(latents[4:8])
latents[24:28] <- shuffle(latents[8:12])
latents[28:32] <- shuffle(latents[12:16])
... and so on if there are more. This is what we're doing here as well
src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py
1
# Copyright 2024 The HuggingFace Team. All rights reserved.
a-r-r-o-w327 days ago

This file will be removed from here once my animatediff controlnet PR is merged :)

src/diffusers/pipelines/free_noise_utils.py
118 weighting_scheme (`str`, defaults to `4`):
119 TODO(aryan)
120 shuffle (`str`, defaults to `True`):
121
TODO(aryan): decide if this is even needed
a-r-r-o-w327 days ago

Latent shuffling (or better explained as reusing context_length number of latent frames intead of num_frames in the pipeline) is very much required to improve temporal consistency. In my initial pass of the paper, I misunderstood what it meant and implemented incorrectly. Now it's done correctly and you can see that text2vid quality has significantly improved.

We can either remove this as parameter and always do shuffling, or leave it in for more experimental freedom in different settings. Either is okay with me

src/diffusers/pipelines/free_noise_utils.py
125 self._free_noise_weighting_scheme = weighting_scheme
126 self._free_noise_shuffle = shuffle
127
128
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
a-r-r-o-w327 days ago

I've not added FreeNoise to AnimateDiff SparseCtrl yet because it fails somewhere in the Mid block. It seems that it will require more time but there are more important things to be attended to at the moment, so I propose to roll with this and revisit in near future

Conversation is marked as resolved
Show resolved
src/diffusers/utils/loading_utils.py
4952 return image
53
54
55
def load_video(
a-r-r-o-w327 days ago

This will automatically be removed after my animatediff controlnet PR is merged :)

tests/pipelines/animatediff/test_animatediff.py
401402 "Enabling of FreeInit should lead to results different from the default pipeline results",
402403 )
403404
405
def test_free_noise_blocks(self):
a-r-r-o-w327 days ago

I've added two tests for FreeNoise based on what I think are the most important parts. LMK if anything else is needed

DN6
DN6 commented on 2024-07-29
Conversation is marked as resolved
Show resolved
src/diffusers/loaders/peft.py
3131_SET_ADAPTER_SCALE_FN_MAPPING = {
3232 "UNet2DConditionModel": _maybe_expand_lora_scales,
3333 "SD3Transformer2DModel": lambda model_cls, weights: weights,
34
"UNetMotionModel": lambda model_cls, weights: weights,
DN6327 days ago👍 1

Pull from upstream once #8995 is merged.

DN6
DN6 commented on 2024-07-29
src/diffusers/pipelines/animatediff/pipeline_animatediff.py
569601 clip_skip: Optional[int] = None,
570602 callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
571603 callback_on_step_end_tensor_inputs: List[str] = ["latents"],
604
vae_batch_size: int = 16,
DN6327 days ago

Let's use naming/logic similar to SVD for batch decoding.

frames = self.decode_latents(latents, num_frames, decode_chunk_size)

a-r-r-o-w327 days ago

This is also used in the vae encode for animatediff_video2video btw, but can rename it that

DN6
DN6 commented on 2024-07-29
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/free_noise_utils.py
31 r"""Helper function to enable FreeNoise in transformer blocks."""
32
33 for motion_module in block.motion_modules:
34
motion_module: TransformerTemporalModel
DN6327 days ago👍 1

This line isn't needed right?

a-r-r-o-w Merge branch 'main' into freenoise
5d5a7ea6
a-r-r-o-w make fix-copies
3d9b1833
a-r-r-o-w improve docstrings
44e40a28
a-r-r-o-w add freenoise tests to animatediff controlnet
a61ffffe
a-r-r-o-w update tests
d82228ea
a-r-r-o-w Update src/diffusers/models/unets/unet_motion_model.py
037ee07d
a-r-r-o-w Merge branch 'main' into freenoise
ac3d8c63
a-r-r-o-w add freenoise to animatediff pag
d19ddb48
a-r-r-o-w address review comments
12cc84a8
a-r-r-o-w make style
6f483562
a-r-r-o-w update tests
1f0ccfdd
a-r-r-o-w make fix-copies
6a4aab8c
a-r-r-o-w a-r-r-o-w requested a review from DN6 DN6 322 days ago
DN6 update
2f77c69c
a-r-r-o-w fix error message
8564dc32
a-r-r-o-w remove copied from comment
b32b1d7f
a-r-r-o-w fix imports in tests
045ae36f
DN6 update
2d9aa42c
DN6
DN6 approved these changes on 2024-08-07
DN6 DN6 merged 16a93f1a into main 318 days ago
a-r-r-o-w a-r-r-o-w deleted the freenoise branch 318 days ago

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone