diffusers
Comprehensive type checking for `from_pretrained` kwargs
#10758
Merged

Comprehensive type checking for `from_pretrained` kwargs #10758

guiyrt
guiyrt172 days ago (edited 161 days ago)

What does this PR do?

Changes

  • Moved type checking to just before pipeline instantiation, so all kwargs are checked
  • Full type-check for collections (list, dicts, ...), every element checked
  • More detailed warning for unexpected arguments type (List[ControlNetModel] instead of list)

To-do

  • Where should the new functions is_valid_type and get_detailed_type be placed?
  • According to new functions location, add simple tests for type checking.

These changes are proposed based on testing for #10747.

Example warning when providing controlnetas List[ControlNetUnionModel] for StableDiffusionXLControlNetPipeline, where List[ControlNetModel] is expected:

Expected types for controlnet: (<class 'diffusers.models.controlnets.controlnet.ControlNetModel'>,
typing.List[diffusers.models.controlnets.controlnet.ControlNetModel], 
typing.Tuple[diffusers.models.controlnets.controlnet.ControlNetModel],
<class 'diffusers.models.controlnets.multicontrolnet.MultiControlNetModel'>),
got typing.List[diffusers.models.controlnets.controlnet_union.ControlNetUnionModel].
Code for warning replication
import torch

from diffusers import StableDiffusionXLControlNetPipeline
from diffusers.models import ControlNetUnionModel, AutoencoderKL
from diffusers.utils import load_image


controlnet = ControlNetUnionModel.from_pretrained(
    "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
)

vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=[controlnet, controlnet],
    vae=vae,
    torch_dtype=torch.float16,
    variant="fp16",
)

room_seg_img = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/room_seg.png")
pose_img = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png")


pipe.enable_model_cpu_offload()

image = pipe(
    prompt="an astronaut in a space station",
    width=1024,
    height=1024,
    negative_prompt="lowres, low quality, worst quality",
    generator=torch.manual_seed(42),
    guidance_scale=5,
    num_inference_steps=50,
    image=[pose_img, room_seg_img],
).images[0]

image.save("result.jpg")

Before submitting

Who can review?

@hlky

guiyrt More robust from_pretrained init_kwargs type checking
ce75466a
guiyrt Corrected for Python 3.10
b1f26c53
HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev172 days ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

hlky
hlky approved these changes on 2025-02-10
hlky172 days ago👍 1

Thanks @guiyrt, nice work. Could you take a look through pipeline test output searching for Expected types for (GitHub's built in search works best)

There are some easy cases that we could fix like

Expected types for feature_extractor: (<class 'transformers.models.clip.image_processing_clip.CLIPImageProcessor'>,), got <class 'transformers.models.clip.feature_extraction_clip.CLIPFeatureExtractor'>.

For that we could do global find+replace on feature_extractor: CLIPImageProcessor -> feature_extractor: CLIPFeatureExtractor.

and some that need investigating

Expected types for unet: (<class 'inspect._empty'>,), got <class 'diffusers.models.unets.unet_2d.UNet2DModel'>.

Type correctness is not strictly enforced so some warnings are expected but we should make a best effort to reduce the number of new warnings that we're introducing. If we find a particular component to be a problem we can skip it like scheduler.

hlky
hlky172 days ago

Failing tests appear unrelated, will re-run later.

hlky hlky requested a review from yiyixuxu yiyixuxu 172 days ago
guiyrt Type checks subclasses and fixed type warnings
5ca27aaf
guiyrt
guiyrt170 days ago (edited 170 days ago)

@hlky Findings from looking through the test logs

TL;DR
tokenizer is the one with most warnings, for example, when T5Tokenizer is annotated but T5TokenizerFast is used. Most of the warnings are smaller things and most are corrected/addressed in 5ca27aa. Doing find+replace for Union[BaseTokenizer, FastTokenizer] deals with this problem, but will change many files, is this ok?

1. Using XYZFast tokenizer when only XYZ is annotated (and vice-versa)

We can make a quick search and replace and update all tokenizer annotations to be Union[XYZBase, XYZFast], but as this is a big change, let me know if you agree.

18 occurrences

Expected types for tokenizer: (<class 'transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer'>,), 
got <class 'transformers.models.xlm_roberta.tokenization_xlm_roberta_fast.XLMRobertaTokenizerFast'>.

68 occurrences

Expected types for tokenizer: (<class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>,),
got <class 'transformers.models.t5.tokenization_t5_fast.T5TokenizerFast'>.

9 occurrences

Expected types for tokenizer: (<class 'transformers.models.bert.tokenization_bert.BertTokenizer'>,),
got <class 'transformers.models.bert.tokenization_bert_fast.BertTokenizerFast'>.

4 occurences

Expected types for tokenizer: (<class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>,), got <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>.

2. CLIPFeatureExtractor as feature_extractor

This comes from tests that use a hf-internal-testing repo with legacy CLIPFeatureExtractor instead of CLIPImageProcessor. A warning from transformers is also thrown FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.

15 occurrences

Expected types for feature_extractor: (<class 'transformers.models.clip.image_processing_clip.CLIPImageProcessor'>,),
got <class 'transformers.models.clip.feature_extraction_clip.CLIPFeatureExtractor'>.

def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self):
repo_id = "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds"

3. PipelineFastTests::test_optional_components

This test purposefully sets requires_safety_checker as [True, True] and safety_checker as an unet and feature_extractor as a function, so the warnings here are expected.

2+2+2 occurrences

Expected types for safety_checker: (<class 'diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker'>,),
got <class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>.
Expected types for requires_safety_checker: (<class 'bool'>,), got typing.List[bool].
Expected types for feature_extractor: (<class 'transformers.models.clip.image_processing_clip.CLIPImageProcessor'>,),
got <class 'function'>.

# Test that partially loading works
sd = StableDiffusionPipeline.from_pretrained(
tmpdirname,
feature_extractor=self.dummy_extractor,
safety_checker=unet,
requires_safety_checker=[True, True],
)

4. Missing type hinting

Added the intended types.

6+6 occurrences from HunyuanDiTPipelines.

Expected types for text_encoder_2: (<class 'inspect._empty'>,),
got <class 'transformers.models.t5.modeling_t5.T5EncoderModel'>.
Expected types for tokenizer_2: (<class 'inspect._empty'>,),
got <class 'transformers.models.t5.tokenization_t5_fast.T5TokenizerFast'>.

text_encoder_2=T5EncoderModel,
tokenizer_2=MT5Tokenizer,

text_encoder_2=T5EncoderModel,
tokenizer_2=MT5Tokenizer,

7+7 occurrences from CustomPipeline tests. Only showed for unet because scheduler is not checked.

Expected types for unet: (<class 'inspect._empty'>,),
got <class 'diffusers.models.unets.unet_2d.UNet2DModel'>.
Expected types for unet: (<class 'inspect._empty'>,),
got <class 'diffusers.models.unets.unet_1d.UNet1DModel'>.

def __init__(self, unet, scheduler):

def __init__(self, unet, scheduler):

def __init__(self, unet, scheduler):

5. CustomPipelineTests

Not sure what to make of this.

2+2+2 occurrences

Expected types for unet: (<class 'diffusers_modules.local.unet.my_unet_model.MyUNetModel'>,),
got <class 'diffusers_modules.local.my_unet_model.MyUNetModel'>.
Expected types for unet: (<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,),
got <class 'diffusers_modules.local.my_unet_model.MyUNetModel'>.
Expected types for unet: (<class 'diffusers_modules.local.unet.my_unet_model.MyUNetModel'>,),
got <class 'diffusers_modules.local.my_unet_model.MyUNetModel'>.

6. Incomplete unet type hints in AnimateDiffVideoToVideoPipelines

Changed to be unet: Union[UNet2DConditionModel, UNetMotionModel], as in AnimateDiffSDXLPipeline.

8 occurrences

Expected types for unet: (<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,),
got <class 'diffusers.models.unets.unet_motion_model.UNetMotionModel'>.

7. Subclasses not checked in is_valid_type

This should be correct, so changed is_valid_type to use isinstance to also allow subsclasses. This way, you can annotate with a parent class.

4 occurrences

Expected types for bert: (<class 'transformers.modeling_utils.PreTrainedModel'>,),
got <class 'transformers.models.clip.modeling_clip.CLIPTextModel'>.

8. Type hinting with AutoTokenizer and AutoModel on SanaPipeliness

The proper base classes are PreTrainedModel and PreTrainedTokenizerBase.

7 occurrences

Expected types for tokenizer: (<class 'transformers.models.auto.tokenization_auto.AutoTokenizer'>,),
got <class 'transformers.models.gemma.tokenization_gemma.GemmaTokenizer'>.

3 occurences

Expected types for tokenizer: (<class 'transformers.models.auto.tokenization_auto.AutoTokenizer'>,),
got <class 'transformers.models.gemma.tokenization_gemma_fast.GemmaTokenizerFast'>.

3 occurrences

Expected types for text_encoder: (<class 'transformers.models.auto.modeling_auto.AutoModel'>,),
got <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'>.

3 occurrences

Expected types for text_encoder: (<class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>,),
got <class 'transformers.models.gemma2.modeling_gemma2.Gemma2ForCausalLM'>.

4 occurrences

Expected types for text_encoder: (<class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>,),
got <class 'transformers.models.gemma2.modeling_gemma2.Gemma2Model'>.

tokenizer: AutoTokenizer,
text_encoder: AutoModelForCausalLM,

tokenizer: AutoTokenizer,
text_encoder: AutoModelForCausalLM,

9. Interchanged use of CLIPTextModel and CLIPTextModelWithProjection

Just swapped with the intended type.

4 occurrences

Expected types for text_encoder: (<class 'transformers.models.clip.modeling_clip.CLIPTextModelWithProjection'>,),
got <class 'transformers.models.clip.modeling_clip.CLIPTextModel'>.

text_encoder: CLIPTextModelWithProjection,

8 occurrences

Expected types for text_encoder: (<class 'transformers.models.clip.modeling_clip.CLIPTextModel'>,),
got <class 'transformers.models.clip.modeling_clip.CLIPTextModelWithProjection'>.

4 occurrences

Expected types for prior_text_encoder: (<class 'transformers.models.clip.modeling_clip.CLIPTextModel'>,),
got <class 'transformers.models.clip.modeling_clip.CLIPTextModelWithProjection'>.

guiyrt Merge branch 'main' into pipe_from_pretrained_typecheck
70efb66f
guiyrt
guiyrt170 days ago

Found another warning related to custom pipelines, this time on "hf-internal-testing/diffusers-dummy-pipeline". The fix is having the correct type hinting there.

4 occurrences

Expected types for unet: (<class 'inspect._empty'>,),
got <class 'diffusers.models.unets.unet_2d.UNet2DModel'>

def test_run_custom_pipeline(self):
pipeline = DiffusionPipeline.from_pretrained(
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
)
pipeline = pipeline.to(torch_device)
images, output_str = pipeline(num_inference_steps=2, output_type="np")
assert images[0].shape == (1, 32, 32, 3)
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
assert output_str == "This is a test"

guiyrt
guiyrt170 days ago

I opened PRs on the hf-internal-testing repos with warnings:
Replacing deprecated CLIPFeatureExtractor for CLIPImageProcessor 1
Replacing deprecated CLIPFeatureExtractor for CLIPImageProcessor 2
Added type annotations for pipeline init args

The last one is regarding arguments with no annotations, which might be common on custom pipelines? To keep warnings relevant, it might be a good idea to skip type checking if there is no type annotated for a given argument.

hlky
hlky170 days ago
  1. Maybe we just don't check tokenizer
  2. If it's an internal testing checkpoint the warning is not important
  3. If it's expected it's ok
  4. Thanks
  5. Custom pipeline so not important
  6. Thanks
  7. Annotations of parent class PreTrainedModel are not that useful, changing to correct type be left as a todo
  8. As above
  9. Thanks
guiyrt
guiyrt169 days ago (edited 169 days ago)
  1. Maybe we just don't check tokenizer

tokenizer is now skipped

  1. Annotations of parent class PreTrainedModel are not that useful, changing to correct type be left as a todo

This to-do includes:

  • LDMTextToImagePipeline: Supposedly the types in the docs are LDMBertModel and BertTokenizer, but text_encoder and tokenizer used in the tests are CLIPTextModel and CLIPTokenizer.
  • LuminaText2ImgPipeline and Lumina2Text2ImgPipeline: Docs mention T5, but tests use Gemma.
  • SanaPAGPipelineand SanaPipeline: From what I understood and the tests, these use GemmaTokenizer[Fast] and Gemma2Model for SanaPipeline tests and Gemma2CausalLM for SanaPAGPipeline, so I annotated text_encoder with Gemma2PreTrainedModel type, as it works for both.
guiyrt More type corrections and skip tokenizer type checking
12eb38fa
guiyrt
guiyrt169 days ago

@hlky can we move the functions get_detailed_type and is_valid_type to some utils file? Maybe a new typing_utils.py or something? I don't think they should stay in the middle of DiffusionPipeline::from_pretrained

guiyrt Merge branch 'main' into pipe_from_pretrained_typecheck
5e717536
hlky
hlky169 days ago

LDMTextToImagePipeline I'm not sure, could be that it supports both types or incorrect type hint.

LuminaText2ImgPipeline Yeah Lumina is Gemma, docstring/type hint would have been copied from some other pipeline.

SanaPAGPipeline and SanaPipeline is Gemma, not sure if it should be Gemma2Model or Gemma2ForCausalLM though and I'm assuming it's supposed to be the same for both. Using Gemma2PreTrainedModel should be ok.

@yiyixuxu WDYT about typing_utils.py? Might be some other code that could be moved there, these functions probably won't be used elsewhere though. IMO I don't mind the functions being in from_pretrained, if I'm working on from_pretrained I've got all the context of those functions immediately available.

guiyrt make style && make quality
03a8fcf1
guiyrt
guiyrt169 days ago (edited 169 days ago)

LuminaText2ImgPipeline Yeah Lumina is Gemma, docstring/type hint would have been copied from some other pipeline.

I updated the docs. But if I got it right, Lumina v1 uses GemmaModel and v2 uses Gemma2Model, however the FastTests of both used GemmaForCausalLM. For Lumina v1 we can annotate as GemmaPreTrainedModel, but if we annotated Lumina v2 with Gemma2PreTrainedModel it would produce a warning for the tests. So on top of updating the type annotations, I also updated the tests for Lumina v2 to use Gemma2ForCausalLM, it was easy because there is no comparison with expected output hardcoded. Ran it locally and passed :)

guiyrt Updated docs and types for Lumina pipelines
0afbe6c0
guiyrt
guiyrt169 days ago

Tests failed due to network issues I think. I noticed yesterday very slow download speeds from the hub, anything you are aware?

hlky
hlky169 days ago❤ 1

Just temporary issues, happens sometimes. Thanks for all the iterations on this @guiyrt, should be good to go after @yiyixuxu's comments on whether to add typing_utils.py.

guiyrt Fixed check for empty signature
e367fd32
shethaadit
shethaadit approved these changes on 2025-02-19
guiyrt Merge branch 'main' into pipe_from_pretrained_typecheck
48dbd541
hlky
hlky162 days ago

Gentle ping @yiyixuxu

yiyixuxu
yiyixuxu162 days ago👍 1
guiyrt changed location of helper functions
b17fc6ea
guiyrt Merge branch 'main' into pipe_from_pretrained_typecheck
95560751
guiyrt Merge branch 'main' into pipe_from_pretrained_typecheck
bb8d8628
hlky make style
46d46297
hlky
hlky160 days ago👍 1

Failing tests are unrelated.

Thanks @guiyrt

hlky hlky merged 9c7e2051 into main 160 days ago
guiyrt guiyrt deleted the pipe_from_pretrained_typecheck branch 160 days ago

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone