diffusers
Stable-Diffusion-Inpainting: Training Pipeline V1.5, V2
#6922
Open

Stable-Diffusion-Inpainting: Training Pipeline V1.5, V2 #6922

cryptexis wants to merge 20 commits into huggingface:main from cryptexis:sd_15_inpainting
cryptexis
cryptexis1 year ago (edited 1 year ago)❤ 6🚀 2

What does this PR do?

This functionality allows training/fine-tuning of the 9 channel inpainting models provided by

This is due to noticing that many inpainting models provided to the community e.g. on https://civitai.com/ have unets with 4 input channels. 4 channel models may lack capacity and eventually quality in the inpainting tasks. To support the community to develop fully fledged inpainting models I have modified the text_to_image training pipeline to do inpainting.

Additions:

  • Added random masking strategy (squares) during the training, center crop during validation
  • Take first 3 images of the pokemon dataset as validation set

Before submitting

Who can review?

@sayakpaul and @patrickvonplaten

Examples Out of Training Distribution Scenery:

Prompt: a drawing of a green pokemon with red eyes

Pre-trained

pretrained_0

Fine-tuned

finetuned_0

Prompt: a green and yellow toy with a red nose

Pre-trained

pretrained_1

Fine-tuned

finetuned_1

Prompt: a red and white ball with an angry look on its face

Pre-trained

pretrained_2

Fine-tuned

finetuned_2

wip: training script
2116de29
wip: update documentation
882cb67b
fix: README
89854ee9
fix: README title
969605f1
sayakpaul sayakpaul requested a review from patil-suraj patil-suraj 1 year ago
cryptexis
cryptexis1 year ago

hi @patil-suraj @sayakpaul, was wondering if this is something interesting for you to look into ? Feedback is appreciated

yiyixuxu
yiyixuxu1 year ago

cool!
gentle pin @patil-suraj

drhead
drhead1 year ago👍 4

I've experimented with finetuning proper inpainting models before. I strongly urge you to read the LAMA paper (https://arxiv.org/pdf/2109.07161.pdf) and implement their masking strategy (which is what is used by the stable-diffusion-inpainting checkpoint). I used a very simple masking strategy like what you had for a long time and never got satisfactory results with my model until switching to the LAMA masking strategy. Training on simple white square masks will severely degrade the performance of the pretrained SD inpainting model.

sayakpaul
sayakpaul commented on 2024-02-19
Conversation is marked as resolved
Show resolved
examples/inpainting/train_inpainting.py
65}
66
67
68
def save_model_card(
sayakpaul1 year ago (edited 1 year ago)

Could you follow the structure of how the model cards are being from here?

cryptexis1 year ago

sure, thank you

cryptexis1 year ago

done

sayakpaul
sayakpaul commented on 2024-02-19
Conversation is marked as resolved
Show resolved
examples/inpainting/train_inpainting.py
74 repo_folder=None,
75):
76 img_str = ""
77
if len(images) > 0:
sayakpaul1 year ago

Better could be if images is not None.

cryptexis1 year ago

done

sayakpaul
sayakpaul commented on 2024-02-19
examples/inpainting/train_inpainting.py
603
604 if args.push_to_hub:
605 repo_id = create_repo(
606
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
sayakpaul1 year ago

Let's make sure to follow:

if args.report_to == "wandb" and args.hub_token is not None:

Otherwise, hub_token will be compromised on wandb run page.

cryptexis1 year ago

done

sayakpaul1 year ago

Seems like this comment wasn't addressed?

sayakpaul
sayakpaul commented on 2024-02-19
Conversation is marked as resolved
Show resolved
examples/inpainting/train_inpainting.py
622
623 return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
624
625
# Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
626
# For this to work properly all models must be run through `accelerate.prepare`. But accelerate
627
# will try to assign the same optimizer with the same weights to all models during
628
# `deepspeed.initialize`, which of course doesn't work.
629
#
630
# For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
631
# frozen models from being partitioned during `zero.Init` which gets called during
632
# `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
633
# across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
634
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
sayakpaul1 year ago

Do we need this? We only need it when fine-tuning multiple models jointly. I don't think that is the case here no?

cryptexis1 year ago

removed

HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev1 year ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

sayakpaul
sayakpaul commented on 2024-02-19
Conversation is marked as resolved
Show resolved
examples/inpainting/README.md
56**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) 768x768 model.___**
57<!-- accelerate_snippet_start -->
58```bash
59
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
sayakpaul1 year ago👍 1

Hmm, what if one wants to start with the SD v1.5 checkpoint? In that case, we will have to add the extra channels to the unet from the script, no? And show how they should be initialized?

I think that might be a good addition!

We do this from the InstructPix2Pix training script as well:
https://github.com/huggingface/diffusers/blob/instruct-pix2pix/emu/examples/instruct_pix2pix/train_instruct_pix2pix.py

cryptexis1 year ago

that's a very good point, thank you - I will get to it

cryptexis1 year ago

added

sayakpaul
sayakpaul commented on 2024-02-19
sayakpaul1 year ago

Left some initial comments. Looking quite nice.

I do think having an option to enable LAMA like making might be a very good reference point as our training scripts are quite widely referenced.

And I apologize for the delay.

sayakpaul Merge branch 'main' into sd_15_inpainting
18191cc5
cryptexis
cryptexis1 year ago👍 2

I've experimented with finetuning proper inpainting models before. I strongly urge you to read the LAMA paper (https://arxiv.org/pdf/2109.07161.pdf) and implement their masking strategy (which is what is used by the stable-diffusion-inpainting checkpoint). I used a very simple masking strategy like what you had for a long time and never got satisfactory results with my model until switching to the LAMA masking strategy. Training on simple white square masks will severely degrade the performance of the pretrained SD inpainting model.

@sayakpaul

I thought having the most simple implementation would do. And then the user can decide which masking strategy to use actually. Sure will add that, if that's a deal breaker

cryptexis
cryptexis1 year ago

@sayakpaul I have adapted masking strategy from LAMA paper on my local branch. I have a question, is it according to guidelines to have a config file properties for the masking separately, like here:
https://github.com/advimman/lama/blob/main/configs/training/data/abl-04-256-mh-dist-celeba.yaml#L10 ?

I feel it is a bit extensive and confusing to make all of those property values as part of CLI arguments, might clutter and confuse - which arguments are model specific and which ones are data specific.

wip: integrating LAMA masking
69d4494f
wip: merged commits
272dc875
sayakpaul
sayakpaul1 year ago

I feel it is a bit extensive and confusing to make all of those property values as part of CLI arguments, might clutter and confuse - which arguments are model specific and which ones are data specific.

You are absolutely correct. What we can do is include a note about the masking strategy in the README and link to your implementation. Does that sound good?

yiyixuxu yiyixuxu added training
wip: final fixes
07c8fd1a
wip: updating README
c1c3a0e3
sayakpaul Merge branch 'main' into sd_15_inpainting
cd619ffe
sayakpaul
sayakpaul approved these changes on 2024-03-02
sayakpaul1 year ago

Looking really nice now. I will let @patil-suraj review this too.

examples/inpainting/train_inpainting.py
350 prompt = batch["prompts"][0]
351
352 with torch.autocast("cuda"):
353
#### UPDATE PIPELINE HERE
sayakpaul1 year ago

Does this command need to be removed?

cryptexis1 year ago

which one ?

sayakpaul1 year ago

"#### UPDATE PIPELINE HERE"

Conversation is marked as resolved
Show resolved
examples/inpainting/train_inpainting.py
383
384def parse_args():
385 parser = argparse.ArgumentParser(description="Simple example of a training script.")
386
parser.add_argument(
387
"--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
388
)
sayakpaul1 year ago

(nit): Seems like a data-related argument which we can place with the other data related arguments, no? Also, an expanded help would be nice. I don't know what the perturbation correspond to.

cryptexis1 year ago (edited 1 year ago)👍 1

I took it from the text_to_image training as it is. If it's not super important can we keep it the way it is ?

Conversation is marked as resolved
Show resolved
examples/inpainting/train_inpainting.py
804 unet.register_to_config(in_channels=in_channels)
805
806 with torch.no_grad():
807
new_conv_in = torch.nn.Conv2d(
808
in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
809
)
810
new_conv_in.weight.zero_()
811
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
812
unet.conv_in = new_conv_in
sayakpaul1 year ago

Is this how it's usually initialized for inpainting? I know this is the case for InstructPix2Pix.

cryptexis1 year ago👍 1

yup, worked. I also tried with non-zero initialization, got a lot of burned pixels after some iterations. Will post results shortly

examples/inpainting/train_inpainting.py
1336 init_image = image_transform(batch["pixel_values"][0])
1337 prompt = batch["prompts"][0]
1338
1339
with torch.autocast("cuda"):
sayakpaul1 year ago

Let's make use of the log_validation() function here and log the results to wandb as well. You can refer to https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py for implementing this. But let me know if you need some more clarifications.

cryptexis1 year ago

done

sayakpaul
sayakpaul1 year ago👍 1

I think we also need to add a test case here.

cryptexis
cryptexis1 year ago
Screenshot 2024-03-02 at 13 20 53 @sayakpaul I think it's a github glitch. :) to the extent that I cannot reply you there.

https://github.com/cryptexis/diffusers/blob/sd_15_inpainting/examples/inpainting/train_inpainting.py#L771 - in my repo I do not have anything similar to it under those lines. And the piece of code you're referring to is here.

cryptexis
cryptexis1 year ago (edited 1 year ago)

I think we also need to add a test case here.

I see a lot of https://huggingface.co/hf-internal-testing is used in the testing. Are usual mortals able to add unit tests ?

cryptexis
cryptexis1 year ago

Examples Training with Random Masking

Inference with Square Mask (as before)

Prompt: a drawing of a green pokemon with red eyes

pre-trained stable-diffusion-inpainting

pretrained_inpainting_0

fine-tuned stable-diffusion-inpainting

finetuned_inpainting_0

pre-trained stable-diffusion-v1-5

pretrained_text2img_0

fine-tuned stable-diffusion-v1-5 (no inpainting)

finetuned_text2img_0

fine-tuned stable-diffusion-v1-5 (inpainting)

finetuned_text2img_to_inpainting_0

Inference with Random Mask

pre-trained stable-diffusion-inpainting

pretrained_inpainting_2

fine-tuned stable-diffusion-inpainting

finetuned_inpainting_2

pre-trained stable-diffusion-v1-5

pretrained_text2img_2

fine-tuned stable-diffusion-v1-5 (no inpainting)

finetuned_text2img_2

fine-tuned stable-diffusion-v1-5 (inpainting)

finetuned_text2img_to_inpainting_2

wip: last inference step with log_validation
94d877cb
Sanster
Sanster1 year ago👍 1👀 1

@cryptexis Thank you for providing the scripts and test cases. I want to train a inpainting model specifically for object removal based on the sd1.5-inpainting model, The goal of this model is to be able to remove objects without using a prompt, just like the ldm-inpainting model. Although the sd1.5-inpainting model can achieve decent results with the appropriate prompts (#973), it is often not easy to find the appropriate prompts, and it's easy to add extra objects.

Here's my plan right now:

  • I will not modify the StableDiffusionInpaintPipeline code, all prompts used during training are blank strings
  • The mask generation strategy will use methods from CM-GAN-Inpainting which is better than LaMA for inpainting. First use a segmentation model to process the images to obtain object masks. Then, randomly generated masks will never completely cover an object (for example, using 50% IOU as a threshold).

The generated mask looks like this:

image

I have not trained diffusion models before, any suggestions would be very helpful to me, thank you.

sayakpaul
sayakpaul commented on 2024-03-03
src/diffusers/loaders/single_file.py
5151 torch_dtype=None,
5252 **kwargs,
5353):
54
sayakpaul1 year ago

Unrelated change?

cryptexis1 year ago

somehow came after ruff formatting ...hmmmm did not intent to commit

sayakpaul
sayakpaul commented on 2024-03-03
Conversation is marked as resolved
Show resolved
examples/inpainting/train_inpainting.py
1314 # Run a final round of inference.
1315 if args.validation_size > 0:
1316 logger.info("Running inference for collecting generated images...")
1317
images, prompts, masks = log_validation(
1318
val_dataloader,
1319
vae,
1320
text_encoder,
1321
tokenizer,
1322
unet,
1323
args,
1324
accelerator,
1325
weight_dtype,
1326
global_step,
sayakpaul1 year ago

When calling log_validation here, we should spearate the logging key. For example, to log the validation results, we are using the "validation" key in the log_validation function. However, these results are coming from the final step, so there should be a distinction here.

This is how we do it:

phase_name = "test" if is_final_validation else "validation"

cryptexis1 year ago

you are totally right! thanks for pointing that out

sayakpaul
sayakpaul commented on 2024-03-03
sayakpaul1 year ago

Looking good. I think the only that is pending now is the testing suite.

sayakpaul Merge branch 'main' into sd_15_inpainting
5532dea7
cryptexis
cryptexis1 year ago

Looking good. I think the only that is pending now is the testing suite.

@sayakpaul worked yesterday on the tests. Hit a wall. Then tried to run tests for the text_to_image and hit the same wall:

attaching the screenshot:
Screenshot 2024-03-03 at 06 56 07

Was wondering if it is a systematic issue across all tests....

sayakpaul
sayakpaul1 year ago❤ 1

@sayakpaul worked yesterday on the tests. Hit a wall. Then tried to run tests for the text_to_image and hit the same wall:

Had it been the case, it would have been caught in the CI. The CI doesn't indicate so. Feel free to push the tests and then we can work towards fixing them. WDYT?

BTW, for fixing the code quality issues, we need to run make style && make quality from the root of diffusers.

wip: fixing log_validation, tests
5dd28bd4
Merge branch 'sd_15_inpainting' of github.com:cryptexis/diffusers int…
235655fb
cryptexis
cryptexis1 year ago

@sayakpaul worked yesterday on the tests. Hit a wall. Then tried to run tests for the text_to_image and hit the same wall:

Had it been the case, it would have been caught in the CI. The CI doesn't indicate so. Feel free to push the tests and then we can work towards fixing them. WDYT?

BTW, for fixing the code quality issues, we need to run make style && make quality from the root of diffusers.

Done @sayakpaul , I think everything is addressed, tests are pushed. Thanks a lot for the patience, support and all the help!

sayakpaul run quality
5179539e
crapthings
crapthings1 year ago👍 1

How to prepare dataset?

image
mask
prompt

sayakpaul Merge branch 'main' into sd_15_inpainting
2d075742
sayakpaul
sayakpaul1 year ago

@cryptexis let's fix the example tests that are failing now.

Srinivasa-N707
Srinivasa-N7071 year ago👍 2

can anyone share script of sdxl inpainting fine tuning?

patil-suraj
patil-suraj approved these changes on 2024-03-11
patil-suraj1 year ago

Thanks a lot for working on this, the script looks great! Just left some nits.

For the runwayml inpainting model, during training they mask the whole image 25% of the time. Have you experimented with that ?

Conversation is marked as resolved
Show resolved
examples/inpainting/README.md
1# Stable Diffusion Inpainting fine-tuning
2
3
The `train_inpainting.py` script shows how to fine-tune stable diffusion model on your own dataset.
patil-suraj1 year ago
Suggested change
The `train_inpainting.py` script shows how to fine-tune stable diffusion model on your own dataset.
The `train_inpainting.py` script shows how to train/fine-tune stable diffusion model for inpainting on your own dataset.
examples/inpainting/requirements.txt
5ftfy
6tensorboard
7Jinja2
8
peft==0.7.0
patil-suraj1 year ago

do we need peft for this example ?

Conversation is marked as resolved
Show resolved
examples/inpainting/train_inpainting.py
1#!/usr/bin/env python
2# coding=utf-8
3
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
patil-suraj1 year ago
Suggested change
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
Conversation is marked as resolved
Show resolved
examples/inpainting/train_inpainting.py
99 return torch.from_numpy(mask[None, ...]).squeeze(0).byte()
100
101
102
class RandomIrregularMaskGenerator:
103
"""
104
Initializes the RandomIrregularMaskGenerator with the provided parameters.
105
106
Parameters:
107
max_angle (int): The maximum angle for the line segments, influencing the irregularity of the shapes.
108
max_len (int): The maximum length for each line segment, affecting the size of the irregular shapes.
109
max_width (int): The maximum width for each line segment, determining the thickness of the irregular shapes.
110
min_times (int): The minimum number of irregular shapes to be generated on the mask.
111
max_times (int): The maximum number of irregular shapes to be generated on the mask.
112
"""
113
114
def __init__(self, max_angle, max_len, max_width, min_times, max_times):
115
self.max_angle = max_angle
116
self.max_len = max_len
117
self.max_width = max_width
118
self.min_times = min_times
119
self.max_times = max_times
120
121
def __call__(self, img_shape):
122
"""
123
Generates a mask with random irregular shapes when called with an image.
124
125
Parameters:
126
img (tuple): Tuple of image dimensions, excluding channels.
127
128
Returns:
129
np.array: A mask array with the same height and width as the input image, containing random irregular shapes.
130
"""
131
cur_max_len = int(max(1, self.max_len))
132
cur_max_width = int(max(1, self.max_width))
133
cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times))
134
return make_random_irregular_mask(
135
img_shape,
136
max_angle=self.max_angle,
137
max_len=cur_max_len,
138
max_width=cur_max_width,
139
min_times=self.min_times,
140
max_times=cur_max_times,
141
)
142
143
144
class RandomRectangleMaskGenerator:
145
"""
146
A generator class for creating masks with random rectangular shapes on images.
147
The rectangles are defined within specified constraints for margins, size, and the number of times they appear.
148
149
Attributes:
150
margin (int): The minimum distance between the rectangle edges and the image boundaries.
151
bbox_min_size (int): The minimum size for the width and height of the rectangles.
152
bbox_max_size (int): The maximum size for the width and height of the rectangles.
153
min_times (int): The minimum number of rectangles to be generated on the mask.
patil-suraj1 year ago

Very cool!

Conversation is marked as resolved
Show resolved
examples/inpainting/train_inpainting.py
801 args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
802 )
803
804
# InstructPix2Pix uses an additional image for conditioning. To accommodate that,
805
# it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
806
# then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
807
# from the pre-trained checkpoints. For the extra channels added to the first layer, they are
808
# initialized to zero.
patil-suraj1 year ago
Suggested change
# InstructPix2Pix uses an additional image for conditioning. To accommodate that,
# it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
# then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
# from the pre-trained checkpoints. For the extra channels added to the first layer, they are
# initialized to zero.
# For inpainting an additional image is used for conditioning. To accommodate that,
# it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
# then fine-tuned on the custom inpainting dataset. This modified UNet is initialized
# from the pre-trained checkpoints. For the extra channels added to the first layer, they are
# initialized to zero.
Conversation is marked as resolved
Show resolved
examples/inpainting/train_inpainting.py
807 # from the pre-trained checkpoints. For the extra channels added to the first layer, they are
808 # initialized to zero.
809
810
# when most likely a text2img pretrained model is used
patil-suraj1 year ago
Suggested change
# when most likely a text2img pretrained model is used
github-actions
github-actions1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions github-actions added stale
cs-mshah
cs-mshah1 year ago

When is this getting merged?

yiyixuxu yiyixuxu removed stale
yiyixuxu Update examples/inpainting/README.md
8f33ed19
yiyixuxu Update examples/inpainting/train_inpainting.py
f2b04e32
yiyixuxu Update examples/inpainting/train_inpainting.py
7dc6bfb3
yiyixuxu Update examples/inpainting/train_inpainting.py
d11619a2
yiyixuxu
yiyixuxu1 year ago (edited 1 year ago)

@cryptexis
can you

  1. address the final comments here #6922 (comment) - if peft is not used we can remove it; otherwise we are all good
  2. make sure the tests pass

will merge once the tests pass!

zijinY
zijinY1 year ago

@Sanster Thanks for your plan, I also want to finetune an stable difffusion inpainting model for object removal. Have you tried this, how is the performance?

github-actions
github-actions233 days ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions github-actions added stale
fire2323
fire232384 days ago

Hi patil-suraj @patil-suraj , appreciated for the convenient script ! Is there any code example and dataset example to run the script: https://github.com/huggingface/diffusers/blob/inpainting-script/examples/inpainting/train_inpainting_sdxl.py ?

github-actions github-actions removed stale
github-actions
github-actions58 days ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions github-actions added stale

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone