diffusers
IP-Adapter attention masking
#6847
Merged

IP-Adapter attention masking #6847

fabiorigano
fabiorigano1 year ago❀ 4πŸ‘€ 1

What does this PR do?

Fixes #6802

Who can review?

@yiyixuxu @asomoza

fabiorigano Add attention masking to attn processors
063e0856
fabiorigano
fabiorigano1 year ago (edited 1 year ago)

it is a work in progress, I am not satisfied with the results (maybe I am doing something wrong).

Mask preprocessing is done outside of the PR. I extract masks from a RGB image, after selecting unique colors and discarding the background (black). Here it is a code snippet to get the list of masks from the following image:

import torch
import diffusers
from diffusers import AutoPipelineForText2Image, DDIMScheduler
from diffusers.utils import load_image

noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1
)

pipeline = AutoPipelineForText2Image.from_pretrained(
    "SG161222/Realistic_Vision_V4.0_noVAE",
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    feature_extractor=None,
    safety_checker=None
).to("cuda")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
pipeline.set_ip_adapter_scale(0.7)


# Load image
mask = load_image("./mask.png")
# Use image processor registered in the pipeline
iproc = pipeline.image_processor
mask = iproc.pil_to_numpy(mask)[0]
# Find unique colors
colors = np.unique(mask.reshape(-1, 3), axis=0)
# Discard background
unique = [colors[i] for i in range(colors.shape[0]) if np.all(colors[i] != np.zeros(3))]
# Extract masks
masks = [np.expand_dims(np.where(mask==u, 1,0)[:, :, 0], axis=0) for u in unique]
masks = [iproc.numpy_to_pt(mask)[0] for mask in masks]

mask

Input images are:

https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png
image1

https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png
image2

Then I called the pipeline as it follows:

generator = torch.Generator(device="cpu").manual_seed(33)
num_images=1

images = pipeline(
      prompt="A photo of two girls wearing black dresses, holding red roses in hand, upper body, behind is the Eiffel Tower",
      ip_adapter_image=[[image1, image2]],
      negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
      num_inference_steps=20, num_images_per_prompt=num_images, width=704, height=512,
      generator=generator, cross_attention_kwargs={"masks": masks},
      #output_type= "np"
  ).images

Result without masks:
without

Result with masks:
masked

asomoza
asomoza1 year ago❀ 5

Nice work, you're doing the one use case that I didn't code which is IP Adapters with multiple images and multiple masks, but is the same as two IP Adapters with one image and one mask for each one with the added benefit that you can manage the weight of each one separately, so in my tests it would be like this:

Result 1 Result 2
20240204223825 20240204223854

I use SDXL only, but they should be comparable. I really recommend that you don’t use multiple masks for multiple images and instead use one mask per IP Adapter. I haven’t seen someone using this, but I could be wrong.

The problem you see in your example is more noticeable with SDXL:

Result 1 Result 2
20240205030442 20240205030523

What's happening is that you're matching the batch with the masks, but the batch, depending on the classifier free guidance is * 2 or not, so what you're really doing is applying only one mask if the negative prompt is empty or deleting one if the CFG is less than 1. Also you're applying the mask to the ip_hidden_states of multiple images, so you can also see that the faces are combined into one where the mask is applied.

There's some more minor issues but I'll wait and see which approach you use.

fabiorigano Move latent image masking
d753eec6
fabiorigano
fabiorigano1 year ago (edited 1 year ago)πŸ‘ 1

hi @asomoza, thanks for the suggestion, I updated the for loop and now results look pretty good.

Also you're applying the mask to the ip_hidden_states of multiple images, so you can also see that the faces are combined into one where the mask is applied.

I am not sure about what you mean here. The image after the mask in the first comment is the result of generation without applying masks, so it is correct to have a combination of the two faces.

I changed the base SD model and loaded two IP-Adapters to the pipeline:

pipeline = AutoPipelineForText2Image.from_pretrained(
    "frankjoshua/realisticVisionV51_v51VAE",
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    feature_extractor=None,
    safety_checker=None
).to("cuda")

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name=["ip-adapter-plus-face_sd15.bin", "ip-adapter-full-face_sd15.bin"])

pipeline.set_ip_adapter_scale([0.7, 0.7])

generator = torch.Generator(device="cpu").manual_seed(33)
num_images=4

images = pipeline(
      prompt="A photo of two girls wearing black dresses, holding red roses in hand, upper body, behind is the Eiffel Tower",
      ip_adapter_image=[[image1], [image2]],
      negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
      num_inference_steps=20, num_images_per_prompt=num_images, width=704, height=512,
      generator=generator, cross_attention_kwargs={"masks": masks}
  ).images

Output
res1
res0
res3

@yiyixuxu

asomoza
asomoza1 year agoπŸ‘ 2

yeah, now is working ok, nice work.

I am not sure about what you mean here. The image after the mask in the first comment is the result of generation without applying masks, so it is correct to have a combination of the two faces.

I meant the one after where there was supposed to be one face for each woman, also you can see it in my results, that's because there were multiple images for one IP Adapter and you were applying one mask to those.

You don't have that problem now and it doesn't matter anymore since you're using two IP Adapters, but the equivalent would be if you do this:

images = pipeline(
      prompt="A photo of two girls wearing black dresses, holding red roses in hand, upper body, behind is the Eiffel Tower",
      ip_adapter_image=[[image1, image2], [image2]],
      negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
      num_inference_steps=20, num_images_per_prompt=num_images, width=704, height=512,
      generator=generator, cross_attention_kwargs={"masks": masks}
  ).images

The results are like this:

[[image1], [image2]] [[image1, image2], [image2]]
20240205114207 20240205115321

I know they're similar but I can see the difference instantly since I've done a million of tests with IP Adapters.

asomoza
asomoza commented on 2024-02-05
src/diffusers/models/attention_processor.py
23482443
2349 # the output of sdp = (batch, num_heads, seq_len, head_dim)
2350 # TODO: add support for attn.scale when we move to Torch 2.1
2351 current_ip_hidden_states = F.scaled_dot_product_attention(
2352 query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2353 )
2444 mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device)
23542445
2355 current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
2356 batch_size, -1, attn.heads * head_dim
2357 )
2358 current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
2446
current_ip_hidden_states = current_ip_hidden_states * mask_downsample
asomoza1 year ago

this throws an error if you use a mask with different width and height than the generated image, for example if I use your mask with SDXL and generate a 1024x1024 image I get this error:

The size of tensor a (4096) must match the size of tensor b (4070) at non-singleton dimension 1

fabiorigano1 year ago

I know, I didn't add checks on mask size yet. I think also ComfyUI implementation has the same issue, but I haven't tested it
https://github.com/cubiq/ComfyUI_IPAdapter_plus/blob/90d3451cd970d5aa9cac55224e24a7c7fd98d253/IPAdapterPlus.py#L537

asomoza1 year ago

I think it works with masks that aren't of the same ratio as the generation, is just not recommended. Maybe @cubiq can provide his insights here, I use the same code and it doesn't use the ratio, I think the need checking means that he wasn't completely sure of the formula used.

fabiorigano1 year agoπŸ‘ 1

Ok, I will check without ratio as in the other implementation! Thanks

cubiq1 year agoπŸ‘ 2

the attention mask is resized and stretched at each iteration, the aspect ratio doesn't matter but of course it's better if you provide the right size.

due to rounding error it might happen that you get the wrong size, but it's not very common and I think I have a solution for that already.

fabiorigano1 year ago

I can confirm the issue is still there also with the other implementation

asomoza1 year ago

In that case I really don't know what should be the best method of doing this that's consistent with diffusers.

In my case I prepare the mask latents outside the attention processor with the vae scale factor and the width and height of the generated image but it could be as simple as throwing an error telling the user that the masks must have the same aspect ratio than the generated image.

fabiorigano Remove redundant code
60336baa
fabiorigano Fix removed line
f6451d3e
yiyixuxu
yiyixuxu1 year agoπŸ‘ 1❀ 1

Great work! Thanks everyone here ❀️ the results look super cool to me!
Can we confirm that it works correctly as long as we only pass one image and one mask for each ip-adapter? @asomoza @fabiorigano

so the remaining item is:

  1. the resizing #6847 (comment)
  2. refactor the code
asomoza
asomoza1 year agoπŸ‘ 1

yes, it works correctly but with one or multiple prompt images and one mask per IP Adapter which IMO is the correct implementation.

There's one other issue that maybe should be addressed but I don't know if it's from this PR or comes from before, but if you don't pass the same number of scales it completely ignores the IP adapters that don't have scales without showing a message or error.

yiyixuxu
yiyixuxu1 year agoπŸ‘ 1

@asomoza I fixed here #6884

yiyixuxu
yiyixuxu commented on 2024-02-07
yiyixuxu1 year ago❀ 1

super cool!

Conversation is marked as resolved
Show resolved
src/diffusers/models/attention_processor.py
720720 attention_mask: Optional[torch.FloatTensor] = None,
721721 temb: Optional[torch.FloatTensor] = None,
722722 scale: float = 1.0,
723
masks=None,
yiyixuxu1 year ago
Suggested change
masks=None,
ip_adapter_masks=None,
Conversation is marked as resolved
Show resolved
src/diffusers/models/attention_processor.py
11961197 attention_mask: Optional[torch.FloatTensor] = None,
11971198 temb: Optional[torch.FloatTensor] = None,
11981199 scale: float = 1.0,
1200
masks=None,
yiyixuxu1 year ago
Suggested change
masks=None,
ip_adapter_masks=None,
Conversation is marked as resolved
Show resolved
src/diffusers/models/attention_processor.py
21312133 attention_mask=None,
21322134 temb=None,
21332135 scale=1.0,
2136
masks=None,
yiyixuxu1 year ago
Suggested change
masks=None,
ip_adapter_masks=None,
src/diffusers/models/attention_processor.py
2194 if len(masks) != len(ip_hidden_states):
2195 raise ValueError(
2196 f"Number of masks ({len(masks)}) must match number of IP-Adapters ({len(self.scale)})"
2197
)
yiyixuxu1 year agoπŸ‘ 1

from what I understand, it only works when we pass 1 image / 1 mask /1 ip-adapter?, if so, let's check the number of images here and throw an error if multiple image are passed

    if ip_hidden_states[0].shape[1] > 1: 
            raise ValueError("...."
asomoza1 year ago

Why do you think that? If you perform that check, you will remove all the instant lora functionality.

yiyixuxu1 year ago

it's only when mask is not None though - you can still use multiple images without mask
and it's only based on the understanding that we can only use one image/one mask/one ip-adapter when we use mask, no?

yiyixuxu1 year agoπŸ‘ 1

if it works with multiple images for sure we don't need this!

asomoza1 year ago (edited 1 year ago)πŸ‘ 1

it works with multiple images, I tested it, so the only check should be that the number of masks matches the number of ip adapters.

sayakpaul1 year agoπŸ‘ 1

Are we covering these cases in the tests?

src/diffusers/models/attention_processor.py
2214 current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
2215
2216 if mask is not None:
2217
seq_len = current_ip_hidden_states.shape[1]
2218
o_h = masks[0].shape[1]
2219
o_w = masks[0].shape[2]
2220
ratio = o_w / o_h
2221
mask_h = int(torch.sqrt(torch.tensor(seq_len / ratio)))
2222
mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0)
2223
mask_w = seq_len // mask_h
2224
2225
if len(mask.shape) == 2:
2226
mask = mask.unsqueeze(0)
2227
mask_downsample = F.interpolate(
2228
torch.tensor(mask, dtype=torch.float32).unsqueeze(0), size=(mask_h, mask_w), mode="bicubic"
2229
).squeeze(0)
2230
2231
if mask_downsample.shape[0] < batch_size:
2232
mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
2233
if mask_downsample.shape[0] > batch_size:
2234
mask_downsample = mask_downsample[:batch_size, :, :]
2235
2236
mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat(
2237
1, 1, current_ip_hidden_states.shape[-1]
2238
)
2239
2240
mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device)
yiyixuxu1 year agoπŸ‘ 1

let's move this code to VaeImageProcessor, https://github.com/huggingface/diffusers/blob/main/src/diffusers/image_processor.py

maybe we can create a IPAdapterMaskProcessor(VaeImageProcessor) and add a downsample method

Suggested change
seq_len = current_ip_hidden_states.shape[1]
o_h = masks[0].shape[1]
o_w = masks[0].shape[2]
ratio = o_w / o_h
mask_h = int(torch.sqrt(torch.tensor(seq_len / ratio)))
mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0)
mask_w = seq_len // mask_h
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
mask_downsample = F.interpolate(
torch.tensor(mask, dtype=torch.float32).unsqueeze(0), size=(mask_h, mask_w), mode="bicubic"
).squeeze(0)
if mask_downsample.shape[0] < batch_size:
mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
if mask_downsample.shape[0] > batch_size:
mask_downsample = mask_downsample[:batch_size, :, :]
mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat(
1, 1, current_ip_hidden_states.shape[-1]
)
mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device)
mask_downsample = IPAdapterMaskProcessor.downsample(mask, seq_length, batch_size)
mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device)
fabiorigano Add padding
bf4eb1dd
fabiorigano
fabiorigano1 year ago (edited 1 year ago)❀ 1

so the remaining item is:

  1. the resizing [WIP] IP-Adapter attention masking  #6847 (comment)
  2. refactor the code

I have just added padding to fix the resizing bug, I see output is still good.
maybe it is better to recommend using masks with aspect ratio equal or very close to that of the output images, but avoiding generating errors if there is a mismatch.
I will finish refactoring as suggested after work :)

yiyixuxu
yiyixuxu commented on 2024-02-07
Conversation is marked as resolved
Show resolved
src/diffusers/models/attention_processor.py
21862189 hidden_states = attn.batch_to_head_dim(hidden_states)
21872190
2191 if masks is not None:
2192
if not isinstance(masks, list):
2193
masks = [masks]
yiyixuxu1 year ago
Suggested change
if not isinstance(masks, list):
masks = [masks]
if not isinstance(masks,np.ndarray) or mask.ndim != 4:
raise ValueError(" ip_adapter_mask should be a numpy array with shape num_ip_adapter, 1, height, width. Please use `IPAdapterMaskProcessor` to preprocess your mask")

let's enforce the masks to be a numpy array with shape [num_ip_adapter, 1, height, width]

In order to be able to do so, let's create a IPAdapterMaskProcessor that inherits from VaeImageProcessor https://github.com/huggingface/diffusers/blob/main/src/diffusers/image_processor.py#L444
and user can get their mask like this

mask_processor = IPAdapterMaskProcessor()
masks = mask_processor.process([mask1,mask2])
yiyixuxu1 year ago (edited 1 year ago)

As a reference, this is how I use the VaeImageProcessor to process our mask;

import torch
from diffusers import AutoPipelineForText2Image, DDIMScheduler
from transformers import CLIPVisionModelWithProjection
from diffusers.utils import load_image
from diffusers.image_processor import VaeImageProcessor

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    "h94/IP-Adapter", 
    subfolder="models/image_encoder",
    torch_dtype=torch.float16,
)

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    image_encoder=image_encoder,
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.load_ip_adapter(
  "h94/IP-Adapter", 
  subfolder="sdxl_models", 
  weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2
)
pipeline.set_ip_adapter_scale([0.7] * 2)
pipeline.enable_model_cpu_offload()

face_image1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png")
face_image2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png")
mask1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png")
mask2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png")

mask_processor = VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True)
masks = list(mask_processor.preprocess([mask1, mask2]))

generator = torch.Generator(device="cpu").manual_seed(0)

image = pipeline(
    prompt="2 girls",
    ip_adapter_image=[face_image1, face_image2],
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=50, num_images_per_prompt=1,
    generator=generator,
    cross_attention_kwargs={"masks": masks}
).images[0]

yiyi_test_1_out

yiyixuxu1 year ago❀ 1

also IPAdapterMaskProcessor should only need to handle mask that has one color, e.g. the masks like this

ip_mask_mask1

And it does not need to be able to handle this kind of mask

ip_mask

fabiorigano Apply suggestions from code review
37419f13
fabiorigano Add IPAdapterMaskProcessing
a180e258
fabiorigano
fabiorigano1 year ago (edited 1 year ago)

Updated snippet to run inference:

from diffusers import AutoPipelineForText2Image, DDIMScheduler
import torch
from diffusers.utils import load_image
from transformers import CLIPVisionModelWithProjection
from diffusers.image_processor import IPAdapterMaskProcessor

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    "h94/IP-Adapter", 
    subfolder="models/image_encoder",
    torch_dtype=torch.float16,
)

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    image_encoder=image_encoder,
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)

face_image1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png")
face_image2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png")
mask1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png")
mask2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png")

processor = IPAdapterMaskProcessor()
masks = processor.preprocess([mask1, mask2])

ip_images =[[image1], [image2]]

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2)
pipeline.set_ip_adapter_scale([0.7, 0.7])
generator = torch.Generator(device="cpu").manual_seed(1)
num_images=1

images = pipeline(
    prompt="2 girls",
    ip_adapter_image=ip_images,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=20, num_images_per_prompt=num_images, 
    generator=generator, cross_attention_kwargs={"ip_adapter_masks": masks}
).images

Output:
p1_0

fabiorigano Fix return types
c6fddaed
fabiorigano Update image_processor
708e0ebb
fabiorigano Add test
bbfeb676
fabiorigano Merge branch 'main' into ipadaptermasks
e11bb7be
fabiorigano
fabiorigano1 year ago

@yiyixuxu can you give a look when you have time please?

I added a test, while I didn't touch documentation because I saw there is a big refactoring going on right now

thanks :)

fabiorigano fabiorigano changed the title [WIP] IP-Adapter attention masking IP-Adapter attention masking 1 year ago
yiyixuxu
yiyixuxu commented on 2024-02-09
yiyixuxu1 year ago❀ 1

ohh looking great!
left a few nits

Conversation is marked as resolved
Show resolved
src/diffusers/image_processor.py
890 Image processor for IP Adapter image masks.
891
892 """
893
def __init__(self):
yiyixuxu1 year agoπŸ‘ 1

I think the resize related arguments are really useful here too, in case you need to resize your mask!

Suggested change
def __init__(self):
def __init__(self):
do_resize: bool = True,
vae_scale_factor: int = 8,
resample: str = "lanczos",
do_normalize: bool = False,
do_binarize: bool = True,
do_convert_grayscale: bool = True)
fabiorigano1 year ago (edited 1 year ago)

I don't fully understand why it is needed do_resize here.
Is it not sufficient to perform the resize by interpolation in the downsample method?

Conversation is marked as resolved
Show resolved
src/diffusers/image_processor.py
891
892 """
893 def __init__(self):
894
super().__init__(do_normalize=False, do_binarize=True, do_convert_grayscale=True)
yiyixuxu1 year ago
Suggested change
super().__init__(do_normalize=False, do_binarize=True, do_convert_grayscale=True)
super().__init__()
Conversation is marked as resolved
Show resolved
src/diffusers/image_processor.py
893 def __init__(self):
894 super().__init__(do_normalize=False, do_binarize=True, do_convert_grayscale=True)
895
896
def process(self, images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> torch.FloatTensor:
yiyixuxu1 year ago
Suggested change
def process(self, images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> torch.FloatTensor:
def preprocess(self, images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> torch.FloatTensor:
Conversation is marked as resolved
Show resolved
src/diffusers/image_processor.py
897 """
898 Convert a PIL.Image.Image or a list of PIL.Image.Image to a torch.FloatTensor
899 """
900
if not isinstance(images, list):
yiyixuxu1 year agoπŸ‘ 2
Suggested change
if not isinstance(images, list):
if not isinstance(images, list):

do we need this? I think the VaeImageProcessor.preprocess is able to handle lists too

sayakpaul1 year agoπŸ‘ 1

I wonder do we need to define preprocess() at all here then.

yiyixuxu
yiyixuxu1 year ago

also I think if we merge this #6915 (comment)
we won't need to add the additional ip_adapter_mask argument to the default attention processors

and also this test PR is relevant too, we will wait for it to merge and update the test #6888

the doc can be added later

yiyixuxu
yiyixuxu1 year ago

cc @asomoza
can you do a final review too?

yiyixuxu yiyixuxu requested a review from DN6 DN6 1 year ago
yiyixuxu yiyixuxu requested a review from sayakpaul sayakpaul 1 year ago
HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev1 year ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

sayakpaul
sayakpaul commented on 2024-02-09
src/diffusers/image_processor.py
903 return images
904
905 @staticmethod
906
def downsample(mask: torch.FloatTensor, batch_size: int, seq_length: int, value_embed_dim: int):
sayakpaul1 year agoπŸ‘ 1

It's uncommon to see seq_length as a parameter for downsampling at least in the vision domain. Could we aim for a better and more well-understood argument name here?

sayakpaul
sayakpaul commented on 2024-02-09
src/diffusers/image_processor.py
910 o_h = mask.shape[1]
911 o_w = mask.shape[2]
912 ratio = o_w / o_h
913
mask_h = int(torch.sqrt(torch.tensor(seq_length / ratio)))
914
mask_h = int(mask_h) + int((seq_length % int(mask_h)) != 0)
915
mask_w = seq_length // mask_h
sayakpaul1 year agoπŸ‘ 1

Why do we need to use torch math operations here? We can directly use math.sqrt() without having to convert seq_length / ratio to a Torch tensor, no?

sayakpaul
sayakpaul commented on 2024-02-09
src/diffusers/image_processor.py
916
917 mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)
918
919
# Repeat mask until batch_size
sayakpaul1 year agoπŸ‘ 1

(nit): "until" indicates that there could be a while in the subsequent operations. A better comment could be:

"# Repeat batch_size times". WDYT?

sayakpaul
sayakpaul commented on 2024-02-09
src/diffusers/image_processor.py
924
925 # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match
926 # Pad tensor if downsampled_mask.shape[1] is smaller than seq_length
927
if mask_h * mask_w < seq_length:
sayakpaul1 year agoπŸ‘ 1

Could mask_h * mask_w be assigned into a variable (with a meaningful name) and reused?

sayakpaul
sayakpaul commented on 2024-02-09
src/diffusers/models/attention_processor.py
720721 attention_mask: Optional[torch.FloatTensor] = None,
721722 temb: Optional[torch.FloatTensor] = None,
722723 scale: float = 1.0,
724
ip_adapter_masks=None,
sayakpaul1 year ago

Let's add type annotation here.

sayakpaul
sayakpaul commented on 2024-02-09
src/diffusers/models/attention_processor.py
2193 if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
2194 raise ValueError(" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
2195 " Please use `IPAdapterMaskProcessor` to preprocess your mask")
2196
if len(ip_adapter_masks) != len(ip_hidden_states):
2197
raise ValueError(
2198
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
2199
)
sayakpaul1 year agoπŸ‘ 1

I find it a little weird that in the condition we have the length of ip_hidden_states while in the error message, we're relying on the length of self.scale. Maybe, settle on one of the two here?

sayakpaul
sayakpaul commented on 2024-02-09
src/diffusers/models/attention_processor.py
2198 f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
2199 )
2200 else:
2201
ip_adapter_masks = [None] * len(ip_hidden_states)
sayakpaul1 year ago

Here as well. Let's make sure to not use multiple variables to check for the right length of components.

sayakpaul
sayakpaul commented on 2024-02-09
src/diffusers/models/attention_processor.py
21982213 ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
21992214 current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
22002215 current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
2216
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
sayakpaul1 year agoπŸ‘ 1

Do we need to do the casting here?

sayakpaul
sayakpaul commented on 2024-02-09
Conversation is marked as resolved
Show resolved
src/diffusers/models/attention_processor.py
2219 mask_downsample = IPAdapterMaskProcessor.downsample(mask, batch_size, current_ip_hidden_states.shape[1],
2220 current_ip_hidden_states.shape[2])
2221
2222
mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device)
sayakpaul1 year ago
Suggested change
mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)

Simpler and more readable?

sayakpaul
sayakpaul commented on 2024-02-09
src/diffusers/models/attention_processor.py
2365 if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
2366 raise ValueError(" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
2367 " Please use `IPAdapterMaskProcessor` to preprocess your mask")
2368
if len(ip_adapter_masks) != len(ip_hidden_states):
2369
raise ValueError(
2370
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
2371
)
2372
else:
2373
ip_adapter_masks = [None] * len(ip_hidden_states)
sayakpaul1 year ago

Same as the previous comments regarding length.

sayakpaul
sayakpaul commented on 2024-02-09
tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
464473 max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
465474 assert max_diff < 5e-4
475
476
def test_masks(self):
sayakpaul1 year ago

Can we also add a test case for checking multiple images with masks, here?

fabiorigano1 year ago

This test case uses two images, two masks and two ipadapters (the same ipadapter loaded twice)

sayakpaul1 year ago

I see. Does it then make sense to separate the test cases and have more descriptive naming?

DN61 year agoπŸ‘ 1

Could we actually break this up into two cases.

test_ip_adapter_single_mask
test_ip_adapter_multiple_masks

Just easier to tell if a specific functionality is broken.

sayakpaul
sayakpaul approved these changes on 2024-02-09
sayakpaul1 year ago❀ 1

Excellent work here!

@yiyixuxu let's make sure to run the tests on a V100 on our Docker image to get the correct assertion values (as the underlying device can change them).

yiyixuxu
yiyixuxu1 year ago (edited 1 year ago)

@sayakpaul

let's make sure to run the tests on a V100 on our Docker image

will do!

asomoza
asomoza1 year ago

@yiyixuxu @fabiorigano can you please do a test with generating an 1024x1024 image with this masks?

mask 1 mask 2
processed_mask_20240202101749_landscape processed_mask_20240202101757_landscape

for me is not working as expected:

20240209005815

if is not just me, IMO we should just do an interpolation with the target width and height, masks for IP Adapters don't have to be that precise, and also specify in the documentation that mask works better when they're of the same ratio than the generation.

Still it would be better to generate an image that matches the areas of the masks than cropping or padding it, for example in my code it generates this image with the same masks:

20240209012136

and even if I do a portrait image with landscape masks it works:

20240209012803

WDYT?

sayakpaul
sayakpaul1 year ago

Are you passing both masks as inputs?

asomoza
asomoza1 year ago

yes, used the same code just replaced the masks, with 1:1 masks it works perfect.

sayakpaul
sayakpaul1 year ago

with 1:1 masks it works perfect.

What's meant by 1:1 masks?

and also specify in the documentation that mask works better when they're of the same ratio than the generation.

You mean the input masks are interpolated such that they share the same aspect ratio? How should this aspect ratio be calculated, then? From the input images?

Apologies, in advance, for my naive questions.

asomoza
asomoza1 year ago (edited 1 year ago)❀ 1

No problem, english is not my native language so I think the problem is more on my side.

Since the generation I used was for a width and height of 1024x1024, that's a 1:1 aspect ratio, so that's why I wrote 1:1, in this case meaning that I used square images as masks (1024x1024 to be precise) and in that case it works perfect.

The interpolation I'm referring is something like this:

mask = torch.nn.functional.interpolate(
                       mask.unsqueeze(1), 
                       size=(generation_height // self.vae_scale_factor, generation_width // self.vae_scale_factor),
                       mode="bicubic"
              ).squeeze(1)

this will distort the mask to match the generation aspect ratio but as you can see in my examples, it maintains the "areas" of the mask well.

fabiorigano
fabiorigano1 year ago

@yiyixuxu @sayakpaul thanks for your reviews!

@asomoza I use those masks in the example in #6847 (comment)
Try a different seed and let me know (the default output size for SDXL is 1024Γ—1024)

asomoza
asomoza1 year ago

I did try with different seeds and I always get some weird results. Just in case, the masks I provided aren't the same as the old ones, this are 1156x896 images not square ones, if it works for you then is a problem on my side.

fabiorigano
fabiorigano1 year ago

Oh, I just saw shape is different.
Maybe the aspect ratio of the masks is too different from that of the output

asomoza
asomoza1 year ago (edited 1 year ago)❀ 1

yeah, I tested it just in case, I don't really see the point of someone using it with masks that don't match the aspect ratio of the output, I didn't event bother to test it with input images that aren't square ones since that's even more pointless, but I still think people would do it and open issues about it.

the rest of the code looks good, nice work

sayakpaul
sayakpaul1 year ago

Can we maybe expose these options in the preprocessing methods in a nice way and document them? I imagine people would want to go for the best-practiced options for their use cases. So, what I understand is that depending on the aspect ratio, the interpolation behavior might change (sometimes we do crop/pad, and sometimes we don't).

Or is my understanding entirely wrong?

fabiorigano
fabiorigano1 year ago

the best practice is to use masks that have the same aspect ratio as the output image, I may add a UserWarning when they don't match.
Since in the attention processor we have no information about the aspect ratio of the output, I used padding and cropping to match the shape of the downsampled mask with the length of the target sequence. The other solution is to simply raise an exception to force users to change the mask or the shape of the output

fabiorigano Apply suggestions from code review
2af04268
fabiorigano Fix names
8f6247d3
fabiorigano Fix style
534b9d91
fabiorigano Update src/diffusers/image_processor.py
cb929ff7
fabiorigano Fix init + docstring
4763c82c
fabiorigano Merge branch 'main' of https://github.com/huggingface/diffusers into …
21efcd1c
fabiorigano Remove unnecessary parameters
4115a862
fabiorigano
fabiorigano1 year ago

also I think if we merge this #6915 (comment) we won't need to add the additional ip_adapter_mask argument to the default attention processors

@yiyixuxu thanks for merging your PR, I just removed unused arguments :)

fabiorigano Update test
b1b99008
yiyixuxu yiyixuxu added Good Example PR
fabiorigano
fabiorigano1 year ago

@yiyixuxu thanks for the Good Example PR label!

what can I do to fix the Fast tests checks? I see many OSError in logs, so maybe they should be run again?

sayakpaul
sayakpaul1 year agoπŸ‘ 1

That was because the HF Hub was down. I just rebased your PR with the latest main. Let's see :)

sayakpaul Merge branch 'main' into ipadaptermasks
bfe55a7e
fabiorigano
fabiorigano1 year ago❀ 1

I want to leave a comment on mask preprocessing for future documentation (maybe Sayak was asking here #6847 (comment))
We have several options:

  1. masks and output image have the same aspect ratio: preprocessing can be done with MaskImageProcessor.preprocess as in this example #6847 (comment) without further changes

  2. masks and output image don't have the same aspect ratio:

    a. (recommended) preprocessing can be done with MaskImageProcessor.preprocess but height and width of the output image must be passed as arguments like this: processor.preprocess([mask1, mask2], height=output_height, width=output_width). Masks will be stretched to fit the target shape

    b. if the aspect ratios are not very different, preprocessing can be done as in 1. Masks will preserve their original aspect ratio during downsampling, but some extra padding will be added if downsampling size doesn't match the number of queries in the attention. When apect ratios of masks and output image are very different, this option is not recommended.

@asomoza for completeness I tested your example in #6847 (comment). I leave here the change to the code and the resulting image:

# masks have both shape: (1152, 896) W,H
output_height = 1024
output_width = 1024
processor = IPAdapterMaskProcessor()
masks = processor.preprocess([mask1, mask2], height=output_height, width=output_width)
# masks have now shape: [2, 1, 1024, 1024] Num_Images, C, H, W 

p1_0

thanks everyone who contributed here!

sayakpaul
sayakpaul1 year ago

I think we should go with the simplest reasonable alternative from our code in the default setting and document the rest of the gotchas very clearly so that users can avail all the options. Goes well with our philosophy of being "simple over easy" as well.

What do y'all think?

sayakpaul Merge branch 'main' into ipadaptermasks
850c9091
fabiorigano
fabiorigano1 year ago

that sounds good to me

should we wait for PR #6897 to be merged?

yiyixuxu
yiyixuxu approved these changes on 2024-02-15
yiyixuxu1 year ago❀ 1

looks great to me!

sayakpaul Merge branch 'main' into ipadaptermasks
9197ca1e
yiyixuxu
yiyixuxu1 year ago

@sayakpaul feel free to merge this if you're happy about it!

sayakpaul
sayakpaul1 year ago

WDYT about adding a section about attention masking in the https://huggingface.co/docs/diffusers/main/en/using-diffusers/ip_adapter doc?

sayakpaul
sayakpaul1 year ago

Once that's done and reviewed, let's ship this πŸš€

DN6
DN6 commented on 2024-02-16
DN61 year agoπŸ‘ 1

Looks good to me. Just one small request related to testing.

fabiorigano Add test for one mask + bugfix
ec923bd2
fabiorigano Add docs
4bc2621a
fabiorigano
fabiorigano1 year ago

thanks @DN6, I added one more test

@yiyixuxu could you load this #6847 (comment) output image to your HF testing-images repository as "ip_adapter_masking_output.png" please? thank you

I can also load it to the documentation-images repository if it is faster

sayakpaul
sayakpaul1 year ago

thanks @DN6, I added one more test

@yiyixuxu could you load this #6847 (comment) output image to your HF testing-images repository as "ip_adapter_masking_output.png" please? thank you

I can also load it to the documentation-images repository if it is faster

Could you let me know which images do you want to see uploaded on the Hub? I can do that quickly :)

fabiorigano
fabiorigano1 year ago (edited 1 year ago)

@sayakpaul this one. thank you very much! it is the one obtained with seed = 0 (see docs)

ip_adapter_masking_output

fabiorigano Docs: update link
3315e81a
fabiorigano Update tensor conversion
c407388e
dhealy05
dhealy051 year ago

hey folks, tried out the branch, it works for me except when I call pipe.unload_ip_adapter() prior to loading the weights.

if load them initially -- success

if i load other weights, unload, and reload, i get: RuntimeError: mat1 and mat2 shapes cannot be multiplied (514x1664 and 1280x1280)

not sure if this is in scope here but just thought i would mention before it's merged!

sayakpaul
sayakpaul1 year ago

We are going to merge it. I welcome you to open a new issue with a fully reproducible code snippet afterward.

@yiyixuxu feel free to merge if this looks like a go to you.

sayakpaul Merge branch 'main' into ipadaptermasks
70376441
fabiorigano
fabiorigano1 year ago

@dhealy05 this happens because you are using a SDXL pipeline and not reloading the correct image encoder

When using IP-Adapters for SDXL, you must first load the CLIPVisionModelWithProjection image encoder from the "models/image_encoder" folder of "h94/IP-Adapter".

Calling pipeline.unload_ip_adapter() removes both IP-Adapter weights and image encoder from the pipeline.

This leads to the issue: by default, if you don't load an image encoder into the pipeline, it is searched in the IP-Adapter folder. In the case of the IP-Adapters for SDXL, this folder is "sdxl_models/image_encoder" and not "models/image_encoder".

To solve the problem, you need to reload the image encoder as follows:

# define the image_encoder
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    "h94/IP-Adapter", 
    subfolder="models/image_encoder",
    torch_dtype=torch.float16,
)

# define your pipeline
pipeline = AutoPipelineForText2Image.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    image_encoder=image_encoder
)
pipeline.to("cuda")

# load your IP-Adapters for SDXL (first time)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"]*2)

# do your inference

#unload IP-Adapters
pipeline.unload_ip_adapter()

# **reload image encoder in the pipeline (very important)**
pipeline.image_encoder=image_encoder

# load your IP-Adapters for SDXL  (second time)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"]*2)

# do your inference
dhealy05
dhealy051 year ago

@fabiorigano that was it, thank you !!

yiyixuxu yiyixuxu merged eba7e7a6 into main 1 year ago
sayakpaul
sayakpaul1 year ago❀ 1

Excellent work, @fabiorigano. Also, hat-tip to @asomoza for all the helpful suggestions and testing!

cubiq
cubiq1 year ago

I had a quick look at the code, sorry if a bit late

                mask_h = int(torch.sqrt(torch.tensor(seq_len / ratio)))
                mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0)
                mask_w = seq_len // mask_h

not sure why using torch.sqrt instead of math.sqrt. Feels like a waste. Also rounding mask_h in the first line I believe might introduce rounding errors.

torch.tensor(mask, dtype=torch.float32).unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)

I would switch to "bilinear" that is faster. I don't think the mask would need bicubic anyway.

                if mask_downsample.shape[0] < batch_size:
                    mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
                if mask_downsample.shape[0] > batch_size:
                    mask_downsample = mask_downsample[:batch_size, :, :]

use if...elif

from mask_downsample.repeat(batch_size, 1, 1) I assume you allow only 1 mask? If that is the case you should trim the tensor before downsampling otherwise you are wasting resources. If you allow only 1 mask, it's also unlikely that the second statement is ever true.

In comfyui I allow sending multiple masks that are applied one per latent in the batch. But there's no such logic here.

                if mask_h * mask_w < seq_len:
                    mask_downsample = F.pad(mask_downsample, (0, seq_len-mask_downsample.shape[1]), value=0.0)
                if mask_h * mask_w > seq_len:
                    mask_downsample = mask_downsample[:, :seq_len]

use if...elif

is elif somewhat discouraged in diffusers?

fabiorigano
fabiorigano1 year ago

hi @cubiq I think you are looking at an old implementation, here is the merged version

class IPAdapterMaskProcessor(VaeImageProcessor):

cubiq
cubiq1 year ago

@fabiorigano
oh okay sorry πŸ˜„ some of the remarks still stand

mask_h = int(math.sqrt(num_queries / ratio))
mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
mask_w = num_queries // mask_h

don't INT the first mask_h (I think might introduce more rounding errors). Or don't int it again in the second line.

I would use bilinear instead of bicubic.

Use if/elif.

sorry if it took me so long to reply

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone