sd-scripts
5a822a40 - Add inpainting training and sampling support for SD1.5 and SDXL (#2309)

Commit
22 days ago
Add inpainting training and sampling support for SD1.5 and SDXL (#2309) * Add inpainting support based off of original Fannovel16 push that didn't appear to get merged * Add inpainting training and sampling support for SD1.5 and SDXL - 9-channel UNet input (noisy_latents + mask + masked_image_latents) wired through all training scripts (train_db, train_network, fine_tune, train_textual_inversion, sdxl_train) - Auto-detect in_channels from checkpoint conv_in weight shape in model_util.py and sdxl_model_util.py; UNet constructors accept explicit in_channels parameter - Inpainting inference added to lpw_stable_diffusion.py and sdxl_lpw_stable_diffusion.py: encodes masked image before denoising loop, prepends 9-ch input each step; latent init uses vae.config.latent_channels (4) not unet.in_channels (9) - --train_inpainting CLI flag; cache_latents incompatibility assertion; --img prompt directive for sampling source image; missing image gracefully skips sample; resolution rounded to multiples of 64 - library/mask_generator.py: procedural cloud (fBm), polygon, shape, and combined random mask generation using numpy/cv2/PIL - tests/: synthetic data generator, mask visualizer, HuggingFace training data downloader, SD1.5 and SDXL smoke test scripts and TOML * Have tests/visualize_masks.py use downloaded training data * Add documentation for the inpainting feature * Added inpainting_minimal_inference.py for inpainting SD1.5/SDXL testing. Added wobbly elipse mask for better sampling * Support standard (4-ch) checkpoints for inpainting training; add SD1.5 smoke test Add expand_unet_to_inpainting() to model_util.py, which expands conv_in from 4 to 9 channels when --train_inpainting is set on a standard (non-inpainting) checkpoint. Original weights are preserved in channels 0-3; channels 4-8 are zero-initialised. Called automatically in both train_util.load_target_model (SD1.5) and sdxl_train_util.load_target_model (SDXL) when in_channels==4. Also fix a FutureWarning from diffusers by setting steps_offset=1 in get_my_scheduler(), matching the expected SD1.5 scheduler configuration. Add tests/sd15_inpainting_test.toml and tests/run_sd15_inpainting_test.sh, a smoke test equivalent to the SDXL one using train_db.py at 512x512/fp16. Accepts both standard and inpainting SD1.5 checkpoints. Update docs/inpainting_training.md to reflect that standard checkpoints now work automatically with --train_inpainting. * Fix mask/masked_image batch shapes and simplify mask interpolation In prepare_mask_and_masked_image (train_util.py), image was shaped [1,C,H,W] and mask [1,1,H,W] due to image[None] and mask[None,None]. After torch.stack these became [B,1,C,H,W] and [B,1,1,H,W]. Fixed to image.transpose(2,0,1) → [C,H,W] and mask[None] → [1,H,W], so stacked batches are the correct [B,C,H,W] and [B,1,H,W]. Removed the .reshape(batch["images"].shape) workaround from fine_tune.py, train_db.py, and train_network.py that was compensating for the extra dim. Replaced the per-item interpolate loop + stack + reshape in fine_tune.py, train_db.py, train_network.py, and sdxl_train.py with a single F.interpolate call on the full [B,1,H,W] batch tensor.
Author
Parents
Loading