transformers
fa3c3f9c - Break weight tying when quantizing input embedding (#37905)

Commit
225 days ago
Break weight tying when quantizing input embedding (#37905) Summary: Currently when we try to quantize input_embedding for some models, the output embedding (lm_head) will also be quantized the same way, since they are tied, and this may not be what we want. To break the tie, we added the option to allow people to 1. load unquantized weight 2. tie weights 3. quantize so that the tie will be broken Test Plan: ``` from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoTokenizer, TorchAoConfig, ) from torchao.quantization.quant_api import ( IntxWeightOnlyConfig, Int8DynamicActivationIntxWeightConfig, AOPerModuleConfig ) from torchao.quantization.granularity import PerGroup, PerAxis import torch model_id = "microsoft/Phi-4-mini-instruct" embedding_config = IntxWeightOnlyConfig( weight_dtype=torch.int8, granularity=PerAxis(0), ) linear_config = Int8DynamicActivationIntxWeightConfig( weight_dtype=torch.int4, weight_granularity=PerGroup(32), weight_scale_dtype=torch.bfloat16, ) quant_config = AOPerModuleConfig({"_default": linear_config, "model.embed_tokens": embedding_config}) quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True) quantized_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="auto", quantization_config=quantization_config) tokenizer = AutoTokenizer.from_pretrained(model_id) print(quantized_model) print("embed_tokens.weight:", quantized_model.model.embed_tokens.weight) print("lm head weight:", quantized_model.lm_head.weight) from transformers.modeling_utils import find_tied_parameters print(find_tied_parameters(quantized_model)) ``` Reviewers: Subscribers: Tasks: Tags: Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Author
Parents
Loading