Add inpainting training and sampling support for SD1.5 and SDXL (#2309)
* Add inpainting support based off of original Fannovel16 push that didn't appear to get merged
* Add inpainting training and sampling support for SD1.5 and SDXL
- 9-channel UNet input (noisy_latents + mask + masked_image_latents)
wired through all training scripts (train_db, train_network,
fine_tune, train_textual_inversion, sdxl_train)
- Auto-detect in_channels from checkpoint conv_in weight shape in
model_util.py and sdxl_model_util.py; UNet constructors accept
explicit in_channels parameter
- Inpainting inference added to lpw_stable_diffusion.py and
sdxl_lpw_stable_diffusion.py: encodes masked image before denoising
loop, prepends 9-ch input each step; latent init uses
vae.config.latent_channels (4) not unet.in_channels (9)
- --train_inpainting CLI flag; cache_latents incompatibility assertion;
--img prompt directive for sampling source image; missing image
gracefully skips sample; resolution rounded to multiples of 64
- library/mask_generator.py: procedural cloud (fBm), polygon, shape,
and combined random mask generation using numpy/cv2/PIL
- tests/: synthetic data generator, mask visualizer, HuggingFace
training data downloader, SD1.5 and SDXL smoke test scripts and TOML
* Have tests/visualize_masks.py use downloaded training data
* Add documentation for the inpainting feature
* Added inpainting_minimal_inference.py for inpainting SD1.5/SDXL testing.
Added wobbly elipse mask for better sampling
* Support standard (4-ch) checkpoints for inpainting training; add SD1.5 smoke test
Add expand_unet_to_inpainting() to model_util.py, which expands conv_in from
4 to 9 channels when --train_inpainting is set on a standard (non-inpainting)
checkpoint. Original weights are preserved in channels 0-3; channels 4-8 are
zero-initialised. Called automatically in both train_util.load_target_model
(SD1.5) and sdxl_train_util.load_target_model (SDXL) when in_channels==4.
Also fix a FutureWarning from diffusers by setting steps_offset=1 in
get_my_scheduler(), matching the expected SD1.5 scheduler configuration.
Add tests/sd15_inpainting_test.toml and tests/run_sd15_inpainting_test.sh,
a smoke test equivalent to the SDXL one using train_db.py at 512x512/fp16.
Accepts both standard and inpainting SD1.5 checkpoints.
Update docs/inpainting_training.md to reflect that standard checkpoints now
work automatically with --train_inpainting.
* Fix mask/masked_image batch shapes and simplify mask interpolation
In prepare_mask_and_masked_image (train_util.py), image was shaped
[1,C,H,W] and mask [1,1,H,W] due to image[None] and mask[None,None].
After torch.stack these became [B,1,C,H,W] and [B,1,1,H,W]. Fixed to
image.transpose(2,0,1) → [C,H,W] and mask[None] → [1,H,W], so stacked
batches are the correct [B,C,H,W] and [B,1,H,W].
Removed the .reshape(batch["images"].shape) workaround from fine_tune.py,
train_db.py, and train_network.py that was compensating for the extra dim.
Replaced the per-item interpolate loop + stack + reshape in fine_tune.py,
train_db.py, train_network.py, and sdxl_train.py with a single
F.interpolate call on the full [B,1,H,W] batch tensor.