The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Thanks for adding this ! I see that you used a lot of things from transformers. Do you think it is possible to import these (or inherit) from transformers ? This will help reducing the maintenance. I'm fine also doing that since there are not too many follow-up PR after a quantizer has been added. About the HfQuantizer
class, there are a lot of methods that were created to fit transformers structure. I'm not sure we will need eveyone of these methods in diffusers. Ofc, we can still do a follow-up PR to clean up.
@SunMarc I am guilty as charged but we donโt have transformers as a hard dependency for loading models in Diffusers. Pinging @DN6 to seek his opinion.
Update: Chatted with @DN6 as well. We think it's better to redefine inside diffusers
without the transformers
specific bits which we can clean in this PR.
@SunMarc I think this PR is ready for another review.
Thanks for adding this @sayakpaul !
I don't think it makes sense to have this as a separate PR to add a base class because it's hard to understand what methods are needed - we should only introduce a minimum base class and gradually add functionalities as needed
can we have a PR with a minimum example working?
Okay, so, do you want me to add everything needed for bitsandbytes integration in this PR? But do note that this wonโt be very different from what we have in transformers.
@sayakpaul
I think so because:
sometimes we can make a feature branch where a bunch of PRs can be merged into before one big honkin' PR is pushed to main at the end. and the pieces are all individually reviewed and can be tested. is this a viable approach for including quantisation?
Okay I will update this branch. @yiyixuxu
cc @MekkCyber for visibility
Just a few considerations for the quantization design.
I would say the initial design should start loading/inference at just the model level and then proceed to add functionality (pipeline level loading etc).
The feature needs to perform the following functions
from_pretrained
from_single_file
At the moment, the most common ask seems to be the ability to load models into GPU using the FP8 dtype and run inference in a supported dtype by dynamically upcasting the necessary layers. NF4 is another format that's gaining attention.
So perhaps we should focus on this first. This mostly applies to the DiT models but large models like CogVideo might also benefit with this approach.
Some example quantized versions of models that have been doing the rounds
To cover these initial cases, we can rely on Quanto (FP8) and BitsandBytes (NF4).
Example API:
from diffusers import FluxPipeline, FluxTransformer2DModel, DiffusersQuantoConfig
# Load model in FP8 with Quanto and perform compute in configured dtype.
quantization_config = DiffusersQuantoConfig(weights="float8", compute_dtype=torch.bfloat16)
FluxTransformer2DModel.from_pretrained("<either diffusers format or quanto format weights>", quantization_config=quantization_config)
pipe = FluxPipeline.from_pretrained("...", transformer=transformer)
The quantization config should probably take the following arguments
DiffusersQuantoConfig(
weights_dtype="", # dtype to store weights
compute_dtype="", # dtype to perform inference
skip_quantize_modules=["ResBlock"]
)
I think initially we can rely on the dynamic upcasting operations performed by Quanto and BnB under the hood to start and then expand on them if needed.
Some other considerations
Diffusers
prefix to the quantization configs. e.g DiffusersQuantoConfig
so that when we import quantization configs from transformers there aren't any conflicts.safetensors.torch.save_file(model.to(torch.float8_e4m3fn), "model-fp8.safetensors)
and loading full pipeline single file checkpoints. But we can address these later.This PR will be at the model-level itself. And we should not add multiple backends in a single PR. This PR aims to add bitsandbytes
. We can do other backends taking this PR as a reference. I would like us to mutually agree on this before I start making progress on this PR.
Concretely, I would like to stick to the outline of the changes laid out in #9174 (along with anything related) for this PR.
The feature needs to perform the following functions
I won't advocate doing all of that in a single PR because it makes things very hard to review. We would rather want to move faster with something more minimal, confirming their effectiveness.
Allow loading/inference with LoRAs in these quantized models. (This we have to figure out in more detail)
Well, note that if the underlying LoRA wasn't trained with the base quantization precision, it might not perform as expected.
So perhaps we should focus on this first. This mostly applies to the DiT models but large models like CogVideo might also benefit with this approach.
Please note that bitsandbytes
related quantization mostly applies to nn.linear
whereas quanto
is broader in their scopes (i.e, quanto
can be applied to an nn.Conv2D
as well).
This PR will be at the model-level itself. And we should not add multiple backends in a single PR. This PR aims to add bitsandbytes. We can do other backends taking this PR as a reference. I would like us to mutually agree on this before I start making progress on this PR.
Sounds good to me.
For this PR lets do
128 | 131 | _supports_gradient_checkpointing = False | |
129 | 132 | _keys_to_ignore_on_load_unexpected = None | |
130 | 133 | _no_split_modules = None | |
134 | _keep_in_fp32_modules = [] |
We have to introduce this attribute now that we're seriously entering the diffusion territory.
If i load lora after quantization, it throws errors:
ValueError: .to
is not supported for 4-bit
or 8-bit
bitsandbytes models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct dtype
.
pipe.load_lora_weights(
hf_hub_download(repo_name, ckpt_name, adapter_name="ckpt_name"
)
@chuck-ma there's a reason why this PR is still in draft :) We will consider these use cases a bit later in the pipeline.
However, you are welcome to try out basic functionalities like loading and saving without LoRAs.
@chuck-ma there's a reason why this PR is still in draft :) We will consider these use cases a bit later in the pipeline.
However, you are welcome to try out basic functionalities like loading and saving without LoRAs.
Gotcha, thanks.
I got this error:
TypeError Traceback (most recent call last)
File ~/autodl-tmp/diffusers/src/diffusers/models/model_loading_utils.py:134, in load_state_dict(checkpoint_file, variant)
133 try:
--> 134 file_extension = os.path.basename(checkpoint_file).split(".")[-1]
135 if file_extension == SAFETENSORS_FILE_EXTENSION:
File ~/miniconda3/lib/python3.10/posixpath.py:142, in basename(p)
141 """Returns the final component of a pathname"""
--> 142 p = os.fspath(p)
143 sep = _get_sep(p)
TypeError: expected str, bytes or os.PathLike object, not NoneType
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
Cell In[7], line 11
4 model_id = "black-forest-labs/FLUX.1-dev"
6 nf4_config = BitsAndBytesConfig(
7 load_in_4bit=True,
8 bnb_4bit_quant_type="nf4",
9 bnb_4bit_compute_dtype=torch.bfloat16,
10 )
---> 11 model_nf4 = FluxTransformer2DModel.from_pretrained(
12 model_id, subfolder="transformer", quantization_config=nf4_config
13 )
14 print(model_nf4.dtype)
15 print(model_nf4.quantization_config)
File ~/miniconda3/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py:114, in validate_hf_hub_args.._inner_fn(*args, **kwargs)
111 if check_use_auth_token:
112 kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.name, has_token=has_token, kwargs=kwargs)
--> 114 return fn(*args, **kwargs)
File ~/autodl-tmp/diffusers/src/diffusers/models/modeling_utils.py:817, in ModelMixin.from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
815 if device_map is None and not is_sharded or (hf_quantizer is not None):
816 param_device = "cpu"
--> 817 state_dict = load_state_dict(model_file, variant=variant)
818 model._convert_deprecated_attention_blocks(state_dict)
819 # move the params from meta device to cpu
File ~/autodl-tmp/diffusers/src/diffusers/models/model_loading_utils.py:146, in load_state_dict(checkpoint_file, variant)
144 except Exception as e:
145 try:
--> 146 with open(checkpoint_file) as f:
147 if f.read().startswith("version"):
148 raise OSError(
149 "You seem to have cloned a repository without having git-lfs installed. Please install "
150 "git-lfs and run git lfs install
followed by git lfs pull
in the folder "
151 "you cloned."
152 )
TypeError: expected str, bytes or os.PathLike object, not NoneType
import torch
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel
model_id = "black-forest-labs/FLUX.1-dev"
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model_nf4 = FluxTransformer2DModel.from_pretrained(
model_id, subfolder="transformer", quantization_config=nf4_config
)
print(model_nf4.dtype)
print(model_nf4.quantization_config)
if i don't use this, everything is just fine:
import torch
from diffusers import FluxTransformer2DModel
model_id = "black-forest-labs/FLUX.1-dev"
model_nf4 = FluxTransformer2DModel.from_pretrained(
model_id, subfolder="transformer",
)
print(model_nf4.dtype)
@chuck-ma do you wanna give this a try now?
Maybe there's something wrong with your setup but I am able to load things without any issues:
https://colab.research.google.com/gist/sayakpaul/1bfcf2441f73364bb06c801f58303cd5/scratchpad.ipynb
I see you also confirmed it's working here.
@chuck-ma do you wanna give this a try now?
Now it works.
527 | 527 | ||
528 | 528 | # 4. Give nice warning if unexpected values have been passed | |
529 | if len(config_dict) > 0: | ||
529 | only_quant_config_remaining = len(config_dict) == 1 and "quantization_config" in config_dict |
Because quantization_config
isn't a part of any model's __init__()
.
I think it is better to not add to cofig_dict if it is not going into __init__
, i.e. at line 511
# remove private attributes
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
# remove quantization_config
config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config")}
We cannot remove quantization_config
from the config of a model as that would prevent loading of the quantized models via from_pretrained()
.
quantization_config
isn't used for initializing a model, it's used to determine what kind of quantization configuration to inject inside the given model. This is why it's only used in from_pretrained()
of ModelMixin
.
LMK if you have a better idea to handle it.
we do not remove them from the config, just not adding to the config_dict
inside this extract_init_dict
method: basically, the cofig_dict
in this function goes through these steps:
init_dict
: the quantisation config will not go there, so it is not affected if we do not add it to config_dict
init_dict
, if the quantisation configs were not there, we do not need to throw a warning for itunused_kwargs
- so I think this is the only difference it would make, do we need the quantisation config to be in unused_kwargs
returned by extract_init_dict
? I think unused_kwargs
is only used to send additional warnings for unexpected stuff, but since quantisation config is expected, and we have already decided not to send a warning here inside extract_init_dict
- I think it does not need to go to the unused_kwargs
here? @classmethod
def extract_init_dict(cls, config_dict, **kwargs):
...
config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"
# remove private attributes
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
+ # remove quantization_config
+ config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config")}
## here we use config_dict to create `init_dict` which will be passed to `__init__` method
init_dict = {}
for key in expected_keys:
...
init_dict[key] = config_dict.pop(key)
- only_quant_config_remaining = len(config_dict) == 1 and "quantization_config" in config_dict
- if len(config_dict) > 0 and not only_quant_config_remaining:
+ if len(config_dict) > 0:
logger.warning(
f"The config attributes {config_dict} were passed to {cls.__name__}, "
"but are not expected and will be ignored. Please verify your "
f"{cls.config_name} configuration file."
)
....
# 6. Define unused keyword arguments
unused_kwargs = {**config_dict, **kwargs}
return init_dict, unused_kwargs, hidden_config_dict
Makes sense. Resolved in 555a5ae.
173 | keep_in_fp32_modules=None, | ||
139 | 174 | ) -> List[str]: | |
140 | device = device or torch.device("cpu") | ||
175 | device = device or torch.device("cpu") if hf_quantizer is None else device |
More on this in the later changes.
202 | else: | ||
203 | param = param.to(dtype) | ||
204 | |||
205 | is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES | ||
206 | if not is_quantized and not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: |
Because bnb quantized params are usually flattened.
44 | 44 | from ..models import AutoencoderKL | |
45 | 45 | from ..models.attention_processor import FusedAttnProcessor2_0 | |
46 | 46 | from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin | |
47 | from ..quantizers.bitsandbytes.utils import _check_bnb_status |
Reason why I preferred not having _check_bnb_status()
inline is because I imagine it'd be used across the library. So, didn't make sense to include it inside of a function.
Wow, big PR, great feature being added here.
I haven't done an in-depth review, but took a look at the parts related to PEFT and skimmed the rest.
With such a big change, it might be worth it to control the line coverage of the newly added tests to ensure that the new code is reasonably well covered.
173 | keep_in_fp32_modules=None, | ||
139 | 174 | ) -> List[str]: | |
140 | device = device or torch.device("cpu") | ||
175 | device = device or torch.device("cpu") if hf_quantizer is None else device |
Not specific to this PR but device = device or torch.device("cpu")
is a bit dangerous because theoretically, 0
is a valid device but it would be considered falsy. AFAICT it's not problematic for the existing code, but something to keep in mind.
Indeed.
I have added a comment about it too.
318 | """ | ||
319 | Converts a quantized model into its dequantized original version. The newly converted model will have some | ||
320 | performance drop compared to the original model before quantization - use it only for specific usecases such as | ||
321 | QLoRA adapters merging. |
Note that PEFT supports merging into bnb weights, so that alone would not require dequantizing the weights entirely.
Noted. I guess not immediately relevant for this PR?
I think it is still interesting to let users have a way to dequantize their models.
Thanks for your efforts. I think it's better if we can load the transformer that has been quantized instead of quantizing the transformer every time we load it. @sayakpaul
I think it's better if we can load the transformer that has been quantized instead of quantizing the transformer every time we load it. @sayakpaul
Possible now:
from diffusers import FluxTransformer2DModel
model_id = "sayakpaul/flux.1-dev-nf4-pkg"
model_nf4 = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
Not sure what made you think it's not possible.
I think it's better if we can load the transformer that has been quantized instead of quantizing the transformer every time we load it. @sayakpaul
Possible now:
from diffusers import FluxTransformer2DModel model_id = "sayakpaul/flux.1-dev-nf4-pkg" model_nf4 = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")Not sure what made you think it's not possible.
import torch
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel
model_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
pipe.transformer.save_pretrained(flux_transformer_id)
model_nf4 = FluxTransformer2DModel.from_pretrained(
flux_transformer_id,
# quantization_config=nf4_config,
torch_dtype=dtype,
)
I just got an error:(no matter if i use quantization_config)
ValueError: Cannot load <class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> from /root/autodl-tmp/flux_transformer because the following keys are missing:
transformer_blocks.11.ff_context.net.0.proj.weight, transformer_blocks.3.attn.add_k_proj.bias, transformer_blocks.6.ff_context.net.2.weight, single_transformer_blocks.27.attn.to_k.bias, transformer_blocks.15.attn.to_out.0.weight, single_transformer_blocks.35.proj_out.bias, time_text_embed.guidance_embedder.linear_1.bias, transformer_blocks.2.attn.to_add_out.bias, transformer_blocks.3.attn.add_v_proj.weight, transformer_blocks.6.ff.net.2.weight,
@chuck-ma please try to follow the Colab Notebooks provided in https://hf.co/sayakpaul/flux.1-dev-nf4-pkg. All of them show the correct usage and run without any errors. And when you're facing errors, please try to provide Colab Notebooks so I can verify things. Otherwise, it's hard for me to reproduce errors. Could we do that?
I think it's better if we can load the transformer that has been quantized instead of quantizing the transformer every time we load it. @sayakpaul
Possible now:
from diffusers import FluxTransformer2DModel model_id = "sayakpaul/flux.1-dev-nf4-pkg" model_nf4 = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")Not sure what made you think it's not possible.
from diffusers import FluxPipeline
import torch
from huggingface_hub import hf_hub_download
repo_name = "ByteDance/Hyper-SD"
ckpt_16steps_name = "Hyper-FLUX.1-dev-8steps-lora.safetensors"
create_fuse_checkp = True
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dtype,
)
if create_fuse_checkp:
model_nf4 = FluxTransformer2DModel.from_pretrained(
# flux_transformer_id,
# flux_transformer_id,
model_id,
subfolder="transformer",
quantization_config=nf4_config,
torch_dtype=dtype,
)
pipe.transformer = model_nf4
pipe.load_lora_weights(
hf_hub_download(repo_name, ckpt_16steps_name), adapter_name="8_steps_lora"
)
pipe.fuse_lora(lora_scale=0.125)
pipe.transformer.save_pretrained(flux_transformer_id)
If I merge lora and then save the transformer, it will get the error. Otherwise, everything is just fine.
If I merge lora and then save the transformer, it will get the error. Otherwise, everything is just fine.
I told you here, that LoRA can be tried out later. So, please be aware of the expectations.
If I merge lora and then save the transformer, it will get the error. Otherwise, everything is just fine.
I told you here, that LoRA can be tried out later. So, please be aware of the expectations.
OK. Because I saw that your latest code can support merging lora after loading nf4, I wanted to try whether it is possible to save and load after merging lora.
Anyway, nice job.
i think it'll be different if you quantise after merge.
Reviewed half of the PR ! I will do the rest soon but since it is mainly config, there shouldn't be any big blockers. Thanks for the PR @sayakpaul ! I left a few comments
but since it is mainly config, there shouldn't be any big blockers.
Did you mean your comments are associated to configs? Understood if that is the case but this PR attempts at adding full-fledged BnB loading support, not just configs.
Thanks for this huge work @sayakpaul ! I'll one final review after you added the 8bit tests. This looks very good. I left a few minor comments
Did you mean your comments are associated to configs? Understood if that is the case but this PR attempts at adding full-fledged BnB loading support, not just configs.
I was talking to the rest of the PR but yeah there were also the quantizer + tests. I went through the entire PR now
I know Lora's aren't a part of this PR but what's the current situation with loading in Lora's on top of nf4 model without dequantize?
It looks like converting to PEFT would work? I assume there is a way to load on top of the quantized model without having to dequantize it because Forge doesn't seem to be doing this.
I figure this place is the best area to leave comment but sorry if slightly off topic of PR.
@itsyourlad no worries. After this PR, the next plan is to add a training script showing how to do LoRAs on NF4 models with PEFT. So, stay tuned.
Thanks for iterating @sayakpaul ! LGTM ! It's nice to finally quantization integrated in diffusers !
32 | |||
33 | ## When to use what? | ||
34 | |||
35 | This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. |
Yes I think it will be nice to also have a table directly in this doc in the future
@yiyixuxu this is ready for your review.
Thanks, this looks really good! ๐ฅ
398 | 400 | ) | |
399 | if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda": | ||
401 | pipeline_has_8bit_bnb_quant = any(_check_bnb_status(module)[-1] for _, module in self.components.items()) | ||
402 | if ( | ||
403 | not pipeline_has_8bit_bnb_quant |
can you explain why are we adding this not pipeline_has_8bit_quant check here?
439 | 454 | and str(device) in ["cpu"] | |
440 | 455 | and not silence_dtype_warnings | |
441 | 456 | and not is_offloaded | |
457 | and not is_loaded_in_4bit_bnb |
why do we add this check here?
if the model is in 4bit, the dtype should not be torch.float16 to begin with, no? even if the dtype shows up as torch.float16 somehow, I think the warning still holds i.e. if the weights are in float16, even though we can move it to cpu, we should not
bnb (and many others like torchao
) only applies to torch.nn.Linear
layers.
module.dtype
because we only check for the first parameter here:
Now, if for a model (like Cog), where the first layer has a Conv (patch embedding layer), bnb won't be applicable here and layer type would be torch.float16
, for example and model.dtype
would return torch.float16
.
if the weights are in float16, even though we can move it to cpu, we should not
So, in the above scenario, all the weights are not in float16, only a tiny fraction is.
This is why I added this check.
But I think your concern is also valid. So, LMK, if, for this PR,
not is_loaded_in_4bit_bnb
and add a note about dtype
dtype
.I think option 1 should be okay.
ok, I think the scenraio you decribed where the dtype if float16 but it contains 8int - it does not matter here and we should still send this warning regardless
however, I'm more concerned of the opposite scenario where the model contains both float16 and 4-bit/8-bit and dtype shows up as 4-bit/8-bit; in that case we still should send a warning when user try to move it to cpu, but we won't do that here based on current implementation
the solution should be update our get_parameter_dtype
to return float point dtype if it is present, so I think option2 here?
848 | 960 | f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." | |
849 | 961 | ) | |
850 | elif torch_dtype is not None: | ||
962 | elif torch_dtype is not None and hf_quantizer is None: |
ohhhI think cannot do model.to(torch_dtype)
here now that we are supporting _keep_in_fp32_modules
- it will just convert all layers to torch_dtype
again
I don't think the keep_in_fp32_modules is supported yet in the code path when low_cpu_mem_usage =False
- so let's maybe make a error message /warming for that too
Done.
if (low_cpu_mem_usage is None or not low_cpu_mem_usage) and cls._keep_in_fp32_modules is not None:
low_cpu_mem_usage = True
logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
even for low_cpu_mem_usage
is True
, we can not do model = model.to(torch_dtype)
here anymore, I think we just have to make sure the dtype conversion is handled properly (with keep_in_fp32_modules
) in each code path under if low_cpu_mem_usage
a dummy example, the _keep_in_fp32_modules
is be ignored here
import torch
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin
class DummyModel(ModelMixin, ConfigMixin):
_keep_in_fp32_modules = ["layer2"]
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(10, 20)
self.layer2 = torch.nn.Linear(20, 30)
self.layer3 = torch.nn.Linear(30, 40)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
# Create an instance of the model
model = DummyModel()
model.save_pretrained("dummy_model")
model = DummyModel.from_pretrained("dummy_model", torch_dtype=torch.float16)
I think I have addressed it already. But LMK if you think otherwise.
so this example I provided here https://github.com/huggingface/diffusers/pull/9213/files#r1782037619
import torch
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin
class DummyModel(ModelMixin, ConfigMixin):
_keep_in_fp32_modules = ["layer2"]
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(10, 20)
self.layer2 = torch.nn.Linear(20, 30)
self.layer3 = torch.nn.Linear(30, 40)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
# Create an instance of the model
model = DummyModel()
model.save_pretrained("dummy_model")
model = DummyModel.from_pretrained("dummy_model", torch_dtype=torch.float16)
print(model.layer2.weight.dtype)
it will print out torch.float16
, even though we have _keep_in_fp32_modules = ["layer2"]
so the layer2 should be kept in float32, no?
Yes, you're right. 81bb48a works for you?
211 | Additional parameters from which to initialize the configuration object. | ||
212 | """ | ||
213 | |||
214 | _exclude_attributes_at_init = ["_load_in_4bit", "_load_in_8bit", "quant_method"] |
what is this?
Ccing @SunMarc.
We don't need these attributes when initializing a quantization configuration class of BnB. But we need them for subsequent operations.
191 | for k, v in state_dict.items(): | ||
192 | # `startswith` to counter for edge cases where `param_name` | ||
193 | # substring can be present in multiple places in the `state_dict` | ||
194 | if param_name + "." in k and k.startswith(param_name): |
k.split('.')[0] == param_name
?
Do you mean if param_name + "." in k and k.split('.')[0] == param_name:
?
I think if param_name + "." in k and k.startswith(param_name)
is same as k.split('.')[0] == param_name
because if k.split('.')[0] == param_name
is True -> if param_name + "." in k
is also True, not the case?
I changed the code with your suggestion and the assertions failed. I didn't dig deeper and I think it's okay to keep it as is because it's mostly a nit, really.
275 | # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 | ||
276 | # in case of diffusion transformer models. For language models and others alike, `lm_head` | ||
277 | # and tied modules are usually kept in FP32. | ||
278 | self.modules_to_not_convert = list(filter(None.__ne__, self.modules_to_not_convert)) |
can you provide examples when this list would contain None
?
It is configured via llm_int8_skip_modules
within the BitsandBytesConfig
object. It is defaulted to None
in our case because we don't know if there's a requirement of a default unlike language models.
@yiyixuxu thanks for your reviews. I think they were very nice and helpful. I have gone ahead and re-run the tests on audace
and everything is green.
I have addressed your comments and made changes. PTAL.
Hi, looks like everything is great. Don't know why approving review is still processing.
308 | 315 | logger.error(f"Provided path ({save_directory}) should be a directory, not a file") | |
309 | 316 | return | |
310 | 317 | ||
318 | hf_quantizer = getattr(self, "hf_quantizer", None) | ||
319 | quantization_serializable = ( | ||
320 | hf_quantizer is not None and isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable | ||
321 | ) | ||
322 | |||
323 | if hf_quantizer is not None and not quantization_serializable: | ||
324 | raise ValueError( | ||
325 | f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" | ||
326 | " the logger on the traceback to understand the reason why the quantized model is not serializable." | ||
327 | ) |
hf_quantizer = getattr(self, "hf_quantizer", None) | |
quantization_serializable = ( | |
hf_quantizer is not None and isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable | |
) | |
if hf_quantizer is not None and not quantization_serializable: | |
raise ValueError( | |
f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" | |
" the logger on the traceback to understand the reason why the quantized model is not serializable." | |
) | |
if hf_quantizer is not None: | |
quantization_serializable = isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable) | |
if not quantization_serializable: | |
raise ValueError( | |
f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" | |
" the logger on the traceback to understand the reason why the quantized model is not serializable." | |
) |
Not sure if we can remove hf_quantizer = getattr(self, "hf_quantizer", None)
.
And need the hf_quantizer is not None
check because we access properties from it in the error thrown just two lines below.
i've added this into simpletuner and it's a bit funky but it works for training as well with few modifications other than the loading of the base model and casting to dtype change.
for testing elsewhere, nf4-trained LyCORIS: https://huggingface.co/RareConcepts/FluxDev-LoKr-beavisandbutthead-nf4
note that the inference speed with NF4 is noticeably slower once an adapter is thrown on top
@yiyixuxu ready for another review. Have run the tests, too and they pass.
201 | ) | ||
202 | and dtype == torch.float16 | ||
203 | ): | ||
204 | dtype = torch.float32 |
ohh we should not change dtype
here, e.g if it is float16
, but we changed it here because we hit a parameter that we need to upcast, but then it would changes the dtype
for all the remaining parameters in the state dict too, it should remain float16 when it goes to next loop
we need to just pass float32 to set_module_tensor_to_device
if it accepts dtype
without changing this variable
Does ff8ddef work for you?
Very insightful comments, @yiyixuxu! I think I have resolved them all. LMK.
33 | } | ||
34 | |||
35 | |||
36 | class DiffusersAutoQuantizationConfig: |
I see this is similar to transformers, but I think the DiffusersAutoQuantConfig class is probably not needed.
This is just a simple mapping to a specific quantization config object. The from_pretrained
method in the AutoQuantizer is just wrapping the AutoConfig from_pretrained
.
I think we can just move these methods/logic directly into the AutoQuantizer.
If this is not a must-have, could do this in a follow-up PR.
Hi folks!
Thanks for working on this. I was able to run the following script on this branch and generate images on my 8 gigs VRAM laptop
from diffusers import FluxPipeline, FluxTransformer2DModel
from transformers import T5EncoderModel
import torch
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
flush()
ckpt_id = "black-forest-labs/FLUX.1-dev"
ckpt_4bit_id = "sayakpaul/flux.1-dev-nf4-pkg"
prompt = "a billboard on highway with 'FLUX under 8' written on it"
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
ckpt_4bit_id,
subfolder="text_encoder_2",
)
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder_2=text_encoder_2_4bit,
transformer=None,
vae=None,
torch_dtype=torch.float16,
)
pipeline.enable_model_cpu_offload()
with torch.no_grad():
print("Encoding prompts.")
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=256
)
pipeline = pipeline.to("cpu")
del pipeline
flush()
transformer_4bit = FluxTransformer2DModel.from_pretrained(ckpt_4bit_id, subfolder="transformer")
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
transformer=transformer_4bit,
torch_dtype=torch.float16,
)
pipeline.enable_model_cpu_offload()
print("Running denoising.")
height, width = 512, 768
images = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=50,
guidance_scale=5.5,
height=height,
width=width,
output_type="pil",
).images
images[0].save("output.png")
let's merge this!
I asked @DN6 to open a follow-up PR for this #9213 (comment),
PR merge contingent on #9720.
159 | |||
160 | |||
161 | @dataclass | ||
162 | class BitsAndBytesConfig(QuantizationConfigMixin): |
Something to consider. Let's assume you want to use a quantized transformer model in your code. With this naming, you would always need to set up imports in the following way.
from transformers import BitsAndBytesConfig
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
Not a huge issue. Just giving a heads up incase you want to consider renaming the config to something like DiffusersBitsAndBytesConfig
211 | set_module_kwargs["dtype"] = dtype | ||
212 | |||
213 | # bnb params are flattened. | ||
214 | if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: |
In this situation, aren't we skipping parameter shape checks for bnb loaded weights entirely? What happens when one attempts to load bnb weights but the flattened shape is incorrect?
Perhaps we add a check_quantized_param_shape
method to the DiffusersQuantizer base class. And in the BnBQuantizer we can check if the shape matches the rule here:
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/18e827d666fa2b70a12d539ccedc17aa51b2c97c/bitsandbytes/functional.py#L816
156 | 217 | f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." | |
157 | 218 | ) | |
158 | 219 | ||
159 | if accepts_dtype: | ||
160 | set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) | ||
220 | if not is_quantized or ( | ||
221 | not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device) | ||
222 | ): | ||
223 | if accepts_dtype: | ||
224 | set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) | ||
225 | else: | ||
226 | set_module_tensor_to_device(model, param_name, device, value=param) | ||
161 | 227 | else: | |
162 | set_module_tensor_to_device(model, param_name, device, value=param) | ||
228 | hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) | ||
229 |
Small nit. IMO this is a bit more readable
if is_quantized or hf_quantizer.check_quantized_param(
model, param, param_name, state_dict, param_device=device
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
134 | """adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization""" | ||
135 | return max_memory | ||
136 | |||
137 | def check_quantized_param( |
IMO check_is_quantized_param
or check_if_quantized_param
more explicitly conveys what this method does.
117 | |||
118 | |||
119 | class BnB4BitBasicTests(Base4bitTests): | ||
120 | def setUp(self): |
Would clear cache on setup as well.
It would be useful to rename llm_int8_skip_modules
or otherwise make it more clear that it is respected in both 4bit and 8bit mode, as currently the docs sound like skipped modules are only respected in 8 bit mode while the actual implementation suggests otherwise
Yeah I think the documentation should reflect this. I guess this is safe to do @SunMarc?
Yeah we should do that, would you like to update this @Ednaordinary ? We should also do it in transformers when it gets merged.
Sure, @SunMarc. I'll make a PR when I'm able. Should I refactor the parameter name and include a deprecation notice, or just include a note in the docs?
import torch
from diffusers import FluxFillPipeline,FluxTransformer2DModel
from diffusers.utils import load_image
from transformers import T5EncoderModel
import gc
image = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup.png")
mask = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup_mask.png")
nf4_model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"
prompt = "a cute dog in paris photoshoot"
def flush():
"""Wipes off memory."""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def bytes_to_giga_bytes(bytes):
return f"{(bytes / 1024 / 1024 / 1024):.3f}"
flush()
text_encoder_2 = T5EncoderModel.from_pretrained(
nf4_model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)
transformer = FluxTransformer2DModel.from_pretrained(
nf4_model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev",
text_encoder_2=text_encoder_2,
transformer=transformer,
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
image = pipe(
prompt="a white paper cup",
image=image,
mask_image=mask,
height=1632,
width=1232,
guidance_scale=30,
num_inference_steps=50,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
torch.cuda.empty_cache()
memory = bytes_to_giga_bytes(torch.cuda.memory_allocated())
print(f"{memory=} GB.")
image.save(f"flux-fill-dev.png")
But I get this error:
462 output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
464 # 3. Save state
465 ctx.state = quant_state
RuntimeError: mat1 and mat2 shapes cannot be multiplied (7854x384 and 64x3072)
You are using the wrong checkpoint for fill. It should be https://huggingface.co/diffusers/FLUX.1-Fill-dev-nf4.
Login to write a write a comment.
What does this PR do?
Come back later.
bitsandbytes
)bitsandbytes
)bitsandbytes
from_pretrained()
at theModelMixin
level and related changessave_pretrained()
Notes
QuantizationLoaderMixin
in #9174, I realized that is not an approach we can take because loading and saving a quantized model is very much baked into the arguments ofModelMixin.save_pretrained()
andModelMixin.from_pretrained()
. It is deeply entangled.device_map
, because for a pipeline, multiple device_maps can get ugly. This will be dealt with in a follow-up PR by @SunMarc and myself.No-frills code snippets
Serialization
Serialized checkpoint: https://huggingface.co/sayakpaul/flux.1-dev-nf4-with-bnb-integration.
NF4 checkpoints of Flux transformer and T5: https://huggingface.co/sayakpaul/flux.1-dev-nf4-pkg (has Colab Notebooks, too).
Inference