diffusers
[Quantization] Add quantization support for `bitsandbytes`
#9213
Merged

[Quantization] Add quantization support for `bitsandbytes` #9213

sayakpaul merged 119 commits into main from quantization-config
sayakpaul
sayakpaul269 days ago (edited 246 days ago)โค 8

What does this PR do?

Come back later.

  • Quantization config class (base and bitsandbytes)
  • Quantizer class (base and bitsandbytes)
  • Utilities related to bitsandbytes
  • from_pretrained() at the ModelMixin level and related changes
  • save_pretrained()
  • NF4 tests
  • INT8 (llm.int8()) tests
  • Docs

Notes

  • Even though I alluded to having a separate QuantizationLoaderMixin in #9174, I realized that is not an approach we can take because loading and saving a quantized model is very much baked into the arguments of ModelMixin.save_pretrained() and ModelMixin.from_pretrained(). It is deeply entangled.
  • For the initial quantization support, I think it's okay to not allow passing device_map, because for a pipeline, multiple device_maps can get ugly. This will be dealt with in a follow-up PR by @SunMarc and myself.
  • For the point above, for checkpoints that are found to be sharded (Flux, for example), I have decided to merge them on CPU to simplify the implementation. This will be dealt with in a follow-up PR by @SunMarc.
  • The PR has an extensive testing suite covering training, too. However, I have decided not to add it to our CI yet. We should first let this feature flow into the community and then add the tests to our nightly CI.

No-frills code snippets

Serialization
import torch 
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel, FluxPipeline
from accelerate.utils import compute_module_sizes

model_id = "black-forest-labs/FLUX.1-dev"

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model_nf4 = FluxTransformer2DModel.from_pretrained(
    model_id, subfolder="transformer", quantization_config=nf4_config, torch_dtype=torch.bfloat16
)
assert model_nf4.dtype == torch.uint8, model_nf4.dtype
print(model_nf4.dtype)
print(model_nf4.config.quantization_config)
print(compute_module_sizes(model_nf4)[""] / 1024 / 1024)

push_id = "sayakpaul/flux.1-dev-nf4-with-bnb-integration"
model_nf4.push_to_hub(push_id)

Serialized checkpoint: https://huggingface.co/sayakpaul/flux.1-dev-nf4-with-bnb-integration.

NF4 checkpoints of Flux transformer and T5: https://huggingface.co/sayakpaul/flux.1-dev-nf4-pkg (has Colab Notebooks, too).

Inference
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline

model_id = "black-forest-labs/FLUX.1-dev"
nf4_id = "sayakpaul/flux.1-dev-nf4-with-bnb-integration"
model_nf4 = FluxTransformer2DModel.from_pretrained(nf4_id, torch_dtype=torch.bfloat16)
print(model_nf4.dtype)
print(model_nf4.config.quantization_config)

pipe = FluxPipeline.from_pretrained(model_id, transformer=model_nf4, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

prompt = "A mystic cat with a sign that says hello world!"
image = pipe(prompt, guidance_scale=3.5, num_inference_steps=50, generator=torch.manual_seed(0)).images[0]
image.save("flux-nf4-dev-loaded.png")
sayakpaul quantization config.
e634ff24
sayakpaul sayakpaul added quantization
sayakpaul sayakpaul requested a review from DN6 DN6 269 days ago
sayakpaul sayakpaul requested a review from SunMarc SunMarc 269 days ago
sayakpaul fix-copies
02a6dffd
HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev269 days ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

sayakpaul Merge branch 'main' into quantization-config
c385a2bb
sayakpaul Merge branch 'main' into quantization-config
0355875d
SunMarc
SunMarc commented on 2024-08-20
SunMarc267 days ago

Thanks for adding this ! I see that you used a lot of things from transformers. Do you think it is possible to import these (or inherit) from transformers ? This will help reducing the maintenance. I'm fine also doing that since there are not too many follow-up PR after a quantizer has been added. About the HfQuantizer class, there are a lot of methods that were created to fit transformers structure. I'm not sure we will need eveyone of these methods in diffusers. Ofc, we can still do a follow-up PR to clean up.

Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/base.py
53
54 requires_calibration = False
55 required_packages = None
56
requires_parameters_quantization = False
SunMarc267 days ago

We can safely remove this var, redundant with check_quantized_param which is set to False by default. I will also remove this from transformers soon.

sayakpaul266 days ago

6e86cc0 should have dealt with this.

sayakpaul
sayakpaul267 days ago (edited 266 days ago)๐Ÿ‘ 1

@SunMarc I am guilty as charged but we donโ€™t have transformers as a hard dependency for loading models in Diffusers. Pinging @DN6 to seek his opinion.

Update: Chatted with @DN6 as well. We think it's better to redefine inside diffusers without the transformers specific bits which we can clean in this PR.

sayakpaul Merge branch 'main' into quantization-config
e41b4949
sayakpaul Merge branch 'main' into quantization-config
dfb33eb2
sayakpaul Merge branch 'main' into quantization-config
e4926555
sayakpaul fix
6e86cc06
sayakpaul
sayakpaul266 days ago

@SunMarc I think this PR is ready for another review.

sayakpaul modules_to_not_convert
58a3d156
sayakpaul Merge branch 'main' into quantization-config
1d477f9a
SunMarc
SunMarc approved these changes on 2024-08-22
SunMarc265 days ago

Thanks for adding this @sayakpaul !

Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/base.py
177 kwargs (`dict`, *optional*):
178 The keyword arguments that are passed along `_process_model_before_weight_loading`.
179 """
180
model.is_quantized = True
181
model.quantization_method = self.quantization_config.quant_method
SunMarc265 days ago

Just to let you know, these were added here for backward compatibility with accelerate when using bnb : https://github.com/huggingface/accelerate/blob/ad3f574a3b091d2fcb469d48ca5e3a646eea120b/src/accelerate/big_modeling.py#L355

sayakpaul265 days ago

Regardless, I think having these two attributes won't be too bad no?

SunMarc265 days ago๐Ÿ‘ 1

Not at all ! I just wanted to prove some context

yiyixuxu
yiyixuxu commented on 2024-08-22
yiyixuxu265 days ago (edited 265 days ago)

I don't think it makes sense to have this as a separate PR to add a base class because it's hard to understand what methods are needed - we should only introduce a minimum base class and gradually add functionalities as needed

can we have a PR with a minimum example working?

sayakpaul
sayakpaul265 days ago (edited 265 days ago)

Okay, so, do you want me to add everything needed for bitsandbytes integration in this PR? But do note that this wonโ€™t be very different from what we have in transformers.

yiyixuxu
yiyixuxu265 days ago (edited 265 days ago)

@sayakpaul
I think so because:

  1. it is better to review that way
  2. we don't need this class in diffusers on its own because it cannot be used yet, no?
bghira
bghira265 days ago

sometimes we can make a feature branch where a bunch of PRs can be merged into before one big honkin' PR is pushed to main at the end. and the pieces are all individually reviewed and can be tested. is this a viable approach for including quantisation?

sayakpaul Merge branch 'main' into quantization-config
bd7f46d5
sayakpaul
sayakpaul265 days agoโค 3

Okay I will update this branch. @yiyixuxu

SunMarc
SunMarc264 days ago (edited 264 days ago)

cc @MekkCyber for visibility

DN6
DN6259 days ago๐Ÿ‘ 1

Just a few considerations for the quantization design.

I would say the initial design should start loading/inference at just the model level and then proceed to add functionality (pipeline level loading etc).

The feature needs to perform the following functions

  1. Perform on the fly quantization of large models so that they can be loaded in a low-memory dtype
    1. with from_pretrained
    2. with from_single_file
  2. Dynamically upcast to the appropriate compute dtype when running inference
  3. Save/Load already quantized versions of these large models (FP8, NF4)
  4. Allow loading/inference with LoRAs in these quantized models. (This we have to figure out in more detail)

At the moment, the most common ask seems to be the ability to load models into GPU using the FP8 dtype and run inference in a supported dtype by dynamically upcasting the necessary layers. NF4 is another format that's gaining attention.

So perhaps we should focus on this first. This mostly applies to the DiT models but large models like CogVideo might also benefit with this approach.

Some example quantized versions of models that have been doing the rounds

To cover these initial cases, we can rely on Quanto (FP8) and BitsandBytes (NF4).

Example API:

from diffusers import FluxPipeline, FluxTransformer2DModel, DiffusersQuantoConfig

# Load model in FP8 with Quanto and perform compute in configured dtype. 

quantization_config = DiffusersQuantoConfig(weights="float8", compute_dtype=torch.bfloat16)

FluxTransformer2DModel.from_pretrained("<either diffusers format or quanto format weights>", quantization_config=quantization_config)

pipe = FluxPipeline.from_pretrained("...", transformer=transformer)

The quantization config should probably take the following arguments

DiffusersQuantoConfig(
	weights_dtype="", # dtype to store weights
	compute_dtype="", # dtype to perform inference
	skip_quantize_modules=["ResBlock"]
)

I think initially we can rely on the dynamic upcasting operations performed by Quanto and BnB under the hood to start and then expand on them if needed.

Some other considerations

  1. Since we have transformers models in diffusers that can also benefit from quantized loading, we might want to consider adding a Diffusers prefix to the quantization configs. e.g DiffusersQuantoConfig so that when we import quantization configs from transformers there aren't any conflicts.
  2. For saving and loading models we can start with models saved in Quanto/BnB format.
  3. One possible challenge with Pipeline level quantized loading is that we have a mix of transformers/diffusers models. So a single config to quantize/load both types might not be possible.
  4. Single file loading has it's own set of issues, such as dealing with checkpoints that have been naively quantized. This applies to some of the Flux single file checkpoints. e.g. safetensors.torch.save_file(model.to(torch.float8_e4m3fn), "model-fp8.safetensors) and loading full pipeline single file checkpoints. But we can address these later.
sayakpaul
sayakpaul259 days ago (edited 241 days ago)

This PR will be at the model-level itself. And we should not add multiple backends in a single PR. This PR aims to add bitsandbytes. We can do other backends taking this PR as a reference. I would like us to mutually agree on this before I start making progress on this PR.

Concretely, I would like to stick to the outline of the changes laid out in #9174 (along with anything related) for this PR.

The feature needs to perform the following functions

I won't advocate doing all of that in a single PR because it makes things very hard to review. We would rather want to move faster with something more minimal, confirming their effectiveness.

Allow loading/inference with LoRAs in these quantized models. (This we have to figure out in more detail)

Well, note that if the underlying LoRA wasn't trained with the base quantization precision, it might not perform as expected.

So perhaps we should focus on this first. This mostly applies to the DiT models but large models like CogVideo might also benefit with this approach.

Please note that bitsandbytes related quantization mostly applies to nn.linear whereas quanto is broader in their scopes (i.e, quanto can be applied to an nn.Conv2D as well).

DN6
DN6259 days ago๐Ÿ‘ 1

This PR will be at the model-level itself. And we should not add multiple backends in a single PR. This PR aims to add bitsandbytes. We can do other backends taking this PR as a reference. I would like us to mutually agree on this before I start making progress on this PR.

Sounds good to me.

For this PR lets do

  1. from_pretrained only
  2. bnb quantization.
sayakpaul Merge branch 'main' into quantization-config
d5d7bb69
sayakpaul Merge branch 'main' into quantization-config
44c8a751
sayakpaul add bitsandbytes utilities.
6a0fcdc2
sayakpaul make progress.
e4590fa7
sayakpaul Merge branch 'main' into quantization-config
77a14389
sayakpaul fixes
335ab6bd
sayakpaul quality
d44ef851
sayakpaul up
210fa1e5
sayakpaul sayakpaul marked this pull request as draft 258 days ago
sayakpaul up
f4feee1d
sayakpaul sayakpaul force pushed from b77f70f8 to f4feee1d 258 days ago
sayakpaul Merge branch 'main' into quantization-config
e8c17224
sayakpaul Merge branch 'main' into quantization-config
7f86a71a
sayakpaul minor
ba671b62
sayakpaul up
c1a9f13b
sayakpaul Merge branch 'main' into quantization-config
4489c544
sayakpaul up
f2ca5e26
sayakpaul fix
d6b89542
sayakpaul sayakpaul changed the title [Quantization] Add quantization config base class [Quantization] Add quantization support for `bitsandbytes` 258 days ago
sayakpaul
sayakpaul commented on 2024-08-30
src/diffusers/models/modeling_utils.py
128131 _supports_gradient_checkpointing = False
129132 _keys_to_ignore_on_load_unexpected = None
130133 _no_split_modules = None
134
_keep_in_fp32_modules = []
sayakpaul258 days ago

We have to introduce this attribute now that we're seriously entering the diffusion territory.

sayakpaul provide credits where due.
45029e26
chuck-ma
chuck-ma257 days ago (edited 257 days ago)

If i load lora after quantization, it throws errors:

ValueError: .to is not supported for 4-bit or 8-bit bitsandbytes models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct dtype.

pipe.load_lora_weights(
            hf_hub_download(repo_name, ckpt_name, adapter_name="ckpt_name"
        )
sayakpaul
sayakpaul257 days ago (edited 257 days ago)

@chuck-ma there's a reason why this PR is still in draft :) We will consider these use cases a bit later in the pipeline.

However, you are welcome to try out basic functionalities like loading and saving without LoRAs.

chuck-ma
chuck-ma257 days ago

@chuck-ma there's a reason why this PR is still in draft :) We will consider these use cases a bit later in the pipeline.

However, you are welcome to try out basic functionalities like loading and saving without LoRAs.

Gotcha, thanks.

sayakpaul make configurations work.
4eb468ad
sayakpaul fixes
939965de
sayakpaul
sayakpaul commented on 2024-08-30
Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
193 "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
194 )
195
196
if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
197
param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
198
):
sayakpaul257 days ago

@SunMarc

Do we have to additionally pop these keys out? Because when a quantized checkpoint is loaded, it complains about

Some weights of the model checkpoint were not used when initializing FluxTransformer2DModel: 
 ['context_embedder.weight.absmax, context_embedder.weight.quant_map, context_embedder.weight.quant_state.bitsandbytes__nf4, norm_out.linear.weight.absmax, norm_out.linear.weight.quant_map, norm_out.linear.weight.quant_state.bitsandbytes__nf4, proj_out.weight.absmax, proj_out.weight.quant_map, proj
...
sayakpaul257 days ago

I know transformers handles this sophisticatedly in:
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py

I imagine we could also do something similar here:

if param_name not in empty_state_dict:

But I don't have a clear idea on how we could leverage hf_quantizer and filter out stats related keys in a clean manner. Guidance would be appreciated.

sayakpaul
sayakpaul commented on 2024-08-30
Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/quantization_config.py
348 else:
349 return None
350
351
def to_dict(self) -> Dict[str, Any]:
sayakpaul257 days ago (edited 257 days ago)

Getting when a quantized checkpoint is loaded:

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'diffusers.quantizers.quantization_config.BitsAndBytesConfig'>.

I can, of course, brute-force delete this in to_dict(), but wanted to know a better way. Also, I see this happening in transformers main, too. Is that expected?

SunMarc254 days ago (edited 254 days ago)

The simplest solution would be to just ignore these kwargs in the init. You need to keep quant_method in to_dict() as this will be useful to recover the quantization scheme used in the serialized model. We need to fix this in transformers too.

sayakpaul254 days ago

@SunMarc trying to think out loud here.

The simplest solution would be to just ignore these kwargs in the init.

Does the following sound good?

  • We _attributes_to_be_ignored_at_init to the base quantization config class.
  • We override it in the bitsandbytes config class and have it as: ["_load_in_4bit", "_load_in_8bit", "quant_method"].
  • And then during __init__() we check in the kwargs if these are present.

Does that sound good or did you have a separate approach in mind?

SunMarc254 days ago

Sounds good to me !

sayakpaul254 days ago

abc8607 should have done it.

sayakpaul Merge branch 'main' into quantization-config
85571660
sayakpaul fix
d098d073
sayakpaul update_missing_keys
c4a00749
sayakpaul fix
ee45612c
chuck-ma
chuck-ma257 days ago (edited 257 days ago)

I got this error:

TypeError Traceback (most recent call last)
File ~/autodl-tmp/diffusers/src/diffusers/models/model_loading_utils.py:134, in load_state_dict(checkpoint_file, variant)
133 try:
--> 134 file_extension = os.path.basename(checkpoint_file).split(".")[-1]
135 if file_extension == SAFETENSORS_FILE_EXTENSION:

File ~/miniconda3/lib/python3.10/posixpath.py:142, in basename(p)
141 """Returns the final component of a pathname"""
--> 142 p = os.fspath(p)
143 sep = _get_sep(p)

TypeError: expected str, bytes or os.PathLike object, not NoneType

During handling of the above exception, another exception occurred:

TypeError Traceback (most recent call last)
Cell In[7], line 11
4 model_id = "black-forest-labs/FLUX.1-dev"
6 nf4_config = BitsAndBytesConfig(
7 load_in_4bit=True,
8 bnb_4bit_quant_type="nf4",
9 bnb_4bit_compute_dtype=torch.bfloat16,
10 )
---> 11 model_nf4 = FluxTransformer2DModel.from_pretrained(
12 model_id, subfolder="transformer", quantization_config=nf4_config
13 )
14 print(model_nf4.dtype)
15 print(model_nf4.quantization_config)

File ~/miniconda3/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py:114, in validate_hf_hub_args.._inner_fn(*args, **kwargs)
111 if check_use_auth_token:
112 kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.name, has_token=has_token, kwargs=kwargs)
--> 114 return fn(*args, **kwargs)

File ~/autodl-tmp/diffusers/src/diffusers/models/modeling_utils.py:817, in ModelMixin.from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
815 if device_map is None and not is_sharded or (hf_quantizer is not None):
816 param_device = "cpu"
--> 817 state_dict = load_state_dict(model_file, variant=variant)
818 model._convert_deprecated_attention_blocks(state_dict)
819 # move the params from meta device to cpu

File ~/autodl-tmp/diffusers/src/diffusers/models/model_loading_utils.py:146, in load_state_dict(checkpoint_file, variant)
144 except Exception as e:
145 try:
--> 146 with open(checkpoint_file) as f:
147 if f.read().startswith("version"):
148 raise OSError(
149 "You seem to have cloned a repository without having git-lfs installed. Please install "
150 "git-lfs and run git lfs install followed by git lfs pull in the folder "
151 "you cloned."
152 )

TypeError: expected str, bytes or os.PathLike object, not NoneType

import torch 
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel

model_id = "black-forest-labs/FLUX.1-dev"

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model_nf4 = FluxTransformer2DModel.from_pretrained(
    model_id, subfolder="transformer", quantization_config=nf4_config
)
print(model_nf4.dtype)
print(model_nf4.quantization_config)

if i don't use this, everything is just fine:

import torch 
from diffusers import FluxTransformer2DModel


model_id = "black-forest-labs/FLUX.1-dev"


model_nf4 = FluxTransformer2DModel.from_pretrained(
    model_id, subfolder="transformer", 
)
print(model_nf4.dtype)
image
sayakpaul fix
b24c0a7a
sayakpaul make it work.
473505ca
sayakpaul fix
c795c82d
sayakpaul Merge branch 'main' into quantization-config
c1d5b966
sayakpaul
sayakpaul256 days ago

@chuck-ma do you wanna give this a try now?

sayakpaul provide credits to transformers.
af7cacaf
lonngxiang
lonngxiang256 days ago (edited 256 days ago)

run error:ImportError: Using bitsandbytes 4-bit quantization requires the latest version of bitsandbytes: pip install -U bitsandbytes


5221a1732bec979e695c7fec7e50c63

bitsandbytes 0.43.3

sayakpaul
sayakpaul256 days ago (edited 256 days ago)

@lonngxiang

Maybe there's something wrong with your setup but I am able to load things without any issues:
https://colab.research.google.com/gist/sayakpaul/1bfcf2441f73364bb06c801f58303cd5/scratchpad.ipynb

I see you also confirmed it's working here.

sayakpaul empty commit
80967f5e
sayakpaul sayakpaul requested a review from SunMarc SunMarc 256 days ago
sayakpaul handle to() better.
3bdf25a7
sayakpaul tests
27415cc1
sayakpaul change to bnb from bitsandbytes
51cac09a
chuck-ma
chuck-ma254 days ago๐Ÿ‘ 1

@chuck-ma do you wanna give this a try now?

Now it works.

sayakpaul fix tests
15f30326
sayakpaul sayakpaul force pushed from 2e42d257 to 15f30326 254 days ago
sayakpaul
sayakpaul commented on 2024-09-02
src/diffusers/configuration_utils.py
527527
528528 # 4. Give nice warning if unexpected values have been passed
529 if len(config_dict) > 0:
529
only_quant_config_remaining = len(config_dict) == 1 and "quantization_config" in config_dict
sayakpaul254 days ago

Because quantization_config isn't a part of any model's __init__().

yiyixuxu243 days ago

I think it is better to not add to cofig_dict if it is not going into __init__, i.e. at line 511

 # remove private attributes
 config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
# remove quantization_config
 config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config")}
sayakpaul240 days ago

We cannot remove quantization_config from the config of a model as that would prevent loading of the quantized models via from_pretrained().

quantization_config isn't used for initializing a model, it's used to determine what kind of quantization configuration to inject inside the given model. This is why it's only used in from_pretrained() of ModelMixin.

LMK if you have a better idea to handle it.

yiyixuxu226 days ago๐Ÿ‘ 1

we do not remove them from the config, just not adding to the config_dict inside this extract_init_dict method: basically, the cofig_dict in this function goes through these steps:

  1. it is used to create init_dict: the quantisation config will not go there, so it is not affected if we do not add it to config_dict
  2. it is used to throw a warning after we createdinit_dict, if the quantisation configs were not there, we do not need to throw a warning for it
  3. it goes into unused_kwargs - so I think this is the only difference it would make, do we need the quantisation config to be in unused_kwargs returned by extract_init_dict? I think unused_kwargs is only used to send additional warnings for unexpected stuff, but since quantisation config is expected, and we have already decided not to send a warning here inside extract_init_dict - I think it does not need to go to the unused_kwargs here?
    @classmethod
    def extract_init_dict(cls, config_dict, **kwargs):
         ...
        config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"

        # remove private attributes
        config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
        
+      # remove quantization_config
+      config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config")}
        
        
        ## here we use config_dict to create `init_dict` which will be passed to `__init__` method
        init_dict = {}
        for key in expected_keys:
                 ...
                init_dict[key] = config_dict.pop(key)
-      only_quant_config_remaining = len(config_dict) == 1 and "quantization_config" in config_dict
-      if len(config_dict) > 0 and not only_quant_config_remaining:
+     if len(config_dict) > 0:
            logger.warning(
                f"The config attributes {config_dict} were passed to {cls.__name__}, "
                "but are not expected and will be ignored. Please verify your "
                f"{cls.config_name} configuration file."
            )
       ....
        # 6. Define unused keyword arguments
        unused_kwargs = {**config_dict, **kwargs}

        return init_dict, unused_kwargs, hidden_config_dict
sayakpaul216 days ago

Makes sense. Resolved in 555a5ae.

sayakpaul
sayakpaul commented on 2024-09-02
src/diffusers/models/model_loading_utils.py
173 keep_in_fp32_modules=None,
139174) -> List[str]:
140 device = device or torch.device("cpu")
175
device = device or torch.device("cpu") if hf_quantizer is None else device
sayakpaul254 days ago

More on this in the later changes.

sayakpaul
sayakpaul commented on 2024-09-02
Conversation is marked as resolved
Show resolved
src/diffusers/models/model_loading_utils.py
201 else:
202 param = param.to(dtype)
203
204
if not is_quantized and empty_state_dict[param_name].shape != param.shape:
sayakpaul254 days ago (edited 254 days ago)

Because quantized params can have a flattened shape (typical in bnb). Could make it more robust with hf_quantizer.quant_method == "bitsandbytes and ...

sayakpaul
sayakpaul commented on 2024-09-02
Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
685761 subfolder=subfolder or "",
686762 )
763 if hf_quantizer is not None:
764
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
765
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
sayakpaul254 days ago
Suggested change
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
sayakpaul better safeguard.
77c9fdb3
sayakpaul
sayakpaul commented on 2024-09-02
src/diffusers/models/model_loading_utils.py
202 else:
203 param = param.to(dtype)
204
205
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
206
if not is_quantized and not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape:
sayakpaul254 days ago

Because bnb quantized params are usually flattened.

sayakpaul
sayakpaul commented on 2024-09-02
src/diffusers/pipelines/pipeline_utils.py
4444from ..models import AutoencoderKL
4545from ..models.attention_processor import FusedAttnProcessor2_0
4646from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
47
from ..quantizers.bitsandbytes.utils import _check_bnb_status
sayakpaul254 days ago๐Ÿ‘ 1

Reason why I preferred not having _check_bnb_status() inline is because I imagine it'd be used across the library. So, didn't make sense to include it inside of a function.

BenjaminBossan
BenjaminBossan commented on 2024-09-02
BenjaminBossan254 days ago

Wow, big PR, great feature being added here.

I haven't done an in-depth review, but took a look at the parts related to PEFT and skimmed the rest.

With such a big change, it might be worth it to control the line coverage of the newly added tests to ensure that the new code is reasonably well covered.

Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
859970
860971 return model
861972
973
@wraps(torch.nn.Module.cuda)
BenjaminBossan254 days ago

These methods exist to align with transformers, right? If so, could be worth adding a comment.

sayakpaul254 days ago

I think these are generic w.r.t bitsandbytes. So, okay without IMO. I have provided courtesy to transformers where I had to copy things over or modify throughout this PR.

sayakpaul254 days ago

Update: done in 44c4109.

src/diffusers/models/model_loading_utils.py
173 keep_in_fp32_modules=None,
139174) -> List[str]:
140 device = device or torch.device("cpu")
175
device = device or torch.device("cpu") if hf_quantizer is None else device
BenjaminBossan254 days ago

Not specific to this PR but device = device or torch.device("cpu") is a bit dangerous because theoretically, 0 is a valid device but it would be considered falsy. AFAICT it's not problematic for the existing code, but something to keep in mind.

sayakpaul254 days ago

Indeed.

sayakpaul251 days ago

I have added a comment about it too.

Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
976 if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
977 raise ValueError(
978 "Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the"
979
" model has already been set to the correct devices and casted to the correct `dtype`."
BenjaminBossan254 days ago

I know it's copied from transformers, but "casted" is incorrect and should be "cast".

Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
1207 total_numel = []
1208 is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
1209
1210
if is_loaded_in_4bit:
BenjaminBossan254 days ago

I'd always try to move all checks to the start of the function, even if this specific one is extremely unlikely to ever fail.

sayakpaul254 days ago

Done in 27666a8.

Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/auto.py
124 warning_msg = (
125 "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading"
126 " already has a `quantization_config` attribute. The `quantization_config` from the model will be used."
127
)
BenjaminBossan254 days ago

As the warning message is not extended, you could warn here directly.

sayakpaul254 days ago

Not sure. It's under an if/else. I think it's easier to read in the current way.

Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/utils.py
1
"""
BenjaminBossan254 days ago

Missing copyright notice.

sayakpaul254 days ago

Done 3464d83

Conversation is marked as resolved
Show resolved
src/diffusers/utils/loading_utils.py
137137 return pil_images
138
139
140
def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
BenjaminBossan254 days ago

Add comment that this is copied from transformers.

sayakpaul254 days ago

Done 3464d83

Conversation is marked as resolved
Show resolved
tests/quantization/bnb/test_4bit.py
233 self.assertTrue("The module 'SD3Transformer2DModel' has been loaded in `bitsandbytes` 4bit" in cap_logger.out)
234
235
236
class SlowBnb4BitTests(Base4bitTests):
BenjaminBossan254 days ago

Don't all tests that inherit from Base4bitTests also get the @slow marker? What makes these specific tests "slow"?

sayakpaul254 days ago

Combination of different reasons.

1> Big checkpoints
2> IO overhead (saving and loading)
3> Partial inference

BenjaminBossan254 days ago

What I wanted to get at is that for CI purposes, having these tests be separate makes no difference as the BnB4BitBasicTests tests also have the slow marker, right? So this distinction is more for devs who run the tests manually?

sayakpaul254 days ago

Yeah this is more of a logical separation following the structure of https://github.com/huggingface/transformers/blob/main/tests/quantization/bnb/test_4bit.py.

Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/utils.py
175 threshold=quantization_config.llm_int8_threshold,
176 )
177 has_been_replaced = True
178
else:
179
if (
180
quantization_config.llm_int8_skip_modules is not None
181
and name in quantization_config.llm_int8_skip_modules
182
):
183
pass
184
else:
BenjaminBossan254 days ago

Could this not be turned into a single elif not (quantization_config.llm_int8_skip_modules is not None and name in quantization_config.llm_int8_skip_modules):?

sayakpaul254 days ago

I prefer granular conditions like this better. The breakpoints become immediately clearer to me instead of combinations.

Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/utils.py
214 return model, has_been_replaced
215
216
217
def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
BenjaminBossan254 days ago

Is this function (and _replace_with_bnb_linear) copied from somewhere? If not, might this warrant its own unit tests?

sayakpaul254 days ago๐Ÿ‘ 1

We're testing that in the test suite already. For example, if it was not applied correctly, test_linear_are_4bit() would have failed.

Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/utils.py
216
217def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
218 """
219
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
BenjaminBossan254 days ago

or Linear4bit.

sayakpaul254 days ago

Done 3464d83

src/diffusers/quantizers/bitsandbytes/utils.py
318 """
319 Converts a quantized model into its dequantized original version. The newly converted model will have some
320 performance drop compared to the original model before quantization - use it only for specific usecases such as
321
QLoRA adapters merging.
BenjaminBossan254 days ago

Note that PEFT supports merging into bnb weights, so that alone would not require dequantizing the weights entirely.

sayakpaul254 days ago

Noted. I guess not immediately relevant for this PR?

SunMarc252 days ago๐Ÿ‘ 1

I think it is still interesting to let users have a way to dequantize their models.

sayakpaul change merging status
ddc9f293
sayakpaul courtesy to transformers.
44c41099
sayakpaul move upper.
27666a8d
sayakpaul better
3464d837
sayakpaul Merge branch 'main' into quantization-config
b106124a
chuck-ma
chuck-ma254 days ago (edited 254 days ago)

Thanks for your efforts. I think it's better if we can load the transformer that has been quantized instead of quantizing the transformer every time we load it. @sayakpaul

sayakpaul
sayakpaul254 days ago

I think it's better if we can load the transformer that has been quantized instead of quantizing the transformer every time we load it. @sayakpaul

Possible now:

from diffusers import FluxTransformer2DModel

model_id = "sayakpaul/flux.1-dev-nf4-pkg"
model_nf4 = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")

Not sure what made you think it's not possible.

sayakpaul Merge branch 'main' into quantization-config
330fa0af
chuck-ma
chuck-ma254 days ago

I think it's better if we can load the transformer that has been quantized instead of quantizing the transformer every time we load it. @sayakpaul

Possible now:

from diffusers import FluxTransformer2DModel

model_id = "sayakpaul/flux.1-dev-nf4-pkg"
model_nf4 = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")

Not sure what made you think it's not possible.

import torch 
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel

model_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16


pipe.transformer.save_pretrained(flux_transformer_id)

model_nf4 = FluxTransformer2DModel.from_pretrained(
    flux_transformer_id,
    
    
    # quantization_config=nf4_config,
    
    torch_dtype=dtype,
)

I just got an error:(no matter if i use quantization_config)

ValueError: Cannot load <class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> from /root/autodl-tmp/flux_transformer because the following keys are missing:
transformer_blocks.11.ff_context.net.0.proj.weight, transformer_blocks.3.attn.add_k_proj.bias, transformer_blocks.6.ff_context.net.2.weight, single_transformer_blocks.27.attn.to_k.bias, transformer_blocks.15.attn.to_out.0.weight, single_transformer_blocks.35.proj_out.bias, time_text_embed.guidance_embedder.linear_1.bias, transformer_blocks.2.attn.to_add_out.bias, transformer_blocks.3.attn.add_v_proj.weight, transformer_blocks.6.ff.net.2.weight,

sayakpaul
sayakpaul254 days ago

@chuck-ma please try to follow the Colab Notebooks provided in https://hf.co/sayakpaul/flux.1-dev-nf4-pkg. All of them show the correct usage and run without any errors. And when you're facing errors, please try to provide Colab Notebooks so I can verify things. Otherwise, it's hard for me to reproduce errors. Could we do that?

chuck-ma
chuck-ma254 days ago (edited 254 days ago)

I think it's better if we can load the transformer that has been quantized instead of quantizing the transformer every time we load it. @sayakpaul

Possible now:

from diffusers import FluxTransformer2DModel

model_id = "sayakpaul/flux.1-dev-nf4-pkg"
model_nf4 = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")

Not sure what made you think it's not possible.

from diffusers import FluxPipeline
import torch
from huggingface_hub import hf_hub_download

repo_name = "ByteDance/Hyper-SD"
ckpt_16steps_name = "Hyper-FLUX.1-dev-8steps-lora.safetensors"


create_fuse_checkp = True

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=dtype,
)

if create_fuse_checkp:
    model_nf4 = FluxTransformer2DModel.from_pretrained(
    # flux_transformer_id,
    # flux_transformer_id,
    model_id, 
    subfolder="transformer", 
    quantization_config=nf4_config,
    torch_dtype=dtype,
)
    pipe.transformer = model_nf4





    pipe.load_lora_weights(
            hf_hub_download(repo_name, ckpt_16steps_name), adapter_name="8_steps_lora"
    )
    pipe.fuse_lora(lora_scale=0.125)
    pipe.transformer.save_pretrained(flux_transformer_id)

If I merge lora and then save the transformer, it will get the error. Otherwise, everything is just fine.

sayakpaul
sayakpaul254 days ago๐Ÿ‘ 1

If I merge lora and then save the transformer, it will get the error. Otherwise, everything is just fine.

I told you here, that LoRA can be tried out later. So, please be aware of the expectations.

chuck-ma
chuck-ma254 days ago

If I merge lora and then save the transformer, it will get the error. Otherwise, everything is just fine.

I told you here, that LoRA can be tried out later. So, please be aware of the expectations.

OK. Because I saw that your latest code can support merging lora after loading nf4, I wanted to try whether it is possible to save and load after merging lora.
Anyway, nice job.

bghira
bghira254 days ago๐Ÿ‘ 1

i think it'll be different if you quantise after merge.

SunMarc
SunMarc commented on 2024-09-02
Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
861978
979 # Taken from `transformers`.
980 @wraps(torch.nn.Module.cuda)
981
def cuda(self, *args, **kwargs):
982
# Checks if the model has been loaded in 8-bit
983
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
984
raise ValueError(
985
"Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the"
986
" model has already been set to the correct devices and cast to the correct `dtype`."
987
)
988
else:
989
return super().cuda(*args, **kwargs)
990
991
# Taken from `transformers`.
992
@wraps(torch.nn.Module.to)
993
def to(self, *args, **kwargs):
994
# Checks if the model has been loaded in 8-bit
995
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
996
raise ValueError(
997
"`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the"
998
" model has already been set to the correct devices and cast to the correct `dtype`."
999
)
1000
return super().to(*args, **kwargs)
SunMarc254 days ago (edited 254 days ago)๐Ÿ‘ 1

We recently merged this PR: huggingface/transformers#33122 to remove the to() and cuda() restriction on 4-bit models.

sayakpaul254 days ago

Resolved in 31725aa

sayakpaul make the unused kwargs warning friendlier.
abc86070
sayakpaul harmonize changes with https://github.com/huggingface/transformers/puโ€ฆ
31725aa2
sayakpaul style
e5938a63
sayakpaul trainin tests
444588f9
sayakpaul Merge branch 'main' into quantization-config
d3360ce8
SunMarc
SunMarc commented on 2024-09-03
SunMarc253 days ago

Reviewed half of the PR ! I will do the rest soon but since it is mainly config, there shouldn't be any big blockers. Thanks for the PR @sayakpaul ! I left a few comments

Conversation is marked as resolved
Show resolved
src/diffusers/models/model_loading_utils.py
157211
158 if accepts_dtype:
159 set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
212 if (
213 not is_quantized
214
or (not hf_quantizer.requires_parameters_quantization)
SunMarc253 days ago

Not needed anymore since we removed it

Suggested change
or (not hf_quantizer.requires_parameters_quantization)
Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
830 )
831 if hf_quantizer is None:
832 param_device = "cpu"
833
elif is_quant_method_bnb:
834
param_device = torch.cuda.current_device()
SunMarc253 days ago

After the loading refactor, we can safely remove this. This is something that should be modified in update_device_map.

sayakpaul252 days ago (edited 252 days ago)

I have made a comment noting this.

Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
398400 )
399 if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
401 pipeline_has_bnb_quant = any(_check_bnb_status(module)[0] for _, module in self.components.items())
402 if (
403
not pipeline_has_bnb_quant
SunMarc253 days ago

With are we considering bnb quant here ? With 4-bit model, we can move them to the gpu then move back to the cpu with we want to save gpu memory.

sayakpaul252 days ago

Good point. Let me try to instead condition this on 8bit loading.

Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
435 if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
426436 logger.warning(
427 f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision."
437
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {precision} precision. In most cases, it is recommended to not change the precision."
SunMarc253 days ago

I don't think there is a case where we suggest to change dtype for quantized model

Suggested change
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {precision} precision. In most cases, it is recommended to not change the precision."
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {precision} precision. "
sayakpaul252 days ago

100 percent.

Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
440 if is_loaded_in_8bit_bnb and device is not None:
431441 logger.warning(
432 f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
442
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and moving it to {device} via `.to()` is not supported. Module is still on {module.device}. In most cases, it is recommended to not change the device."
SunMarc253 days ago

Same comment

Suggested change
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and moving it to {device} via `.to()` is not supported. Module is still on {module.device}. In most cases, it is recommended to not change the device."
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
10121032 if not isinstance(model, torch.nn.Module):
10131033 continue
10141034
1035
# This is because the model would already be placed on a CUDA device.
1036
if is_loaded_in_8bit_bnb: # is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb:
1037
logger.info(
1038
f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
1039
)
1040
continue
1041
SunMarc253 days ago

we should also skip for 4-bit when .to() is not supported

sayakpaul252 days ago

Not sure if I understand. I have bitsandbytes to be installed with at least 0.43.3:

if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"):

So, I guess 4bit model offloading here should be possible?

SunMarc252 days ago

Oh nice then !

Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
10091025 hook = None
10101026 for model_str in self.model_cpu_offload_seq.split("->"):
10111027 model = all_model_components.pop(model_str, None)
1028
is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = False, False
1029
if model is not None and isinstance(model, torch.nn.Module):
1030
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(model)
1031
10121032
if not isinstance(model, torch.nn.Module):
10131033
continue
10141034
1035
# This is because the model would already be placed on a CUDA device.
1036
if is_loaded_in_8bit_bnb: # is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb:
1037
logger.info(
1038
f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
1039
)
1040
continue
1041
SunMarc253 days ago

As long as we can move the model (which is the case for most quantization scheme ), we can use cpu_offload_with_hook. So it makes sense to consider bnb seperately.

sayakpaul252 days ago๐Ÿ‘ 1

I guess this is already addressed? is_loaded_in_8bit_bnb is indeed very specific to bitsandbytes or am I missing something?

SunMarc252 days agoโค 1

yes, I just wanted to point that out

sayakpaul Merge branch 'main' into quantization-config
d8b35f46
sayakpaul
sayakpaul253 days ago

but since it is mainly config, there shouldn't be any big blockers.

Did you mean your comments are associated to configs? Understood if that is the case but this PR attempts at adding full-fledged BnB loading support, not just configs.

sayakpaul Merge branch 'main' into quantization-config
859f2d76
sayakpaul
sayakpaul commented on 2024-09-04
Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
987 "Calling `cuda()` is not supported for `8-bit` quantized models. "
988 " Please use the model as it is, since the model has already been set to the correct devices."
989 )
990
elif is_bitsandbytes_version("<", "0.43.2"):
991
raise ValueError(
992
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
993
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
994
)
sayakpaul252 days ago

I have taken the liberty to pin bitsandbytes version to 0.43.3 as we're introducing it for the first time. So, if that is the case, I think we can safely remove these?

if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"):

SunMarc252 days ago

Yes !

sayakpaul
sayakpaul commented on 2024-09-04
Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/utils.py
105 )
106
107 if is_8bit:
108
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
sayakpaul252 days ago

I guess we can safely remove these version checks if we were to ship the first iteration to always install the latest bnb?

SunMarc252 days ago

yes !

sayakpaul feedback part i.
3b2d6e13
SunMarc
SunMarc commented on 2024-09-04
SunMarc252 days ago

Thanks for this huge work @sayakpaul ! I'll one final review after you added the 8bit tests. This looks very good. I left a few minor comments

Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
55 """
56
57 use_keep_in_fp32_modules = True
58
requires_parameters_quantization = True
SunMarc252 days ago

To remove

Suggested change
requires_parameters_quantization = True
Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
104 )
105
106 def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
107
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
SunMarc252 days ago

In validate_environment, we are already asking for accelerate >= 0.26, so we don't need this condition

Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
58 requires_parameters_quantization = True
59 requires_calibration = False
60
61
required_packages = ["bitsandbytes", "accelerate"]
SunMarc252 days ago

We don't really make use of this attribute (same in transformers ...), so we can remove it also

Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
234 torch_dtype = torch.float16
235 return torch_dtype
236
237
# (sayakpaul): I think it could be better to disable custom `device_map`s
238
# for the first phase of the integration in the interest of simplicity.
239
# Commenting this for discussions on the PR.
240
# def update_device_map(self, device_map):
241
# if device_map is None:
242
# device_map = {"": torch.cuda.current_device()}
243
# logger.info(
244
# "The device_map was not initialized. "
245
# "Setting device_map to {'':torch.cuda.current_device()}. "
246
# "If you want to use the model for inference, please set device_map ='auto' "
247
# )
248
# return device_map
SunMarc252 days agoโค 1

Agreed ! I will add it back when I finished the loading refactor

Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
296
297 @property
298 def is_serializable(self):
299
_is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3")
SunMarc252 days ago

You can just return True in diffusers and leave a comment

Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
329 return model
330
331
332
class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
333
"""
334
8-bit quantization from bitsandbytes quantization method:
335
before loading: converts transformer layers into Linear8bitLt during loading: load 16bit weight and pass to the
336
layer object after: quantizes individual weights in Linear8bitLt into 8bit at fitst .cuda() call
337
saving:
338
from state dict, as usual; saves weights and 'SCB' component
339
loading:
340
need to locate SCB component and pass to the Linear8bitLt object
341
"""
342
343
use_keep_in_fp32_modules = True
344
requires_parameters_quantization = True
345
requires_calibration = False
346
347
required_packages = ["bitsandbytes", "accelerate"]
348
349
def __init__(self, quantization_config, **kwargs):
350
super().__init__(quantization_config, **kwargs)
351
SunMarc252 days ago

same comment as above for this class

Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/utils.py
42logger = logging.get_logger(__name__)
43
44
45
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, quantized_stats=None):
46
"""
47
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
48
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
49
function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the
50
class `Int8Params` from `bitsandbytes`.
51
52
Args:
53
module (`torch.nn.Module`):
54
The module in which the tensor we want to move lives.
55
tensor_name (`str`):
56
The full name of the parameter/buffer.
57
device (`int`, `str` or `torch.device`):
58
The device on which to set the tensor.
59
value (`torch.Tensor`, *optional*):
60
The value of the tensor (useful when going from the meta device to any other device).
61
quantized_stats (`dict[str, Any]`, *optional*):
62
Dict with items for either 4-bit or 8-bit serialization
63
"""
64
# Recurse if needed
65
if "." in tensor_name:
66
splits = tensor_name.split(".")
67
for split in splits[:-1]:
68
new_module = getattr(module, split)
SunMarc252 days ago

This function can be removed entirely. This was the old version of create_quantized_params.

SunMarc
SunMarc252 days agoโค 1

Did you mean your comments are associated to configs? Understood if that is the case but this PR attempts at adding full-fledged BnB loading support, not just configs.

I was talking to the rest of the PR but yeah there were also the quantizer + tests. I went through the entire PR now

itsyourlad
itsyourlad251 days ago๐Ÿ‘ 1

I know Lora's aren't a part of this PR but what's the current situation with loading in Lora's on top of nf4 model without dequantize?

It looks like converting to PEFT would work? I assume there is a way to load on top of the quantized model without having to dequantize it because Forge doesn't seem to be doing this.

I figure this place is the best area to leave comment but sorry if slightly off topic of PR.

sayakpaul
sayakpaul251 days ago

@itsyourlad no worries. After this PR, the next plan is to add a training script showing how to do LoRAs on NF4 models with PEFT. So, stay tuned.

sayakpaul sayakpaul requested a review from stevhliu stevhliu 250 days ago
sayakpaul sayakpaul marked this pull request as ready for review 250 days ago
sayakpaul
sayakpaul250 days ago๐Ÿ‘ 1

@SunMarc ready for another round of review.

@stevhliu could you help review the docs?

Gothos Add Flux inpainting and Flux Img2Img (#9135)
5799954d
sayakpaul sayakpaul force pushed from 758d552f to 5799954d 250 days ago
sayakpaul Revert "Add Flux inpainting and Flux Img2Img (#9135)"
8e4bd089
sayakpaul tests
835d4add
sayakpaul don
27075fee
sayakpaul Merge branch 'main' into quantization-config
5c00c1c1
sayakpaul sayakpaul requested a review from SunMarc SunMarc 250 days ago
SunMarc
SunMarc approved these changes on 2024-09-06
SunMarc250 days ago

Thanks for iterating @sayakpaul ! LGTM ! It's nice to finally quantization integrated in diffusers !

docs/source/en/quantization/overview.md
32
33## When to use what?
34
35
This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
SunMarc250 days agoโค 2

Yes I think it will be nice to also have a table directly in this doc in the future

sayakpaul
sayakpaul250 days ago

@yiyixuxu this is ready for your review.

sayakpaul Merge branch 'main' into quantization-config
5d633a03
stevhliu
stevhliu approved these changes on 2024-09-09
stevhliu247 days ago

Thanks, this looks really good! ๐Ÿ”ฅ

Conversation is marked as resolved
Show resolved
docs/source/en/api/quantization.md
13
14# Quantization
15
16
Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [`bitsandbytes`](https://github.com/bitsandbytes-foundation/bitsandbytes).
stevhliu247 days ago
Suggested change
Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [`bitsandbytes`](https://github.com/bitsandbytes-foundation/bitsandbytes).
Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index).
Conversation is marked as resolved
Show resolved
docs/source/en/quantization/bitsandbytes.md
13
14# bitsandbytes
15
16
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes) is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance. 4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs.
stevhliu247 days ago
Suggested change
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes) is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance. 4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs.
[bitsandbytes](https://huggingface.co/docs/bitsandbytes/index) is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance.
4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs.
Conversation is marked as resolved
Show resolved
docs/source/en/quantization/bitsandbytes.md
22pip install diffusers transformers accelerate bitsandbytes -U
23```
24
25
Now you can quantize a model by passing a `BitsAndBytesConfig` to [`~ModelMixin.from_pretrained`] method. This works for any model in any modality, as long as it supports loading with Accelerate and contains `torch.nn.Linear` layers.
stevhliu247 days ago
Suggested change
Now you can quantize a model by passing a `BitsAndBytesConfig` to [`~ModelMixin.from_pretrained`] method. This works for any model in any modality, as long as it supports loading with Accelerate and contains `torch.nn.Linear` layers.
Now you can quantize a model by passing a [`BitsAndBytesConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
Conversation is marked as resolved
Show resolved
docs/source/en/quantization/bitsandbytes.md
57model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype
58```
59
60
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization config.json file is pushed first, followed by the quantized model weights.
stevhliu247 days ago
Suggested change
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization config.json file is pushed first, followed by the quantized model weights.
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights.
Conversation is marked as resolved
Show resolved
docs/source/en/quantization/bitsandbytes.md
104model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype
105```
106
107
You can simply call `model.push_to_hub()` after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with `model.save_pretrained()` command.
stevhliu247 days ago
Suggested change
You can simply call `model.push_to_hub()` after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with `model.save_pretrained()` command.
Call [`~ModelMixin.push_to_hub`] after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
Conversation is marked as resolved
Show resolved
docs/source/en/quantization/bitsandbytes.md
115
116</Tip>
117
118
You can check your memory footprint with the `get_memory_footprint` method:
stevhliu247 days ago
Suggested change
You can check your memory footprint with the `get_memory_footprint` method:
Check your memory footprint with the `get_memory_footprint` method:
Conversation is marked as resolved
Show resolved
docs/source/en/quantization/bitsandbytes.md
141
142</Tip>
143
144
This section explores some of the specific features of 8-bit models, such as utlier thresholds and skipping module conversion.
stevhliu247 days ago
Suggested change
This section explores some of the specific features of 8-bit models, such as utlier thresholds and skipping module conversion.
This section explores some of the specific features of 8-bit models, such as outlier thresholds and skipping module conversion.
Conversation is marked as resolved
Show resolved
docs/source/en/quantization/bitsandbytes.md
165
166### Skip module conversion
167
168
For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), the `proj_out` module that could be skipped using the `llm_int8_skip_modules` parameter in [`BitsAndBytesConfig`]:
stevhliu247 days ago
Suggested change
For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), the `proj_out` module that could be skipped using the `llm_int8_skip_modules` parameter in [`BitsAndBytesConfig`]:
For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), the `proj_out` module can be skipped using the `llm_int8_skip_modules` parameter in [`BitsAndBytesConfig`]:
Conversation is marked as resolved
Show resolved
docs/source/en/quantization/bitsandbytes.md
227
228### Nested quantization
229
230
Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an addition 0.4 bits/parameter.
stevhliu247 days ago
Suggested change
Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an addition 0.4 bits/parameter.
Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter.
Conversation is marked as resolved
Show resolved
docs/source/en/quantization/overview.md
17
18<Tip>
19
20
Interested in adding a new quantization method to Transformers? Read the [`DiffusersQuantizer`](../conceptual/contribution) guide to learn how!
stevhliu247 days ago

Is this supposed to link to the Transformers (https://huggingface.co/docs/transformers/main/en/quantization/contribute) doc or the Diffusers contribution doc?

sayakpaul247 days ago (edited 247 days ago)

Contribution docs (the transformers contribution docs point to the regular transformers contributions docs, too).

stevhliu246 days ago

Ah ok, would it make more sense to link to the quantization contribution doc in Transformers then?

sayakpaul246 days ago

Ah actually yes! Could you maybe provide a suggestion?

stevhliu246 days ago

Refer to the Contribute new quantization method guide to learn more about adding a new quantization method.

sayakpaul246 days ago

Done in acdeb25.

sayakpaul Apply suggestions from code review
c381fe06
sayakpaul sayakpaul requested a review from yiyixuxu yiyixuxu 247 days ago
sayakpaul Merge branch 'main' into quantization-config
3c92878d
sayakpaul contribution guide.
acdeb254
sayakpaul Merge branch 'main' into quantization-config
aa295b72
sayakpaul Merge branch 'main' into quantization-config
7f7c9cec
sayakpaul Merge branch 'main' into quantization-config
55f96d8e
yiyixuxu
yiyixuxu commented on 2024-09-13
Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
692 logger.warning("`low_cpu_mem_usage` was None, now set to True since model is quantized.")
693
694 # Check if `_keep_in_fp32_modules` is not None
695
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
yiyixuxu243 days ago

here is says cls._keep_in_fp32_modules is not None
when would it be None? it was default to be an empty list - let's change the default to None?

class ModelMixin(torch.nn.Module):
    _keep_in_fp32_modules = []
    ...
sayakpaul240 days ago

Done.

Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
327 if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
328 raise ValueError(
329 f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
330
" the logger on the traceback to understand the reason why the quantized model is not serializable."
yiyixuxu243 days ago

but we raised a ValueError here, they are not going to get traceback, no?

sayakpaul240 days ago

I think it would still throw the warnings on the console, hence.

Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
311318 logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
312319 return
313320
321
_hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False)
yiyixuxu243 days ago

why peft is part of the check?

sayakpaul240 days ago

Not needed indeed. Have removed.

Conversation is marked as resolved
Show resolved
src/diffusers/models/model_loading_utils.py
99131 """
100132 Reads a checkpoint file, returning properly formatted errors if they arise.
101133 """
134
if isinstance(checkpoint_file, dict):
yiyixuxu243 days ago

why are we making this change? when will checkpoint_file passed as a dict?

sayakpaul240 days ago

We merge the sharded checkpoints (as stated in the PR description and mutually agreed upon internally) in case we're doing quantization:

model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)

^ model_file becomes a state dict which is loaded by load_state_dict:

state_dict = load_state_dict(model_file, variant=variant)

and hence this change.

yiyixuxu226 days ago๐Ÿ‘ 1

ok then! can you add a note? Because I think you we will want to refactor out the _merge_sharded_checkpoints method later so we remember to check if we should remove this change too

sayakpaul216 days ago
Conversation is marked as resolved
Show resolved
src/diffusers/models/model_loading_utils.py
174 keep_in_fp32_modules=None,
139175) -> List[str]:
140 device = device or torch.device("cpu")
176
device = device or torch.device("cpu") if hf_quantizer is None else device
yiyixuxu243 days ago
Suggested change
device = device or torch.device("cpu") if hf_quantizer is None else device
if hf_quantizer is None:
device = device or torch.device("cpu")
Conversation is marked as resolved
Show resolved
src/diffusers/models/model_loading_utils.py
151189
152 if empty_state_dict[param_name].shape != param.shape:
190 # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
191 # in int/uint/bool and not cast them.
192
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
yiyixuxu243 days ago

we don't yet support param.dtype == torch.float8_e4m3fn, no?
let's not add this for now then

sayakpaul240 days ago

Should we throw an error then?

yiyixuxu226 days ago

I think just removing the related code is fine for now

sayakpaul216 days ago
Conversation is marked as resolved
Show resolved
src/diffusers/models/model_loading_utils.py
190 # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
191 # in int/uint/bool and not cast them.
192 is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
193
if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn:
yiyixuxu243 days ago๐Ÿ‘ 1

dtype is not going to be None here because you did dtype = dtype or torch.float32 earlier

Conversation is marked as resolved
Show resolved
src/diffusers/models/model_loading_utils.py
148187 if param_name not in empty_state_dict:
149 unexpected_keys.append(param_name)
150188 continue
151189
152 if empty_state_dict[param_name].shape != param.shape:
190
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
yiyixuxu243 days ago

so for example:

  1. dtype = torch.float16
  2. is_quantized = False
  3. the module is one of the modules that we included in keep_in_fp32_modules

inside this function, with current function, we would first convert it to torch.float32, then later this line will run, it would be convert back to float16 again because dtype here is still torch.float16 - I don't think it is expected behavior

            if accepts_dtype:
                set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
            else:
                set_module_tensor_to_device(model, param_name, device, value=param)
sayakpaul240 days ago (edited 240 days ago)

Good point!

Added:

+                dtype = torch.float32
                param = param.to(dtype)

Also added tests (test_keep_modules_in_fp32) to ensure effectiveness.

src/diffusers/pipelines/pipeline_utils.py
398400 )
399 if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
401 pipeline_has_8bit_bnb_quant = any(_check_bnb_status(module)[-1] for _, module in self.components.items())
402 if (
403
not pipeline_has_8bit_bnb_quant
yiyixuxu242 days ago

can you explain why are we adding this not pipeline_has_8bit_quant check here?

src/diffusers/pipelines/pipeline_utils.py
439454 and str(device) in ["cpu"]
440455 and not silence_dtype_warnings
441456 and not is_offloaded
457
and not is_loaded_in_4bit_bnb
yiyixuxu242 days ago

why do we add this check here?
if the model is in 4bit, the dtype should not be torch.float16 to begin with, no? even if the dtype shows up as torch.float16 somehow, I think the warning still holds i.e. if the weights are in float16, even though we can move it to cpu, we should not

sayakpaul240 days ago

bnb (and many others like torchao) only applies to torch.nn.Linear layers.

module.dtype because we only check for the first parameter here:

def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:

Now, if for a model (like Cog), where the first layer has a Conv (patch embedding layer), bnb won't be applicable here and layer type would be torch.float16, for example and model.dtype would return torch.float16.

if the weights are in float16, even though we can move it to cpu, we should not

So, in the above scenario, all the weights are not in float16, only a tiny fraction is.

This is why I added this check.

But I think your concern is also valid. So, LMK, if, for this PR,

  • Just remove not is_loaded_in_4bit_bnb and add a note about dtype
  • Modify the logic of how we determine dtype.

I think option 1 should be okay.

yiyixuxu226 days ago๐Ÿ‘ 1

ok, I think the scenraio you decribed where the dtype if float16 but it contains 8int - it does not matter here and we should still send this warning regardless

however, I'm more concerned of the opposite scenario where the model contains both float16 and 4-bit/8-bit and dtype shows up as 4-bit/8-bit; in that case we still should send a warning when user try to move it to cpu, but we won't do that here based on current implementation

the solution should be update our get_parameter_dtype to return float point dtype if it is present, so I think option2 here?

sayakpaul216 days ago

I decided to not touch dtype() for now. Instead I performed changes at the pipeline level. Reference commits:

Does this work for you?

Conversation is marked as resolved
Show resolved
src/diffusers/models/model_loading_utils.py
202 else:
203 param = param.to(dtype)
204
205
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
yiyixuxu241 days ago๐Ÿ‘ 1

this can go outside of the loop

Conversation is marked as resolved
Show resolved
src/diffusers/models/model_loading_utils.py
203 param = param.to(dtype)
204
205 is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
206
if not is_quantized and not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape:
yiyixuxu241 days ago๐Ÿ‘ 1
Suggested change
if not is_quantized and not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape:
if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape:

Based on the comments, we only want to skip the quantized param for now, so keep it that way.

Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
667 config["quantization_config"], quantization_config
668 )
669 else:
670
if "quantization_config" not in config:
671
config["quantization_config"] = quantization_config
yiyixuxu241 days ago๐Ÿ‘ 1
Suggested change
if "quantization_config" not in config:
config["quantization_config"] = quantization_config
config["quantization_config"] = quantization_config

I think quantization_config should not be ignored if "quantization_config" in config and config["quantization_config"] is None, no?

Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
624656 **kwargs,
625657 )
658 # no in-place modification of the original config.
659
config = copy.deepcopy(config)
yiyixuxu241 days ago

do we need this? don't think it is harmful, just wonder why?
we do not allow pass a config dict to from_pretrained(), so we don't have to worry about object outside this function, and the variable is immediately reassigned and the original object is not used anymore

sayakpaul240 days ago

Doing this because we modify the config with quantization config and in any case, that fails, we should be able to retrieve the original model config back before the failure happened. This is my reasoning. WDYT?

yiyixuxu226 days ago๐Ÿ‘ 1

I don't think it is needed but I don't mind it here!

Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
815 model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
816 )
817
818
# We store the original dtype for quantized models as we cannot easily retrieve it
819
# once the weights have been quantized
820
# Note that once you have loaded a quantized model, you can't change its dtype so this will
821
# remain a single source of truth
822
config["_pre_quantization_dtype"] = torch_dtype
yiyixuxu241 days ago
Suggested change
# We store the original dtype for quantized models as we cannot easily retrieve it
# once the weights have been quantized
# Note that once you have loaded a quantized model, you can't change its dtype so this will
# remain a single source of truth
config["_pre_quantization_dtype"] = torch_dtype

we registered it ton config later so it is duplicated I think

sayakpaul240 days ago

We did but keeping it here in case config["_pre_quantization_dtype"] from the intermediate quantization code before we hit register_to_config.

yiyixuxu226 days ago

where config["_pre_quantization_dtype"] is used before it is registered?

sayakpaul216 days ago
Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
529560 low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
530561 variant = kwargs.pop("variant", None)
531562 use_safetensors = kwargs.pop("use_safetensors", None)
563
quantization_config = kwargs.pop("quantization_config", None)
yiyixuxu241 days ago

we don't support low_cpu_mem_usuage=false for quantisation, no? should we throw a warning here?

sayakpaul240 days ago

This is what we're doing currently:

if low_cpu_mem_usage is None:

Maybe we just update this check to:
if low_cpu_mem_usage is None or not low_cpu_mem_usage:

WDYT?

yiyixuxu226 days ago

I don't think you should change low_cpu_mem_usage to True when it is explicitly set to be False - better to throw an error in that case

sayakpaul216 days ago
src/diffusers/models/modeling_utils.py
848960 f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
849961 )
850 elif torch_dtype is not None:
962
elif torch_dtype is not None and hf_quantizer is None:
yiyixuxu241 days ago๐Ÿ‘ 1

ohhhI think cannot do model.to(torch_dtype) here now that we are supporting _keep_in_fp32_modules - it will just convert all layers to torch_dtype again

I don't think the keep_in_fp32_modules is supported yet in the code path when low_cpu_mem_usage =False - so let's maybe make a error message /warming for that too

sayakpaul240 days ago

Done.

if (low_cpu_mem_usage is None or not low_cpu_mem_usage) and cls._keep_in_fp32_modules is not None:
            low_cpu_mem_usage = True
            logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
yiyixuxu226 days ago

even for low_cpu_mem_usage is True, we can not do model = model.to(torch_dtype) here anymore, I think we just have to make sure the dtype conversion is handled properly (with keep_in_fp32_modules) in each code path under if low_cpu_mem_usage

a dummy example, the _keep_in_fp32_modules is be ignored here

import torch
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin

class DummyModel(ModelMixin, ConfigMixin):
    _keep_in_fp32_modules = ["layer2"]
    def __init__(self):
        super().__init__()
        self.layer1 = torch.nn.Linear(10, 20)
        self.layer2 = torch.nn.Linear(20, 30)
        self.layer3 = torch.nn.Linear(30, 40)
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

# Create an instance of the model
model = DummyModel()
model.save_pretrained("dummy_model")
model = DummyModel.from_pretrained("dummy_model", torch_dtype=torch.float16)
sayakpaul216 days ago

I think I have addressed it already. But LMK if you think otherwise.

yiyixuxu212 days ago๐Ÿ‘ 1

so this example I provided here https://github.com/huggingface/diffusers/pull/9213/files#r1782037619

import torch
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin

class DummyModel(ModelMixin, ConfigMixin):
    _keep_in_fp32_modules = ["layer2"]
    def __init__(self):
        super().__init__()
        self.layer1 = torch.nn.Linear(10, 20)
        self.layer2 = torch.nn.Linear(20, 30)
        self.layer3 = torch.nn.Linear(30, 40)
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

# Create an instance of the model
model = DummyModel()
model.save_pretrained("dummy_model")
model = DummyModel.from_pretrained("dummy_model", torch_dtype=torch.float16)
print(model.layer2.weight.dtype)

it will print out torch.float16, even though we have _keep_in_fp32_modules = ["layer2"] so the layer2 should be kept in float32, no?

sayakpaul212 days ago

Yes, you're right. 81bb48a works for you?

Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
421429 is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
422430 for module in modules:
423 is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit
431 _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
432
precision = None
433
precision = "4bit" if is_loaded_in_4bit_bnb else "8bit"
yiyixuxu240 days ago
Suggested change
precision = None
precision = "4bit" if is_loaded_in_4bit_bnb else "8bit"

precision here will be 8bit if for models that's not 8bit and 4bit, it is wrong but not a huge deal because this variable is only used to print warning when it is either 4bit or 8bit - better to just put that online inside the warning.

Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
435 if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
426436 logger.warning(
427 f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision."
437
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {precision} precision."
yiyixuxu240 days ago
Suggested change
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {precision} precision."
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
440 if is_loaded_in_8bit_bnb and device is not None:
431441 logger.warning(
432 f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
442
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
yiyixuxu240 days ago
Suggested change
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
10091025 hook = None
10101026 for model_str in self.model_cpu_offload_seq.split("->"):
10111027 model = all_model_components.pop(model_str, None)
1028
is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = False, False
1029
if model is not None and isinstance(model, torch.nn.Module):
1030
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(model)
1031
yiyixuxu240 days ago

move this check after the if not isinstance(model, torch.nn.Module) check so that we don't need to run this check if it is not an nn.Module

Suggested change
is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = False, False
if model is not None and isinstance(model, torch.nn.Module):
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(model)
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
10131033 continue
10141034
1035 # This is because the model would already be placed on a CUDA device.
1036
if is_loaded_in_8bit_bnb:
yiyixuxu240 days ago
Suggested change
if is_loaded_in_8bit_bnb:
_,_ , is_loaded_in_8bit_bnb = _check_bnb_status(model)
if is_loaded_in_8bit_bnb:
Conversation is marked as resolved
Show resolved
src/diffusers/configuration_utils.py
586587 value = value.as_posix()
587588 return value
588589
590
# IFWatermarker, for example, doesn't have a `config`.
591
if hasattr(self, "config") and "quantization_config" in self.config:
592
config_dict["quantization_config"] = (
593
self.config.quantization_config.to_dict()
594
if not isinstance(self.config.quantization_config, dict)
595
else self.config.quantization_config
596
)
yiyixuxu240 days ago
Suggested change
# IFWatermarker, for example, doesn't have a `config`.
if hasattr(self, "config") and "quantization_config" in self.config:
config_dict["quantization_config"] = (
self.config.quantization_config.to_dict()
if not isinstance(self.config.quantization_config, dict)
else self.config.quantization_config
)
if "quantization_config" in config_dict:
config_dict["quantization_config"] = (
config_dict.quantization_config.to_dict()
if not isinstance(config_dict.quantization_config, dict)
else config_dict.quantization_config
)
Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
128134 _supports_gradient_checkpointing = False
129135 _keys_to_ignore_on_load_unexpected = None
130136 _no_split_modules = None
137
_keep_in_fp32_modules = []
yiyixuxu240 days ago
Suggested change
_keep_in_fp32_modules = []
_keep_in_fp32_modules = None
src/diffusers/quantizers/quantization_config.py
211 Additional parameters from which to initialize the configuration object.
212 """
213
214
_exclude_attributes_at_init = ["_load_in_4bit", "_load_in_8bit", "quant_method"]
yiyixuxu240 days ago

what is this?

sayakpaul240 days ago

Ccing @SunMarc.

We don't need these attributes when initializing a quantization configuration class of BnB. But we need them for subsequent operations.

Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
63 def validate_environment(self, *args, **kwargs):
64 if not torch.cuda.is_available():
65 raise RuntimeError("No GPU found. A GPU is needed for quantization.")
66
if not is_accelerate_available() and is_accelerate_version("<", "0.26.0"):
yiyixuxu240 days ago๐Ÿ‘ 1
Suggested change
if not is_accelerate_available() and is_accelerate_version("<", "0.26.0"):
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
67 raise ImportError(
68 "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`"
69 )
70
if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"):
yiyixuxu240 days ago
Suggested change
if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"):
if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"):
Conversation is marked as resolved
Show resolved
src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
104 logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization")
105 return CustomDtype.INT4
106 else:
107
raise ValueError(
yiyixuxu240 days ago

what is this? error message does not make sense here

sayakpaul240 days ago

Modified.

src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
191 for k, v in state_dict.items():
192 # `startswith` to counter for edge cases where `param_name`
193 # substring can be present in multiple places in the `state_dict`
194
if param_name + "." in k and k.startswith(param_name):
yiyixuxu240 days ago

k.split('.')[0] == param_name ?

sayakpaul240 days ago

Do you mean if param_name + "." in k and k.split('.')[0] == param_name:?

yiyixuxu226 days ago

I think if param_name + "." in k and k.startswith(param_name) is same as k.split('.')[0] == param_name
because if k.split('.')[0] == param_name is True -> if param_name + "." in k is also True, not the case?

sayakpaul216 days ago

I changed the code with your suggestion and the assertions failed. I didn't dig deeper and I think it's okay to keep it as is because it's mostly a nit, really.

src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
275 # Unlike `transformers`, we don't know if we should always keep certain modules in FP32
276 # in case of diffusion transformer models. For language models and others alike, `lm_head`
277 # and tied modules are usually kept in FP32.
278
self.modules_to_not_convert = list(filter(None.__ne__, self.modules_to_not_convert))
yiyixuxu240 days ago

can you provide examples when this list would contain None?

sayakpaul240 days ago

It is configured via llm_int8_skip_modules within the BitsandBytesConfig object. It is defaulted to None in our case because we don't know if there's a requirement of a default unlike language models.

sayakpaul changes
b28cc651
sayakpaul Merge branch 'main' into quantization-config
8328e863
sayakpaul empty
97589423
sayakpaul fix tests
b1a98787
sayakpaul
sayakpaul239 days ago

@yiyixuxu thanks for your reviews. I think they were very nice and helpful. I have gone ahead and re-run the tests on audace and everything is green.

I have addressed your comments and made changes. PTAL.

sayakpaul harmonize with https://github.com/huggingface/transformers/pull/33546.
971305b7
sayakpaul numpy_cosine_distance
f41adf1f
sayakpaul sayakpaul requested a review from yiyixuxu yiyixuxu 238 days ago
sayakpaul Merge branch 'main' into quantization-config
0bcb88b1
sayakpaul Merge branch 'main' into quantization-config
55b3696d
chuck-ma
chuck-ma233 days ago

Hi, looks like everything is great. Don't know why approving review is still processing.

sayakpaul Merge branch 'main' into quantization-config
4cb3a6d7
sayakpaul Merge branch 'main' into quantization-config
8a03eaee
sayakpaul Merge branch 'main' into quantization-config
53f0a920
sayakpaul Merge branch 'main' into quantization-config
6aab47c0
sayakpaul resolved conflicts,
9b9a6107
yiyixuxu
yiyixuxu commented on 2024-09-30
Conversation is marked as resolved
Show resolved
src/diffusers/configuration_utils.py
586587 value = value.as_posix()
587588 return value
588589
590
# IFWatermarker, for example, doesn't have a `config`.
yiyixuxu226 days ago
Suggested change
# IFWatermarker, for example, doesn't have a `config`.
sayakpaul216 days ago
src/diffusers/models/modeling_utils.py
308315 logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
309316 return
310317
318
hf_quantizer = getattr(self, "hf_quantizer", None)
319
quantization_serializable = (
320
hf_quantizer is not None and isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable
321
)
322
323
if hf_quantizer is not None and not quantization_serializable:
324
raise ValueError(
325
f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
326
" the logger on the traceback to understand the reason why the quantized model is not serializable."
327
)
yiyixuxu226 days ago
Suggested change
hf_quantizer = getattr(self, "hf_quantizer", None)
quantization_serializable = (
hf_quantizer is not None and isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable
)
if hf_quantizer is not None and not quantization_serializable:
raise ValueError(
f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
" the logger on the traceback to understand the reason why the quantized model is not serializable."
)
if hf_quantizer is not None:
quantization_serializable = isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable)
if not quantization_serializable:
raise ValueError(
f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
" the logger on the traceback to understand the reason why the quantized model is not serializable."
)
sayakpaul216 days ago (edited 216 days ago)

Not sure if we can remove hf_quantizer = getattr(self, "hf_quantizer", None).

And need the hf_quantizer is not None check because we access properties from it in the error thrown just two lines below.

Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
594625 # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
595626 raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
596627
628
if (low_cpu_mem_usage is None or not low_cpu_mem_usage) and cls._keep_in_fp32_modules is not None:
yiyixuxu226 days ago
Suggested change
if (low_cpu_mem_usage is None or not low_cpu_mem_usage) and cls._keep_in_fp32_modules is not None:
if (low_cpu_mem_usage is None and cls._keep_in_fp32_modules is not None:
yiyixuxu226 days ago
  1. not low_cpu_mem_usage will include low_cpu_mem_usage is None so we don't need both (because both not False and not None evaluate as True)

  2. but I think here it is better to handle False and None differently (same comment as here https://github.com/huggingface/diffusers/pull/9213/files#r1781706229) - I think if user explicitly sets low_cpu_mem_usage to be False and it is not compatible with the model, we should throw an error

yiyixuxu226 days ago

also we should do this after we figure out use_keep_in_fp32_modules later in the code, because cls._keep_in_fp32_modules may not come into effect at all

use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
            (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
        )
sayakpaul216 days ago
Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
852968
853 model.register_to_config(_name_or_path=pretrained_model_name_or_path)
969 if hf_quantizer is not None:
970 # We need to register the _pre_quantization_dtype separately for bookkeeping purposes.
971
# directly assigning `config["_pre_quantization_dtype"]` won't reflect `_pre_quantization_dtype`
972
# in `model.config`. We also make sure to purge `_pre_quantization_dtype` when we serialize
973
# the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable.
yiyixuxu226 days ago
Suggested change
# directly assigning `config["_pre_quantization_dtype"]` won't reflect `_pre_quantization_dtype`
# in `model.config`. We also make sure to purge `_pre_quantization_dtype` when we serialize
# the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable.

i found the notes confusing here, adding a note about it not being Json serializable inside to_json_string is enough

sayakpaul216 days ago

870d74f done. Let's follow this mutually then. I have had a couple of instances where my comments on "adding clearer comments" were simply ignored.

Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
407409 pipeline_is_sequentially_offloaded = any(
408410 module_is_sequentially_offloaded(module) for _, module in self.components.items()
409411 )
412
# pipeline_has_8bit_bnb_quant = any(_check_bnb_status(module)[-1] for _, module in self.components.items())
413
# not pipeline_has_8bit_bnb_quant
yiyixuxu226 days ago
Suggested change
# pipeline_has_8bit_bnb_quant = any(_check_bnb_status(module)[-1] for _, module in self.components.items())
# not pipeline_has_8bit_bnb_quant
sayakpaul216 days ago
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
10431054
1055 # This is because the model would already be placed on a CUDA device.
1056 _, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
1057
if is_loaded_in_8bit_bnb:
yiyixuxu226 days ago

why is sometimes we check the transformer version for 4bit device placement (like inside to earlier), sometime we do not?

sayakpaul216 days ago

Only 4bit models are supported for device placement and it was introduced in bitandbytes recently (and the changes were propagated to transformers recently as well). So.

bghira
bghira222 days ago๐Ÿ‘ 2

i've added this into simpletuner and it's a bit funky but it works for training as well with few modifications other than the loading of the base model and casting to dtype change.

bghira
bghira221 days ago๐Ÿ‘ 1

for testing elsewhere, nf4-trained LyCORIS: https://huggingface.co/RareConcepts/FluxDev-LoKr-beavisandbutthead-nf4

note that the inference speed with NF4 is noticeably slower once an adapter is thrown on top

sayakpaul Merge branch 'main' into quantization-config
510d57a4
sayakpaul config_dict modification.
555a5ae8
sayakpaul remove if config comment.
da103650
sayakpaul note for load_state_dict changes.
71316a66
sayakpaul float8 check.
12f5c593
sayakpaul quantizer.
5e722cdd
sayakpaul raise an error for non-True low_cpu_mem_usage values when using quant.
c78dd0cc
sayakpaul low_cpu_mem_usage shenanigans when using fp32 modules.
af3ecea8
sayakpaul don't re-assign _pre_quantization_type.
a473d28d
sayakpaul make comments clear.
870d74f1
sayakpaul remove comments.
3e6cfeb5
sayakpaul handle mixed types better when moving to cpu.
673993ce
sayakpaul add tests to check if we're throwing warning rightly.
0d5f2f7c
sayakpaul better check.
3cb20fe4
sayakpaul fix 8bit test_quality.
10940a94
sayakpaul
sayakpaul216 days ago (edited 216 days ago)โค 1

@yiyixuxu ready for another review. Have run the tests, too and they pass.

sayakpaul Merge branch 'main' into quantization-config
c0a88aee
sayakpaul Merge branch 'main' into quantization-config
dcc5bc5e
sayakpaul Merge branch 'main' into quantization-config
5e0b4eb1
sayakpaul Merge branch 'main' into quantization-config
569dd960
yiyixuxu
yiyixuxu commented on 2024-10-15
src/diffusers/models/model_loading_utils.py
201 )
202 and dtype == torch.float16
203 ):
204
dtype = torch.float32
yiyixuxu212 days ago๐Ÿ‘ 1

ohh we should not change dtype here, e.g if it is float16, but we changed it here because we hit a parameter that we need to upcast, but then it would changes the dtype for all the remaining parameters in the state dict too, it should remain float16 when it goes to next loop

we need to just pass float32 to set_module_tensor_to_device if it accepts dtype without changing this variable

sayakpaul212 days ago

Does ff8ddef work for you?

yiyixuxu
yiyixuxu commented on 2024-10-15
Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
701 low_cpu_mem_usage = True
702 logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
703 elif not low_cpu_mem_usage:
704
raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
yiyixuxu212 days ago

we already throw an error for low_cpu_mem_usage=False + quantization, this code path is for use_keep_in_fp32_modules, no?

sayakpaul212 days ago

Correct. Should be resolved in de6394a.

sayakpaul Merge branch 'main' into quantization-config
8bdc8465
sayakpaul handle dtype more robustly.
ff8ddef9
sayakpaul better message when keep_in_fp32_modules.
de6394af
sayakpaul handle dtype casting.
81bb48af
sayakpaul sayakpaul requested a review from yiyixuxu yiyixuxu 212 days ago
sayakpaul Merge branch 'main' into quantization-config
c5e62aef
sayakpaul
sayakpaul212 days ago

Very insightful comments, @yiyixuxu! I think I have resolved them all. LMK.

sayakpaul Merge branch 'main' into quantization-config
d023b402
DN6
DN6 commented on 2024-10-16
src/diffusers/quantizers/auto.py
33}
34
35
36
class DiffusersAutoQuantizationConfig:
DN6210 days ago (edited 210 days ago)

I see this is similar to transformers, but I think the DiffusersAutoQuantConfig class is probably not needed.

This is just a simple mapping to a specific quantization config object. The from_pretrained method in the AutoQuantizer is just wrapping the AutoConfig from_pretrained.

I think we can just move these methods/logic directly into the AutoQuantizer.

sayakpaul210 days ago๐Ÿ‘ 1

If this is not a must-have, could do this in a follow-up PR.

sayakpaul Merge branch 'main' into quantization-config
a3d26552
yiyixuxu
yiyixuxu commented on 2024-10-18
Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
449 # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
450 if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
451 module.to(device=device)
452
elif not is_loaded_in_8bit_bnb:
yiyixuxu209 days ago

If the model is 4-bit and with device specified and transformer_version<= 4.440, it will fall through to this elif and attempt the device placement - is this intended?

yiyixuxu209 days ago

also, if the model is 4-bit and with a dtype , it will also come here to this elif for the dtype conversion - is this expected? (This is confusing because we warned earlier about not supporting dtype conversion for 4-bit earlier)

sayakpaul208 days ago

You are right. 0ae70fe should resolve this confusion.

Conversation is marked as resolved
Show resolved
src/diffusers/pipelines/pipeline_utils.py
452466 and not is_offloaded
453467 ):
454468 logger.warning(
455 "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It"
456 " is not recommended to move them to `cpu` as running them will fail. Please make"
457 " sure to use an accelerator to run the pipeline in inference, due to the lack of"
458 " support for`float16` operations on this device in PyTorch. Please, remove the"
469
"Pipelines loaded with `dtype=torch.float16` and containing modules that have int weights"
yiyixuxu209 days ago
  1. in the if statement is is module.dtype == torch.float16 or module_has_int_weights but here it says "and" - which one it is?
  2. can you explain how is this check (to see if it contains int8 and unit8) different than previous checks is_loaded_in_8bit_bnb etc?
sayakpaul208 days ago
  1. Should be or. Fixed it in ecdf1d0.
  2. is_loaded_in_8bit_bnb simply check for the quantization_method in the quantization config. module_has_int_weights check for the dtype of the module params.
yiyixuxu208 days ago

is_loaded_in_8bit_bnb simply check for the quantization_method in the quantization config. module_has_int_weights check for the dtype of the module params.

i can tell that from the code - I guess I'm looking to understand a scenario where is_loaded_in_8bit_bnb and is_loaded_in_8bit_bnb both false but the weights has these dtypes

sayakpaul208 days ago

Oh okay. module.dtype == torch.float16 can still be True while the module has these types. When?

When the first layer of the concerned model has a convolution layer and since that is not covered by bnb its dtype won't be affected. So, our dtype() property will return torchfloat16, for example.

I think I explained it before already but LMK if it's still not clear.

yiyixuxu208 days ago (edited 208 days ago)

so this warning is meant for any module that contains float16, right? that will include cases:

  1. when all the modules are in float16
  2. the scenrios you described (when it says torch.dtype = float16 but also contains other dtype)

since both scenarios will have dtype==torch.float16, so why do we have to check the [torch.uint8, torch.int8] ?

if we want to make sure we include the cases where dtype does not return torch.float16 but it actually include that dtype, we should check for float16, instead of int8, no?

sayakpaul207 days ago

since both scenarios will have dtype==torch.float16, so why do we have to check the [torch.uint8, torch.int8] ?

if we want to make sure we include the cases where dtype does not return torch.float16 but it actually include that dtype, we should check for float16, instead of int8, no?

I don't think so. For models like SD3 where we can have both float16 and int weights, it should be fine. But for models like Flux where it's purely linear, we won't have ANY float16. This is why module.dtype == torch.float16 or or module_has_int_weights is a better condition to check:

yiyixuxu207 days ago

or models like SD3 where we can have both float16 and int weights, it should be fine. But for models like Flux where it's purely linear, we won't have ANY float16.

are you talking about the cases where models contain some float16 but model.dtype won't show as torch.float16, i.e. this case I'm talking about here - shouldn't we check the dtype contains torch.float16, instead of int8?

if we want to make sure we include the cases where dtype does not return torch.float16 but it actually include that dtype, we should check for float16, instead of int8, no?

sayakpaul207 days ago

What is the best way to check:

  • When a model (after being quantized with bnb) has both float16 params and int params (like SD3)
  • When a model (after being quantized with bnb) only has int params

My idea is to use module.dtype == torch.float16 or or module_has_int_weights to cover both, if you have a better idea to cover both the scenarios please suggest.

yiyixuxu207 days ago

in another word, you basically need this check (instead of that is if model.dtype=torch.float16 or ... ?

            module_has_fp16_weights = any(
                module
                for _, module in module.named_modules()
                if module.weight.dtype == torch.float16
            )
sayakpaul207 days ago

In case of a model like Flux (after being quantized), module_has_fp16_weights will always evaluate to False. How do we handle that case?

yiyixuxu207 days ago
def module_has_fp16_weights(model):
    for t in model.parameters():
        if t.is_floating_point() and t.dtype == torch.float16:
            return True
    return False

would this work?

yiyixuxu207 days ago

actually dtype conversion is not supported for 4-bit and 8-bit - why do we need to consider them at all here?

sayakpaul207 days ago

Good point!

Then we could keep the entire

under a if not is_loaded_in_8bit or not is_loaded_in_4bit

Does it work for you?

yiyixuxu207 days ago

what would be a situation that not is_loaded_in_8bit or not is_loaded_in_4bit but weights contains float16 and other dtypes?

sayakpaul207 days ago

My bad it won't occur.

yiyixuxu207 days ago๐Ÿ‘ 1

let's remove that then
I will do more one pass after that, I think we can merge after that

sayakpaul207 days ago (edited 207 days ago)

5d8e844 should resolve this. Thanks for explaining.

Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
697 if not isinstance(keep_in_fp32_modules, list):
698 keep_in_fp32_modules = [keep_in_fp32_modules]
699
700
if low_cpu_mem_usage is None:
yiyixuxu209 days ago

actually I think low_cpu_mem_usage won't be None in from_pretrained because we have that default value when we get it out from kwargs, so just have to raise the error to simplify things

sayakpaul208 days ago

Just covering for edge cases here (I know that the probability is quite low):

  • In case someone explicitly specifies it.
  • Under torch 1.9.0
Conversation is marked as resolved
Show resolved
src/diffusers/models/modeling_utils.py
847963 raise ValueError(
848964 f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
849965 )
850 elif torch_dtype is not None:
966
# When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will
967
# completely lose the effectivity of `use_keep_in_fp32_modules`. `transformers` does
968
# a global dtype setting (see: https://github.com/huggingface/transformers/blob/fa3f2db5c7405a742fcb8f686d3754f70db00977/src/transformers/modeling_utils.py#L4021),
969
# but this would prevent us from doing things like https://github.com/huggingface/diffusers/pull/9177/.
yiyixuxu209 days ago
Suggested change
# When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will
# completely lose the effectivity of `use_keep_in_fp32_modules`. `transformers` does
# a global dtype setting (see: https://github.com/huggingface/transformers/blob/fa3f2db5c7405a742fcb8f686d3754f70db00977/src/transformers/modeling_utils.py#L4021),
# but this would prevent us from doing things like https://github.com/huggingface/diffusers/pull/9177/.
# When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will
# completely lose the effectivity of `use_keep_in_fp32_modules`.

we don't need to talk about transformers here, I don't have time to look into that code, but pretty sure their use_keep_in_fp32_modules works in their models because T5 has these layers

sayakpaul Merge branch 'main' into quantization-config
700b0f3a
sayakpaul fix dtype checks in pipeline.
0ae70fe2
sayakpaul fix warning message.
ecdf1d07
sayakpaul Update src/diffusers/models/modeling_utils.py
aea33981
sayakpaul sayakpaul requested a review from yiyixuxu yiyixuxu 208 days ago
sayakpaul Merge branch 'main' into quantization-config
3a919749
sayakpaul Merge branch 'main' into quantization-config
5d8e8449
sayakpaul mitigate the confusing cpu warning
501a6ba2
ariG23498
ariG23498206 days ago๐ŸŽ‰ 3

Hi folks!

Thanks for working on this. I was able to run the following script on this branch and generate images on my 8 gigs VRAM laptop

Screenshot from 2024-10-20 13-52-20

from diffusers import FluxPipeline, FluxTransformer2DModel
from transformers import T5EncoderModel
import torch
import gc


def flush():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()


def bytes_to_giga_bytes(bytes):
    return bytes / 1024 / 1024 / 1024


flush()

ckpt_id = "black-forest-labs/FLUX.1-dev"
ckpt_4bit_id = "sayakpaul/flux.1-dev-nf4-pkg"
prompt = "a billboard on highway with 'FLUX under 8' written on it"

text_encoder_2_4bit = T5EncoderModel.from_pretrained(
    ckpt_4bit_id,
    subfolder="text_encoder_2",
)

pipeline = FluxPipeline.from_pretrained(
    ckpt_id,
    text_encoder_2=text_encoder_2_4bit,
    transformer=None,
    vae=None,
    torch_dtype=torch.float16,
)
pipeline.enable_model_cpu_offload()


with torch.no_grad():
    print("Encoding prompts.")
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
        prompt=prompt, prompt_2=None, max_sequence_length=256
    )


pipeline = pipeline.to("cpu")
del pipeline

flush()


transformer_4bit = FluxTransformer2DModel.from_pretrained(ckpt_4bit_id, subfolder="transformer")
pipeline = FluxPipeline.from_pretrained(
    ckpt_id,
    text_encoder=None,
    text_encoder_2=None,
    tokenizer=None,
    tokenizer_2=None,
    transformer=transformer_4bit,
    torch_dtype=torch.float16,
)
pipeline.enable_model_cpu_offload()

print("Running denoising.")
height, width = 512, 768
images = pipeline(
    prompt_embeds=prompt_embeds,
    pooled_prompt_embeds=pooled_prompt_embeds,
    num_inference_steps=50,
    guidance_scale=5.5,
    height=height,
    width=width,
    output_type="pil",
).images
images[0].save("output.png")

output

yiyixuxu
yiyixuxu approved these changes on 2024-10-20
yiyixuxu206 days ago

let's merge this!

I asked @DN6 to open a follow-up PR for this #9213 (comment),

sayakpaul Merge branch 'main' into quantization-config
1a931cb5
sayakpaul
sayakpaul206 days ago

PR merge contingent on #9720.

sayakpaul Merge branch 'main' into quantization-config
2fa8fb91
sayakpaul sayakpaul merged b821f006 into main 206 days ago
sayakpaul sayakpaul deleted the quantization-config branch 206 days ago
DN6
DN6 commented on 2024-10-16
src/diffusers/quantizers/quantization_config.py
159
160
161@dataclass
162
class BitsAndBytesConfig(QuantizationConfigMixin):
DN6210 days ago

Something to consider. Let's assume you want to use a quantized transformer model in your code. With this naming, you would always need to set up imports in the following way.

from transformers import BitsAndBytesConfig
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig

Not a huge issue. Just giving a heads up incase you want to consider renaming the config to something like DiffusersBitsAndBytesConfig

src/diffusers/models/model_loading_utils.py
211 set_module_kwargs["dtype"] = dtype
212
213 # bnb params are flattened.
214
if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape:
DN6210 days ago

In this situation, aren't we skipping parameter shape checks for bnb loaded weights entirely? What happens when one attempts to load bnb weights but the flattened shape is incorrect?

Perhaps we add a check_quantized_param_shape method to the DiffusersQuantizer base class. And in the BnBQuantizer we can check if the shape matches the rule here:
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/18e827d666fa2b70a12d539ccedc17aa51b2c97c/bitsandbytes/functional.py#L816

src/diffusers/models/model_loading_utils.py
156217 f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
157218 )
158219
159 if accepts_dtype:
160 set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
220
if not is_quantized or (
221
not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device)
222
):
223
if accepts_dtype:
224
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
225
else:
226
set_module_tensor_to_device(model, param_name, device, value=param)
161227
else:
162
set_module_tensor_to_device(model, param_name, device, value=param)
228
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
229
DN6210 days ago

Small nit. IMO this is a bit more readable

        if is_quantized or hf_quantizer.check_quantized_param(
            model, param, param_name, state_dict, param_device=device
        ):
            hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
        else:
            if accepts_dtype:
                set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
            else:
                set_module_tensor_to_device(model, param_name, device, value=param)
src/diffusers/quantizers/base.py
134 """adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
135 return max_memory
136
137
def check_quantized_param(
DN6210 days ago๐Ÿ‘ 1

IMO check_is_quantized_param or check_if_quantized_param more explicitly conveys what this method does.

tests/quantization/bnb/test_4bit.py
117
118
119class BnB4BitBasicTests(Base4bitTests):
120
def setUp(self):
DN6210 days ago๐Ÿ‘ 1

Would clear cache on setup as well.

Ednaordinary
Ednaordinary193 days ago

It would be useful to rename llm_int8_skip_modules or otherwise make it more clear that it is respected in both 4bit and 8bit mode, as currently the docs sound like skipped modules are only respected in 8 bit mode while the actual implementation suggests otherwise

if self.quantization_config.llm_int8_skip_modules is not None:

image

sayakpaul
sayakpaul193 days ago๐Ÿ‘ 1

Yeah I think the documentation should reflect this. I guess this is safe to do @SunMarc?

SunMarc
SunMarc191 days ago

Yeah we should do that, would you like to update this @Ednaordinary ? We should also do it in transformers when it gets merged.

Ednaordinary
Ednaordinary190 days ago

Sure, @SunMarc. I'll make a PR when I'm able. Should I refactor the parameter name and include a deprecation notice, or just include a note in the docs?

MartinRad52
MartinRad5269 days ago

@sayakpaul Hi Saya, can you please help on this related to you post. I'm trying to use the quantized Flux model (https://colab.research.google.com/gist/sayakpaul/8fb27a653934c1bc6b013913c346e456/scratchpad.ipynb#scrollTo=YxbQEPd1_Tqf) for FLUX outpainting model (https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev). Here is my script :

import torch
from diffusers import FluxFillPipeline,FluxTransformer2DModel
from diffusers.utils import load_image
from transformers import T5EncoderModel
import gc
image = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup.png")
mask = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup_mask.png")

nf4_model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"

prompt = "a cute dog in paris photoshoot"

def flush():
"""Wipes off memory."""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()

def bytes_to_giga_bytes(bytes):
return f"{(bytes / 1024 / 1024 / 1024):.3f}"

flush()

text_encoder_2 = T5EncoderModel.from_pretrained(
nf4_model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)
transformer = FluxTransformer2DModel.from_pretrained(
nf4_model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)

pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev",
text_encoder_2=text_encoder_2,
transformer=transformer,
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()

image = pipe(
prompt="a white paper cup",
image=image,
mask_image=mask,
height=1632,
width=1232,
guidance_scale=30,
num_inference_steps=50,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]

torch.cuda.empty_cache()
memory = bytes_to_giga_bytes(torch.cuda.memory_allocated())
print(f"{memory=} GB.")
image.save(f"flux-fill-dev.png")


But I get this error:

462 output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
464 # 3. Save state
465 ctx.state = quant_state

RuntimeError: mat1 and mat2 shapes cannot be multiplied (7854x384 and 64x3072)

sayakpaul
sayakpaul69 days ago

You are using the wrong checkpoint for fill. It should be https://huggingface.co/diffusers/FLUX.1-Fill-dev-nf4.

MartinRad52
MartinRad5268 days ago

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone