diffusers
[ip-adapter] refactor `prepare_ip_adapter_image_embeds` and skip load `image_encoder`
#7016
Merged

[ip-adapter] refactor `prepare_ip_adapter_image_embeds` and skip load `image_encoder` #7016

yiyixuxu merged 29 commits into main from ip-adapter-no-image-encoder
yiyixuxu
yiyixuxu1 year ago (edited 1 year ago)

fix #6925

allow pass ip_adapter_image_embeds directly and skip loading image_encoder

import torch
from diffusers import StableDiffusionXLPipeline, AutoencoderKL, DPMSolverMultistepScheduler
from diffusers.utils.testing_utils import load_pt

vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix",
    torch_dtype=torch.float16,
).to("cuda")

pipeline = StableDiffusionXLPipeline.from_pretrained(
    'stabilityai/stable-diffusion-xl-base-1.0',
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
    vae=vae
).to('cuda')

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.scheduler.config.use_karras_sigmas = True


pipeline.load_ip_adapter(
    "h94/IP-Adapter",
    subfolder="sdxl_models",
    weight_name="ip-adapter_sdxl_vit-h.safetensors",
    image_encoder_folder=None,
)
pipeline.set_ip_adapter_scale(0.6)

print(f" pipeline.image_encoder: {pipeline.image_encoder}")

prompt = "a horse, highly detailed, 4k, professional"
negative_prompt="blurry"

# diffusers embeds [(2,2,1024)], cfg 
image_embeds =  load_pt("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/diffusers_style_test.ipadpt")
generator = torch.Generator(device="cpu").manual_seed(33)
image = pipeline(
    prompt=prompt,
    ip_adapter_image_embeds=image_embeds,
    negative_prompt=negative_prompt,
    guidance_scale=7.5,
    num_inference_steps=20,
    num_images_per_prompt=2,
    generator=generator,
).images

image[0].save("yiyi_test_4_out_diffusers_cfg.png")

# diffusers embeds, no cfg 
image_embeds_no_cfg = [single_image_embeds[1] for single_image_embeds in image_embeds]
generator = torch.Generator(device="cpu").manual_seed(33)
image = pipeline(
    prompt=prompt,
    ip_adapter_image_embeds=image_embeds_no_cfg,
    negative_prompt=negative_prompt,
    guidance_scale=0,
    num_inference_steps=20,
    num_images_per_prompt=2,
    generator=generator,
).images

image[0].save("yiyi_test_4_out_diffusers_no_cfg.png")


# comfyui embeds (2,2,1024), cfg 
image_embeds =  load_pt("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/comfyui_style_test.ipadpt").to(device="cuda", dtype=torch.float16)
image_embeds, negative_image_embeds = image_embeds.chunk(2)

generator = torch.Generator(device="cpu").manual_seed(33)
image = pipeline(
    prompt=prompt,
    ip_adapter_image_embeds=[torch.cat([negative_image_embeds, image_embeds], dim=0)],
    negative_prompt=negative_prompt,
    guidance_scale=7.5,
    num_inference_steps=20,
    num_images_per_prompt=2,
    generator=generator,
).images

image[0].save("yiyi_test_4_out_comfy_cfg.png")

# comfyui embeds, no cfg 
generator = torch.Generator(device="cpu").manual_seed(33)
image = pipeline(
    prompt=prompt,
    ip_adapter_image_embeds=[image_embeds],
    negative_prompt=negative_prompt,
    guidance_scale=0,
    num_inference_steps=20,
    num_images_per_prompt=2,
    generator=generator,
).images

image[0].save("yiyi_test_4_out_comfy_no_cfg.png")

we also allow specifying a different subfolder for image_encoder_folder

e.g. if the ip-adapter checkpoint needs to use image_encoder that's not from the same subfolder, you do not need to load it explicitly as described here in the doc https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters#ip-adapter-plus you can specify the image_encoder_folder instead

from diffusers import AutoPipelineForText2Image
import torch
from diffusers.utils import load_image

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16
)

pipeline.load_ip_adapter(
    "h94/IP-Adapter", 
    subfolder="sdxl_models", 
    weight_name="ip-adapter-plus_sdxl_vit-h.safetensors", 
    image_encoder_folder="models/image_encoder")
pipeline.enable_model_cpu_offload()


image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
generator = torch.Generator(device="cpu").manual_seed(33)
images = pipeline(
    prompt='best quality, high quality, wearing sunglasses', 
    ip_adapter_image=image,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=50,
    generator=generator,
).images[0]
images.save("yiyi_test_out.png")
add
d6825d67
up
b7e82f4f
merge
0082b2ef
up
98552b96
add
7e969c9f
up
45018c65
yiyixuxu yiyixuxu changed the title [ip-adapter] allow pass `ip_hidden_states` directly and skip load image_encoder [ip-adapter] allow pass `ip_hidden_states` directly and skip load `image_encoder` 1 year ago
HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev1 year ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

fix-copies
8b707cdf
yiyixuxu
yiyixuxu1 year ago

cc @asomoza feel free to give a review:)

yiyixuxu yiyixuxu requested a review from sayakpaul sayakpaul 1 year ago
sayakpaul
sayakpaul1 year ago

If you could also describe how was "diffusers_style_test.ipadpt" generated, I think that'd be a useful reference for the community.

sayakpaul
sayakpaul commented on 2024-02-19
src/diffusers/loaders/ip_adapter.py
188 if image_encoder_folder is not None:
189 if not isinstance(pretrained_model_name_or_path_or_dict, dict):
190 logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
191
if image_encoder_folder.count("/") == 0:
sayakpaul1 year ago

🧠

sayakpaul
sayakpaul commented on 2024-02-19
Conversation is marked as resolved
Show resolved
src/diffusers/loaders/ip_adapter.py
205 )
195206 else:
196 raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
207 logger.warning(
208
"image_encoder is not loaded since `image_encoder_folder=None` passed. you will not be able to use `ip_adapter_image` for ip-adapter."
209
" use `ip_adapter_image_embedding` to pass pre-geneated image embedding instead."
210
)
sayakpaul1 year ago
Suggested change
"image_encoder is not loaded since `image_encoder_folder=None` passed. you will not be able to use `ip_adapter_image` for ip-adapter."
" use `ip_adapter_image_embedding` to pass pre-geneated image embedding instead."
)
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
"Use `ip_adapter_image_embedding` to pass pre-geneated image embedding instead."
)
sayakpaul
sayakpaul commented on 2024-02-19
Conversation is marked as resolved
Show resolved
src/diffusers/loaders/ip_adapter.py
5252 pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
5353 subfolder: Union[str, List[str]],
5454 weight_name: Union[str, List[str]],
55
image_encoder_folder: Optional[str] = "image_encoder",
sayakpaul1 year ago

Let's add this to the docstring too.

sayakpaul1 year ago

@yiyixuxu I think this is yet to be resolved?

sayakpaul
sayakpaul approved these changes on 2024-02-19
sayakpaul1 year ago👍 1

Looking nice!

I think it'd be nice to add:

sayakpaul Merge branch 'main' into ip-adapter-no-image-encoder
d349b229
asomoza
asomoza1 year ago (edited 1 year ago)

Nice that this also solves the problem with the different image encoders for SDXL.

I'm still testing and everything works except when I try to load the embeddings from comfyui, I get a shape error. I'm trying to figure out where is the difference between the image embeddings form diffusers and the ones from comfyui.

edit: This is not a problem with this PR though, probably need some kind of conversion.

asomoza
asomoza1 year ago

If you could also describe how was "diffusers_style_test.ipadpt" generated, I think that'd be a useful reference for the community.

Just did a torch.save after the prepare_ip_adapter_image_embeds with your example in #6868:

image_embeds = prepare_ip_adapter_image_embeds(
    unet=pipeline.unet,
    image_encoder=pipeline.image_encoder,
    feature_extractor=pipeline.feature_extractor,
    ip_adapter_image=[[image_one, image_two]],
    do_classifier_free_guidance=True,
    device="cuda",
    num_images_per_prompt=1,
)

torch.save(image_embeds, "diffusers_style_test.ipadpt")
asomoza
asomoza1 year ago

I found another issue while doing some tests, I didn't notice it before since I don't use multiple images per prompt, but the code right now expects the embeddings to match this argument, for example in the SDXL pipeline:

single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)

This ties the embeddings to the number of images per prompt which is not ideal unless you keep track of this.

I see two solutions:

1.- Document this and leave it to the users
2.- Move this part of the code to after loading the embeddings so they will be the same regardless of the number of images per prompt.

I'll wait to see your resolution since this makes it hard to reuse the embeddings or make them compatible with other apps.

As an example, if I use the same code that generates the horse but with the num_images_per_prompt=4 I get this result:

20240219121344

sayakpaul
sayakpaul1 year ago

Thanks for reporting!

Move this part of the code to after loading the embeddings so they will be the same regardless of the number of images per prompt.

I think we should catch these things early and report to the users so that they can call the pipeline with the appropriate values. We'd want the num_images_per_prompt argument to be consistent with how it's generally treated across the library, IMO. WDYT? Also @yiyixuxu? I think this way, we wouldn't have to document it separately.

As an example, if I use the same code that generates the horse but with the num_images_per_prompt=4 I get this result:

What is the expected result?

asomoza
asomoza1 year ago

I think we should catch these things early and report to the users so that they can call the pipeline with the appropriate values. We'd want the num_images_per_prompt argument to be consistent with how it's generally treated across the library, IMO. WDYT? Also @yiyixuxu? I think this way, we wouldn't have to document it separately.

AFAIK there's no way to catch this, this doesn't produce any errors is just that the final result is different and kind of worse than the original.

The only way to make it work as intended, is if you match the number of images per prompt with the embeddings, so you'll need to also remember this value, put it in the filename or save it as a metadata in the file, but this also makes it only usable for that specific case.

This is not a critical issue, specially for people that just use one image per prompt, but with it, I don't see the need to put any more effort into making it compatible with saved embeddings from other apps or libraries.

What is the expected result?

The same as the first image in this PR:

20240219011826

sayakpaul
sayakpaul1 year ago

I think it does warrant a deeper investigation then why it’s the case. We should fix the root cause here IMO.

asomoza
asomoza commented on 2024-02-19
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
12761276 encoder_hidden_states = self.encoder_hid_proj(image_embeds)
12771277 elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1278 if "image_embeds" not in added_cond_kwargs:
1278 if cross_attention_kwargs is not None and cross_attention_kwargs.get("ip_hidden_states", None) is not None:
1279
ip_hidden_states = cross_attention_kwargs.pop("ip_hidden_states")
asomoza1 year ago

if you pop the ip_hidden_states, the next step throws the ValueError since there's no image_embeds and no ip_hidden_states. If you don't do it though, there's a heavy spam in the console, this is an expected issue with not going with the kwargs route.

yiyixuxu1 year ago

umm i see. it make sense!
do you think we should pass it to ip_adapter_image_embeds and decide whether it has been projected or not based on the shape?

asomoza1 year ago👍 1

yeah, that's a good solution, it works for me.

yiyixuxu
yiyixuxu1 year ago (edited 1 year ago)

@asomoza @sayakpaul

about the other issue

I see two solutions:
1.- Document this and leave it to the users
2.- Move this part of the code to after loading the embeddings so they will be the same regardless of the number of images per prompt.

i think solution 2 is an easy answer, no? any downside to it that I missed?

asomoza
asomoza1 year ago

i think solution 2 is an easy answer, no? any downside to it that I missed?

yeah, for me it is the solution I would choose but sometimes there's an underlying reason in diffusers that I don't know which prevents it.

yiyixuxu
yiyixuxu1 year ago (edited 1 year ago)👍 1

@asomoza

ok, I will make changes based on solution 2. I think it's fine because that's consistent with prompt_embeds (we do not expect prompt_embeds to match num_images_per_prompt )

prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)

but we will wait @sayakpaul back next week to make a final decision on what we do

asomoza
asomoza1 year ago

@yiyixuxu @sayakpaul

There's another parameter that gets saved with the image embeddings and ties it up, the do_classifier_free_guidance which means that the embeddings will only work if they match with what is saved (CFG > 0 or CFG < 1)

So if we're going to move the num_images_per_prompt we should also move the check for do_classifier_free_guidance to after saving/loading the embeddings. This is also consistent with prompt_embeds

asomoza
asomoza1 year ago

I can now make my saved embeddings work with diffusers and comfyui depending on which point in my code I save the embeddings, but right now there isn't a solution to make both of them compatible.

I was wrong about the image projection being done before the saving though, so the comfyui embeds works passing them with ip_adapter_image_embeds, I do mine with ip_hidden_states and works ok too.

I can load and run the comfyui embeds like this:

comfy_image_embeds = torch.load("comfyui_style_test.ipadpt")
embeds = torch.unbind(comfy_image_embeds )

single_image_embeds = embeds[0]
single_negative_image_embeds = embeds[1]

single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)

image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])

but there's no way to make the diffusers embeds work with comfyui with all the arguments saved with the embeddings.

yiyixuxu
yiyixuxu1 year ago

@asomoza
I don't quiet understand the last comment
is this an example of comfyui embedding? https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/comfyui_style_test.ipadpt?download=true

it is a 2 x 2 x 1024 tensor, I would assume it is a the embedding after image_encoder and before the projection layer, no?

asomoza
asomoza1 year ago

it is a 2 x 2 x 1024 tensor, I would assume it is a the embedding after image_encoder and before the projection layer, no?

@yiyixuxu

yes, I didn't check that deep before, but now I found that they are saved before the projection layer so they must be passed with the ip_adapter_image_embeds.

Still it will be nice to have the option to skip it with cross_attention_kwargs since with that I don't have to do any conversion to load the embeddings (with my code).

yiyixuxu
yiyixuxu1 year ago

@asomoza
Ohh,, in that case, I don't think we should support ip_hidden_states - it will be too much of an edge case. But we will make sure the ip_adapter_image_embedding works as expected

the shape 2 x 2 x 1024 is batch_size x num_images x emb_dim - so it has CFG implied and num_images_per_prompt = 2?

asomoza
asomoza1 year ago

@yiyixuxu

Ohh,, in that case, I don't think we should support ip_hidden_states - it will be too much of an edge case.

I agree and don't have any problems with it, I can adapt and its a lot easier to do it on my side.

the shape 2 x 2 x 1024 is batch_size x num_images x emb_dim - so it has CFG implied and num_images_per_prompt = 2?

no, the batch size in this case refers to the image_embeds and the negative_image_embeds, the node in comfyui does this before saving them:

output = torch.stack((clip_embed, clip_embed_zeroed))

that's why I use:

embeds = torch.unbind(comfy_image_embeds )

after loading them and also the second dimension is the number of images (multiple images, I used two) for the IP Adapter, the nodes in ComfyUI don't manage the number of images per prompt.

For reference and so that you don't have to search for it, this is the code for the embeddings in the node:

https://github.com/cubiq/ComfyUI_IPAdapter_plus/blob/6a411dcb2c6c3b91a3aac97adfb080a77ade7d38/IPAdapterPlus.py#L1023-L1070

The output is what is saved as a file.

For the loading it just gets passed as an embeds argument:

https://github.com/cubiq/ComfyUI_IPAdapter_plus/blob/6a411dcb2c6c3b91a3aac97adfb080a77ade7d38/IPAdapterPlus.py#L706-L708

and then does the rest.

fabiorigano
fabiorigano1 year ago❤ 1

@yiyixuxu nice addition

I am adding this comment to stay tuned about the progress
This feature is very useful when loading IP Adapter FaceID, as it has no image encoder

up
d537f321
up
d0356c35
update docstring
e50484ff
yiyixuxu
yiyixuxu commented on 2024-02-25
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
582582 image_embeds.append(single_image_embeds)
583583 else:
584 image_embeds = ip_adapter_image_embeds
584
image_embeds = []
yiyixuxu1 year ago

@asomoza

I updated the prepare_ip_adapter_image_embeds method now we expect the ip_adapter_image_embeds to be a list of the same length as the number of IP adapters; each tensor in the list should have shape batch_size, num_images, embed_dim (e.g. 2, 2, 1024 in our case); the num_image_per_prompt is always 1 for the ip_adapter_image_embeds user passed, we duplicate it inside prepare_ip_adapter_image_embeds method; however we expected the embedding has CFG applied if we are going to use CFG in the inference (e.g. both our comfyui and diffusers style embedding has CFG applied, however, the order is different in diffusers vs comfy UI so we had to swap);

might be a little bit confusing but it is easier to understand if you take a look at the testing script here
#7016 (comment)

asomoza1 year ago

I tested it and it works perfect for the diffusers and comfyui embeddings, also tested it with cfg and no cfg. with num_per_images > 0 and 1, it's all good, thank you for you hard work.

asomoza1 year ago

with the Save IPAdapter Embeds node:

image

sayakpaul Merge branch 'main' into ip-adapter-no-image-encoder
df5fce41
sayakpaul
sayakpaul commented on 2024-02-26
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
582582 image_embeds.append(single_image_embeds)
583583 else:
584 image_embeds = ip_adapter_image_embeds
584 image_embeds = []
585
for single_image_embeds in ip_adapter_image_embeds:
586
if self.do_classifier_free_guidance:
587
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
588
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
589
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
590
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
591
else:
592
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
sayakpaul1 year ago

@yiyixuxu do we need to catch any conditions, throw warnings here to avoid silent bugs?

yiyixuxu1 year ago

ok. going to check on the check_inputs

sayakpaul
sayakpaul approved these changes on 2024-02-26
sayakpaul1 year ago👍 1

Seems like we have covered quite a bit of grounds here and went through multiple hoops of design.

This is looking pretty solid. I think it's okay if we're covering the most common use cases enabled by this change. We don't need to cover extreme edge cases which can be better handled otherwise, IMO.

Let's ship this with tests and docs!

yiyixuxu Update src/diffusers/loaders/ip_adapter.py
cf47f357
update check input
b6c0a375
update check inputs
b74aa54f
fix copies
1784c635
update docstring
f7e05151
update docstring for image encoder
ddfa055f
add a cfg test
57a4b6a0
fix tests
46a47b3e
add do_classifier_free_guidance arg to prepare_ip_adapter_image_embeds
56a136c4
fix
10e91511
add a section on doc about ip_adapter_image_embeds
26c5b5f7
add a note about comfy ui embedding
40fc2514
yiyixuxu
yiyixuxu commented on 2024-02-29
src/diffusers/pipelines/animatediff/pipeline_animatediff.py
390390 )
391391
392 if self.do_classifier_free_guidance:
392
if do_classifier_free_guidance:
yiyixuxu1 year ago (edited 1 year ago)👍 1

adding do_classifier_free_guidance as an argument to prepare_ip_adapter_imabe_embeds so we can use this pipeline method to save image embeddings

image_embeds = pipeline.prepare_ip_adapter_image_embeds(
    ip_adapter_image=image,
    ip_adapter_image_embeds=None,
    device="cuda",
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
)

torch.save(image_embeds, "image_embeds.ipadpt")
update ldm3d
0118fe21
yiyixuxu yiyixuxu changed the title [ip-adapter] allow pass `ip_hidden_states` directly and skip load `image_encoder` [ip-adapter] refactor `prepare_ip_adapter_image_embeds` and skip load `image_encoder` 1 year ago
yiyixuxu
yiyixuxu commented on 2024-02-29
docs/source/en/using-diffusers/ip_adapter.md
234234> [!TIP]
235235> While calling `load_ip_adapter()`, pass `low_cpu_mem_usage=True` to speed up the loading time.
236236
237
All the pipelines supporting IP-Adapter accept a `ip_adapter_image_embeds` argument. If you need to run the IP-Adapter multiple times with the same image, you can encode the image once and save the embedding to the disk.
yiyixuxu1 year ago

cc @stevhliu here for awareness

I added a section to ip-adapter guide here. Let me know if you have any comments. If editing in a separate PR is easier, feel free to do so!

yiyixuxu1 year ago👍 2❤ 1

another very good use case for ip_adapter_image_embeds is probably the multi-ip-adapter https://huggingface.co/docs/diffusers/main/en/using-diffusers/ip_adapter#multi-ip-adapter

a common practice is to use a folder of 10+ images for styling, and you would use the same styling images everywhere to create a consistent style, so it would be nice to create an image embedding for these style images, so you don't have to load a bunch of same images from a folder and encode them each time

sayakpaul1 year ago👍 1

I think we should definitely add that example motivating the use case. WDYT @asomoza?

stevhliu1 year ago

I'll edit it in a separate PR, and I can also make a mention of ip_adapter_image_embeds in the multi IP-Adapter section 🙂

asomoza1 year ago❤ 1

I think we should definitely add that example motivating the use case. WDYT @asomoza?

yeah this is specially helpful when you use a lot of images and multiple ip adapters, you just need to save the embeddings making it a lot easier to replicate and saves a lot of space if you use high quality images.

I'll try to do one with a style and a character and see how it goes, but to see the real potential of this we'll also need controlnet and ip adapter masking so the best use case would be a full demo with all of this incorporated.

update test, only check shape
dd1ff56b
yiyixuxu
yiyixuxu1 year ago

finally finishing up this PR now. I refactored some more feel free to give a final review

cc @sayakpaul @asomoza

sayakpaul Merge branch 'main' into ip-adapter-no-image-encoder
e72d8434
sayakpaul
sayakpaul commented on 2024-02-29
Conversation is marked as resolved
Show resolved
docs/source/en/using-diffusers/ip_adapter.md
251Load the image embedding and pass it to the pipeline as `ip_adapter_image_embeds`
252
253> [!TIP]
254
> ComfyUI image embeddings are fully compatible with IP-Adapter in diffusers and will work out-of-box.
sayakpaul1 year ago
Suggested change
> ComfyUI image embeddings are fully compatible with IP-Adapter in diffusers and will work out-of-box.
> ComfyUI image embeddings for IP-Adapters are fully compatible in Diffusers and should work out-of-box.
sayakpaul
sayakpaul commented on 2024-02-29
docs/source/en/using-diffusers/ip_adapter.md
254> ComfyUI image embeddings are fully compatible with IP-Adapter in diffusers and will work out-of-box.
255
256```py
257
image_embeds = torch.load("image_embeds.ipadpt")
sayakpaul1 year ago

We don't know where this is coming from. Let's include a snippet to download that and explicitly mention that it's coming from ComfyUI.

sayakpaul
sayakpaul commented on 2024-02-29
Conversation is marked as resolved
Show resolved
docs/source/en/using-diffusers/ip_adapter.md
265```
266
267> [!TIP]
268
> If you use IP-Adapter with image embedding instead of image, you can choose not to load a image encoder by passing `image_encoder_folder=None` to `load_ip_adapter()`
sayakpaul1 year ago
Suggested change
> If you use IP-Adapter with image embedding instead of image, you can choose not to load a image encoder by passing `image_encoder_folder=None` to `load_ip_adapter()`
> If you use IP-Adapter with `ip_adapter_image_embedding` instead of `ip_adapter_image`, you can choose not to load an image encoder by passing `image_encoder_folder=None` to `load_ip_adapter()`.
sayakpaul
sayakpaul commented on 2024-02-29
sayakpaul1 year ago

Looking pretty solid. I left a couple of suggestions to the docs. I reviewed the changes made to ip_adapter.py and the changes in prepare_ip_adapter_image_embeds() and check_inputs from SDXL pipeline script. I think the rest of the pipelines share these changes?

I thought we were also supporting passing the image embedding projection as well. Are we not doing so?

src/diffusers/pipelines/animatediff/pipeline_animatediff.py
494504 "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
495505 )
496506
507
if ip_adapter_image_embeds is not None:
508
if not isinstance(ip_adapter_image_embeds, list):
509
raise ValueError(
510
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
511
)
512
elif ip_adapter_image_embeds[0].ndim != 3:
513
raise ValueError(
514
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
515
)
sayakpaul1 year ago

Do we need any checks on the shapes to conform to what's needed for classifier-free guidance?

yiyixuxu Update docs/source/en/using-diffusers/ip_adapter.md
7bd572fd
yiyixuxu Update docs/source/en/using-diffusers/ip_adapter.md
07b4e21e
yiyixuxu
yiyixuxu1 year ago (edited 1 year ago)❤ 1

@sayakpaul

I thought we were also supporting passing the image embedding projection as well. Are we not doing so?

so it turns out comfyUI embedding is created before the image projection layer - so we don't need to support passing the projection output directly anymore since it is too small an use case

yiyixuxu yiyixuxu merged 06b01ea8 into main 1 year ago
yiyixuxu yiyixuxu deleted the ip-adapter-no-image-encoder branch 1 year ago

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone