diffusers
add a `from_pipe` method to `DiffusionPipeline`
#7241
Merged

add a `from_pipe` method to `DiffusionPipeline` #7241

yiyixuxu merged 32 commits into main from from_pipe
yiyixuxu
yiyixuxu1 year ago (edited 1 year ago)🎉 1

motivated by #6531

Create a stable diffusion pipeline with from_pretrained

from diffusers import DiffusionPipeline, StableDiffusionSAGPipeline, AnimateDiffPipeline, MotionAdapter, DDIMScheduler
from diffusers.utils import export_to_gif, load_image
import torch
import gc
from accelerate.utils import compute_module_sizes

def flush():
    gc.collect()
    torch.cuda.empty_cache()

def bytes_to_giga_bytes(bytes):
    return bytes / 1024 / 1024 / 1024

base_repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
num_inference_steps = 50
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
prompt="bear eating pizza"
negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"

# test1
print(" ")
print("test1: pipe_sd")
pipe_sd = DiffusionPipeline.from_pretrained(base_repo, torch_dtype=torch.float16)
pipe_sd.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipe_sd.set_ip_adapter_scale(0.6)
pipe_sd.to("cuda")

generator = torch.Generator(device="cpu").manual_seed(33)
out_sd = pipe_sd(
    prompt=prompt,
    negative_prompt=negative_prompt,
    ip_adapter_image=image,
    num_inference_steps=num_inference_steps,
    generator=generator,
).images[0]
out_sd.save("yiyi_test_4_out_1_sd.png")

flush()
print(
    f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
Max memory allocated: 4.408166408538818 GB

yiyi_test_4_out_1_sd

test2: SD -> SAG

# test2
print(" ")
print("test2: pipe_sd -> pipe_sag")

pipe_sag = StableDiffusionSAGPipeline.from_pipe(
    pipe_sd,
    safety_checker=None,
)
# the pipe_sag already have ip-adapter loaded
generator = torch.Generator(device="cpu").manual_seed(33)
out_sag = pipe_sag(
    prompt = prompt,
    negative_prompt=negative_prompt,
    ip_adapter_image=image,
    num_inference_steps=num_inference_steps,
    generator=generator,
    guidance_scale=1.0,
    sag_scale=0.75).images[0]
out_sag.save("yiyi_test_4_out_2_sag.png")
flush()
print(
    f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
Max memory allocated: 4.408166408538818 GB

yiyi_test_4_out_2_sag

test3: run SD again

# test3
print(" ")
print("test3: run pipe_sd again (should have same output as before)")
generator = torch.Generator(device="cpu").manual_seed(33)
out_sd = pipe_sd(
    prompt=prompt,
    negative_prompt=negative_prompt,
    ip_adapter_image=image,
    num_inference_steps=num_inference_steps,
    generator=generator,
).images[0]
out_sd.save("yiyi_test_4_out_3_sd.png")
flush()
print(
    f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
Max memory allocated: 4.408166408538818 GB

yiyi_test_4_out_3_sd

test4: run pipe_sd after pipe_sag.unload_ip_adapter()

# test4
print("")
print(f" test4: run pipe_sd after unload ip_adapter from pipe_sag (should get an error)")
pipe_sag.unload_ip_adapter()
try:
    generator = torch.Generator(device="cpu").manual_seed(33)
    out_sd = pipe_sd(
        prompt=prompt,
        negative_prompt=negative_prompt,
        ip_adapter_image=image,
        num_inference_steps=num_inference_steps,
        generator=generator,
    ).images[0]
except Exception as e:
    print(f"error: {e}")
error: 'NoneType' object has no attribute 'image_projection_layers'

test5: SD -> AnimateDiff

# test5
print(" ")
print("test5: pipe_sd -> pipe_animate")

adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)

pipe_animate = AnimateDiffPipeline.from_pipe(pipe_sd, motion_adapter=adapter)
pipe_animate.scheduler = DDIMScheduler.from_config(pipe_animate.scheduler.config, beta_schedule="linear")
# load ip_adapter again and load lora weights
pipe_animate.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipe_animate.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
pipe_animate.to("cuda")

generator = torch.Generator(device="cpu").manual_seed(33)
pipe_animate.set_adapters("zoom-out", adapter_weights=0.75)
out = pipe_animate(
    prompt= prompt,
    num_frames=16,
    num_inference_steps=num_inference_steps,
    ip_adapter_image = image,
    generator=generator,
).frames[0]

export_to_gif(out, "yiyi_test_4_out_5_animate.gif")
flush()
print(
    f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
Max memory allocated: 15.185057640075684 GB

yiyi_test_4_out_5_animate

test6: SD -> LPW

# test6
print(" ")
print("test6: pipe_sd -> pipe_lpw (community pipeline)")

pipe_lpw = DiffusionPipeline.from_pipe(
    pipe_sd,
    custom_pipeline="lpw_stable_diffusion",
).to("cuda")

prompt = "best_quality (1girl:1.3) bow bride brown_hair closed_mouth frilled_bow frilled_hair_tubes frills (full_body:1.3) fox_ear hair_bow hair_tubes happy hood japanese_clothes kimono long_sleeves red_bow smile solo tabi uchikake white_kimono wide_sleeves cherry_blossoms"
neg_prompt = "lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry"
generator = torch.Generator(device="cpu").manual_seed(33)
out_lpw = pipe_lpw.text2img(
    prompt,
    negative_prompt=neg_prompt,
    width=512,height=512,
    max_embeddings_multiples=3,
    num_inference_steps=num_inference_steps,
    generator=generator,
    ).images[0]
out_lpw.save("yiyi_test_4_out_6_lpw.png")

flush()
print(
    f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
Max memory allocated: 15.185057640075684 GB

yiyi_test_4_out_6_lpw

test7: run SD again

# test7
print(" ")
print("test7: pipe_sd")
generator = torch.Generator(device="cpu").manual_seed(33)
out_sd = pipe_sd(
    prompt=prompt,
    negative_prompt=negative_prompt,
    generator=generator,
    num_inference_steps=num_inference_steps,
).images[0]
out_sd.save("yiyi_test_4_out_7_sd.png")
flush()
print(
    f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
Max memory allocated: 15.185057640075684 GB

yiyi_test_4_out_7_sd

add from_pipe
99cdecc9
up
f664d171
up
daab15dc
style
16392631
yiyixuxu
yiyixuxu1 year ago

cc @vladmandic here
still WIP, but let me know what you think about the API and the use cases it covers
do you have any other specific use case in mind that I did not cover here?

yiyixuxu
yiyixuxu commented on 2024-03-06
src/diffusers/models/unets/unet_motion_model.py
395395
396396 # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
397 config = unet.config
397
config = dict(unet.config)
yiyixuxu1 year ago👍 1

@DN6
currently, we will modify the original 2d unet's config - even though we do not use it here, we create a new unet motion model instead

HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev1 year ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

yiyixuxu
yiyixuxu commented on 2024-03-06
src/diffusers/pipelines/pipeline_utils.py
2138 ```
2139 """
2140
2141
if hasattr(pipeline, "_all_hooks") and len(pipeline._all_hooks) > 0:
yiyixuxu1 year ago

this is the part we make sure if the pipeline has previous called enable_model_cpu_offload, it will still work properly with from_pipe

vladmandic
vladmandic1 year ago

thanks yiyi!
from high level, it seems to cover all main use cases. if there is something borderline, we can think of that later.
my two comments are:

  • target pipeline should inherit pipeline settings, not just components. e.g. model_cpu_offload and all other enable_() methods should be applied on target if they were applied on source.
  • testing. to make sure that target pipeline actually works when there are additional components added to it (e.g. different vae or pretty much anything). in my experience, this is where accelerate often breaks as it doesn't pull model components in time so you end up with runtimeerror cuda vs cpu.
yiyixuxu Merge branch 'main' into from_pipe
29cc455a
yiyixuxu add test
a94ef035
yiyixuxu fix tests
0b9d579e
yiyixuxu add a test for offload
507dc3cc
yiyixuxu fix tests
afd53542
yiyixuxu fix
e17504a7
yiyixuxu make get_signature_types a class method + find out which pipelines ca…
20829e0a
yiyixuxu style
6f977587
yiyixuxu up
f7220063
yiyixuxu up
cac53a0e
yiyixuxu fix tests!
ada8d042
yiyixuxu update test name
02f0527b
yiyixuxu Merge branch 'main' into from_pipe
e4a785c6
yiyixuxu separate out the cpu offload test
312cce40
yiyixuxu Merge branch 'from_pipe' of github.com:huggingface/diffusers into fro…
a496e8c2
yiyixuxu
yiyixuxu1 year ago

thanks for the feedback! @vladmandic

to make sure that target pipeline actually works when there are additional components added to it

I have not run into any issues using enable_model_cpu_offload with additional components, and my tests are pretty extensive (we added a fast test to test all diffusers official pipeline that can use from_pipe), I think it is not a concern here because I remove all the hooks and reset the offload device in the beginning

target pipeline should inherit pipeline settings

I'm not so sure about this because:

  1. We allow adding and subtracting components with the from_pipe API, so the new pipeline may have different memory requirements, and the user may want different settings. I think it would be simpler to reset instead of inheriting the settings unless they always want to have the same settings for the new pipeline.
  2. not every pipeline has implemented all of these methods; e.g. in my testing, the LPW pipeline did not have the enable_model_cpu_offload method working correctly. This would more likely be an issue with community pipelines
  3. I agree it is less convenient if you have to re-apply settings but I don't think it makes too much difference

with this being said, I think it won't be hard to implement and I'm open to it if you all think it's more intuitive and convenient to let the new pipelines inherit settings. cc @pcuenca here too, let me know what you think!

yiyixuxu yiyixuxu requested a review from DN6 DN6 1 year ago
yiyixuxu yiyixuxu requested a review from sayakpaul sayakpaul 1 year ago
yiyixuxu
yiyixuxu1 year ago

cc @DN6 @sayakpaul for a final review
let me know what you think about this #7241 (comment) too
I'm slightly in favor of resetting the pipeline settings but I don't feel strongly either way

vladmandic
vladmandic1 year ago

thanks @yiyixuxu

re: pipeline settings inheritance - IMO it would be more convenient and expected since its a pipeline switch using loaded model components (all or some), but its not a deal breaker - from_pipe has massive value either way.

yiyixuxu
yiyixuxu1 year ago (edited 1 year ago)

actually, now I think most of the enable_* methods make stateful changes to the model components, and these changes are already naturally carried over to the new pipeline (e.g. these on StableDiffusionMixin

class StableDiffusionMixin:
) so we probably should make it consistent with enable_model_cpu_offload

on the other hand, these methods may not work probably with the potential addition or override of new components
e.g.

if we have enable_vae_slicing enabled on the pipeline, and create a new pipeline with a new vae components, it won't work

pipeline1.enable_vae_slicing()
vae = AutoencoderKL.from_pretrained()
pipe2 = NewPipelineClass.from_pipe(pipe1, vae= vae)

should we handle this on our end or let the user address this? if they just re-apply the settings, it would work as expected; if we are going to handle this on our end, it would be pretty complicated I think

sayakpaul
sayakpaul commented on 2024-03-22
sayakpaul1 year ago

Very nice PR. Let's document this more formally no?

src/diffusers/pipelines/pipeline_loading_utils.py
292292 return class_obj, class_candidates
293293
294294
295
def _get_custom_pipeline_class(
sayakpaul1 year ago

Very nice cleanup here!

src/diffusers/pipelines/pipeline_utils.py
1481 elif get_origin(v.annotation) == Union:
1482 signature_types[k] = get_args(v.annotation)
1483 else:
1484
logger.warn(f"cannot get type annotation for Parameter {k} of {cls}.")
sayakpaul1 year ago👍 1

logger.warning() please.

Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
1665 @classmethod
1666 def from_pipe(cls, pipeline: "DiffusionPipeline", **kwargs) -> "DiffusionPipeline":
1667 r"""
1668
Create a new pipeline from a given pipeline. This method is useful to create a new pipeline with the same
1669
weights and configurations without reallocating additional memory.
sayakpaul1 year ago
Suggested change
Create a new pipeline from a given pipeline. This method is useful to create a new pipeline with the same
weights and configurations without reallocating additional memory.
Create a new pipeline from a given pipeline. This method is useful to create a new pipeline from the existing pipeline components without reallocating additional memory.
src/diffusers/pipelines/pipeline_utils.py
1688
1689 pipeline_is_offloaded = True if hasattr(pipeline, "_all_hooks") and len(pipeline._all_hooks) > 0 else False
1690 if pipeline_is_offloaded:
1691
# `enable_model_cpu_offload` has be called on the pipeline, offload model and remove hook from model
sayakpaul1 year ago
Suggested change
# `enable_model_cpu_offload` has be called on the pipeline, offload model and remove hook from model
# `enable_model_cpu_offload` has been called on the pipeline, offload model and remove hooks from model
sayakpaul1 year ago

Why is this step needed?

yiyixuxu1 year ago

otherwise if you call enable_model_cpu_offload again on the new pipeline, it won't work correctly,
however, this is no longer needed now we have this PR #7448
I can remove this now :)

Conversation is marked as resolved
Show resolved
tests/pipelines/test_pipelines_common.py
1123 ## for pipelines such as animtediff and pia, the unet changes from UNet2DConditionModel to UNetMotionModel during __init__
1124 changed_components = {}
1125 original_signature_types = original_pipeline_class._get_signature_types()
1126
for k, v in pipe2.components.items():
sayakpaul1 year ago
Suggested change
for k, v in pipe2.components.items():
for name, component in pipe2.components.items():
Conversation is marked as resolved
Show resolved
tests/pipelines/test_pipelines_common.py
1078 # components expected in the original pipeline
1079 original_expected_modules, _ = original_pipeline_class._get_signature_keys(original_pipeline_class)
1080
1081
for k, v in components.items():
sayakpaul1 year ago
Suggested change
for k, v in components.items():
for name, component in components.items():
Conversation is marked as resolved
Show resolved
tests/pipelines/test_pipelines_common.py
1082 if k in original_expected_modules:
1083 original_pipe_components[k] = v
1084 else:
1085
additional_components[k] = v
sayakpaul1 year ago

Could be nice to denote this as current_pipe_additional_components.

tests/pipelines/test_pipelines_common.py
1124 changed_components = {}
1125 original_signature_types = original_pipeline_class._get_signature_types()
1126 for k, v in pipe2.components.items():
1127
if (
1128
k in original_expected_modules
1129
and v is not None
1130
and isinstance(v, torch.nn.Module)
1131
and type(v) not in original_signature_types[k]
1132
):
1133
changed_components[k] = original_pipe_components[k]
sayakpaul1 year ago

I am really having trouble reading this condition.

  • Why should the original_pipe_additional_components go into initializing a different pipeline that might not share the same additional components?
sayakpaul1 year ago

Maybe fully spelling out what k and v are denoting will be helpful.

yiyixuxu1 year ago (edited 1 year ago)

Why should the original_pipe_additional_components go into initializing a different pipeline that might not share the same additional components?

yeah, I know - not the most straight-forward code to read here
so we only have two pipelines here, the original pipeline would be something like SD, the current pipeline would be something like SAG

Basically in the test we do something like this,

                    SD (`pipe_original`) -> SAG (`pipe2`) -> SD (`pipe_original_2`)

so you can see that pipe_original_2 and pipe_original is actually same pipeline class, original_pipe_additional_components would be components in pipe_original that not accepted in pipe2, e.g. image_encoder (I'm just making it up here), naturally it will go back when we initiate the pipe_original_2 since it is the same pipeline class

maybe I should give them better names!

tests/pipelines/test_pipelines_common.py
1142 inputs = get_dummy_inputs_pipe_original(torch_device)
1143 output_original2 = pipe_original_2(**inputs)[0]
1144
1145
assert pipe1_config == pipe2_config
1146
assert pipe_original_2_config == original_config
sayakpaul1 year ago

Could be a little stricter to iterate through the keys and make sure their corresponding values match.

vladmandic
vladmandic1 year ago

actually, now I think most of the enable_* methods make stateful changes to the model components, and these changes are already naturally carried over to the new pipeline so we probably should make it consistent with enable_model_cpu_offload
on the other hand, these methods may not work probably with the potential addition or override of new components
e.g. if we have enable_vae_slicing enabled on the pipeline, and create a new pipeline with a new vae components, it won't work

that exactly is the problem with model offload compatibiltiy i was referring to early in the conversation.
for example, a very common use-case is to load feature_extractor and image_encoder once they are needed by ipadapter.
and then we end up with split-brain pipe: parts want to do offloading, parts do not.

even worse, what if i want to now load a different base model and just reuse those two previously loaded component so i don't have to load them again as well (they are not tiny by any means)?

if we're not going to handle those internally, then at least we need to have opposite of enable_model_cpu_offload to force-disable it before loading components and then we call enable_model_cpu_offload it again once pipeline is reconstructed.

yiyixuxu only check type for modules
af4f18f7
yiyixuxu
yiyixuxu1 year ago

Thanks @vladmandic! I didn't think it through before
I updated the model offload methods in a separate PR here. So now, both enable_model_cpu_offload and enable_sequential_cpu_offload will work properly when you re-apply them, either from the same pipeline or from a different one.

now all these enable_* methods have consistent behavior in from_pipe: we do not do anything to guarantee these pipeline settings are inherited from the previous pipeline; they may not work as it is but if you re-apply them, they will work correctly on the new pipeline

I'm open to inheriting the settings too! let me know what you all think too @sayakpaul @pcuenca @DN6

vladmandic
vladmandic1 year ago👍 1

thanks - if re-apply now works, that's imo sufficient.

DN6
DN6 commented on 2024-03-26
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
1724 passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
1725
1726 original_class_obj = {}
1727
for k, v in pipeline.components.items():
DN61 year ago
Suggested change
for k, v in pipeline.components.items():
for name, component in pipeline.components.items():
yiyixuxu Merge branch 'main' into from_pipe
7c9bfbe9
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
1485 elif get_origin(v.annotation) == Union:
1486 signature_types[k] = get_args(v.annotation)
1487 else:
1488
logger.warn(f"cannot get type annotation for Parameter {k} of {cls}.")
yiyixuxu1 year ago
Suggested change
logger.warn(f"cannot get type annotation for Parameter {k} of {cls}.")
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
tests/pipelines/test_pipelines_common.py
1093 original_expected_modules, _ = original_pipeline_class._get_signature_keys(original_pipeline_class)
1094
1095 for k, v in components.items():
1096
if k in original_expected_modules:
yiyixuxu1 year ago
Suggested change
if k in original_expected_modules:
if name in original_expected_modules:
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
tests/pipelines/test_pipelines_common.py
1094
1095 for k, v in components.items():
1096 if k in original_expected_modules:
1097
original_pipe_components[k] = v
yiyixuxu1 year ago
Suggested change
original_pipe_components[k] = v
original_pipe_components[name] = component
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
tests/pipelines/test_pipelines_common.py
1096 if k in original_expected_modules:
1097 original_pipe_components[k] = v
1098 else:
1099
additional_components[k] = v
yiyixuxu1 year ago
Suggested change
additional_components[k] = v
current_pipe_additional_components[name] = component
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
tests/pipelines/test_pipelines_common.py
1088 # additional components that are not in the pipeline, but expected in the original pipeline
1089 original_pipe_additional_components = {}
1090 # additional components that are in the pipeline, but not expected in the original pipeline
1091
additional_components = {}
yiyixuxu1 year ago
Suggested change
additional_components = {}
current_pipe_additional_components = {}
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
tests/pipelines/test_pipelines_common.py
1124 output1 = pipe1(**inputs)[0]
1125
1126 # pipe2 (created with `from_pipe`): original pipeline(sd/sdxl) -> pipeline
1127
pipe2 = self.pipeline_class.from_pipe(pipe_original, **additional_components)
yiyixuxu1 year ago
Suggested change
pipe2 = self.pipeline_class.from_pipe(pipe_original, **additional_components)
pipe2 = self.pipeline_class.from_pipe(pipe_original, **current_pipe_additional_components)
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
tests/pipelines/test_pipelines_common.py
1139 original_signature_types = original_pipeline_class._get_signature_types()
1140 for k, v in pipe2.components.items():
1141 if (
1142
k in original_expected_modules
yiyixuxu1 year ago
Suggested change
k in original_expected_modules
name in original_expected_modules
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
tests/pipelines/test_pipelines_common.py
1140 for k, v in pipe2.components.items():
1141 if (
1142 k in original_expected_modules
1143
and v is not None
yiyixuxu1 year ago
Suggested change
and v is not None
and component is not None
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
tests/pipelines/test_pipelines_common.py
1141 if (
1142 k in original_expected_modules
1143 and v is not None
1144
and isinstance(v, torch.nn.Module)
yiyixuxu1 year ago
Suggested change
and isinstance(v, torch.nn.Module)
and isinstance(component, torch.nn.Module)
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
tests/pipelines/test_pipelines_common.py
1142 k in original_expected_modules
1143 and v is not None
1144 and isinstance(v, torch.nn.Module)
1145
and type(v) not in original_signature_types[k]
yiyixuxu1 year ago
Suggested change
and type(v) not in original_signature_types[k]
and type(component) not in original_signature_types[k]
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
tests/pipelines/test_pipelines_common.py
1144 and isinstance(v, torch.nn.Module)
1145 and type(v) not in original_signature_types[k]
1146 ):
1147
changed_components[k] = original_pipe_components[k]
yiyixuxu1 year ago
Suggested change
changed_components[k] = original_pipe_components[k]
changed_components[name] = original_pipe_components[name]
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
1729
1730 original_class_obj = {}
1731 for k, v in pipeline.components.items():
1732
if k in expected_modules and k not in passed_class_obj:
yiyixuxu1 year ago (edited 1 year ago)
Suggested change
if k in expected_modules and k not in passed_class_obj:
if name in expected_modules and name not in passed_class_obj:
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
1731 for k, v in pipeline.components.items():
1732 if k in expected_modules and k not in passed_class_obj:
1733 if (
1734
not isinstance(v, torch.nn.Module)
yiyixuxu1 year ago
Suggested change
not isinstance(v, torch.nn.Module)
not isinstance(component, torch.nn.Module)
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
1732 if k in expected_modules and k not in passed_class_obj:
1733 if (
1734 not isinstance(v, torch.nn.Module)
1735
or type(v) in signature_types[k]
yiyixuxu1 year ago
Suggested change
or type(v) in signature_types[k]
or type(component) in signature_types[name]
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
1733 if (
1734 not isinstance(v, torch.nn.Module)
1735 or type(v) in signature_types[k]
1736
or (v is None and k in cls._optional_components)
yiyixuxu1 year ago
Suggested change
or (v is None and k in cls._optional_components)
or (component is None and name in cls._optional_components)
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
1735 or type(v) in signature_types[k]
1736 or (v is None and k in cls._optional_components)
1737 ):
1738
original_class_obj[k] = v
yiyixuxu1 year ago
Suggested change
original_class_obj[k] = v
original_class_obj[name] = component
yiyixuxu
yiyixuxu commented on 2024-03-29
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
1738 original_class_obj[k] = v
1739 else:
1740 logger.warn(
1741
f"component {k} is not switched over to new pipeline because type does not match the expected."
1742
f" {k} is {type(v)} while the new pipeline expect {signature_types[k]}."
1743
f" please pass the correct type of component to the new pipeline. `from_pipe(..., {k}={k})`"
yiyixuxu1 year ago
Suggested change
f"component {k} is not switched over to new pipeline because type does not match the expected."
f" {k} is {type(v)} while the new pipeline expect {signature_types[k]}."
f" please pass the correct type of component to the new pipeline. `from_pipe(..., {k}={k})`"
f"component {name} is not switched over to new pipeline because type does not match the expected."
f" {name} is {type(component)} while the new pipeline expect {signature_types[name]}."
f" please pass the correct type of component to the new pipeline. `from_pipe(..., {name}={name})`"
yiyixuxu Apply suggestions from code review
28cf55b5
yiyixuxu up
f21b7384
yiyixuxu up
6660203e
yiyixuxu Merge branch 'main' into from_pipe
b3fb32ed
yiyixuxu made a texter mixin for from_pipe
154bdbca
yiyixuxu fix tests: make sure pipelines does not make stateful changes to the …
e5010446
yiyixuxu add doc
f2e2ffe5
DN6
DN6 commented on 2024-04-01
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
10171018 image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
10181019 self.maybe_free_model_hooks()
1020 # make sure to set the original attention processors back
1021
self.unet.set_attn_processor(original_attn_proc)
DN61 year ago (edited 1 year ago)

Are the attn processors changed here? Doesn't seem like it?

DN61 year ago

Oh nvm. It's done through self.register_attention_control()

DN6
DN6 commented on 2024-04-01
src/diffusers/pipelines/pipeline_utils.py
1722 if name in expected_modules and name not in passed_class_obj:
1723 # for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature
1724 if (
1725
not isinstance(component, torch.nn.Module)
DN61 year ago (edited 1 year ago)👍 1

Maybe change this to see if it subclasses ModelMixin here?

DN6
DN6 approved these changes on 2024-04-01
DN61 year ago

LGTM 👍🏽

yiyixuxu
yiyixuxu commented on 2024-04-01
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
1722 if name in expected_modules and name not in passed_class_obj:
1723 # for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature
1724 if (
1725
not isinstance(component, torch.nn.Module)
yiyixuxu1 year ago
Suggested change
not isinstance(component, torch.nn.Module)
not isinstance(component, ModelMixin)
yiyixuxu Update src/diffusers/pipelines/pipeline_utils.py
26535f66
yiyixuxu up
3bf58a43
yiyixuxu move guide
5e393600
yiyixuxu Merge branch 'main' into from_pipe
e088ac36
yiyixuxu yiyixuxu merged 7956c36a into main 1 year ago
yiyixuxu yiyixuxu deleted the from_pipe branch 1 year ago

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone