peft
Combine multiple (IA)^3 Adapters and delete (IA)^3 adapters
#980
Closed

Combine multiple (IA)^3 Adapters and delete (IA)^3 adapters #980

alexrs wants to merge 8 commits into huggingface:main from alexrs:multi-ia3
alexrs
alexrs1 year ago (edited 1 year ago)

Problem

$(IA)^3$ models supports multiple adapters, however, there are a few missing features (compared to LoRA):

  • Delete adapters: There is no method to delete an adapter from an IA3Model.
  • Add Weighted Adapter: There is no method to add a weighted combination of multiple adapters.

Solution

  • Add a delete_adapter method based on the one in LoraModel.
  • Add a add_weighted_adapter method based on the one in LoraModel. This method, however, is a simplified version of the LoRA one. As $(IA)^3$ injects trainable vectors, I have deleted both the cat and svd options.

Other minor modifications

  • I have deleted the self.scaling in IA3Layer as it was not used. I believe this was a leftover from the LoRA implementation this layer seems to be based on.

Discussion

I have tested this locally using a simple script (see below) but I have not run any automated tests yet. What is the best way to test these changes?

Test script
import os

import datasets
import torch
from peft import PeftModel, PeftConfig
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)

def run_model(
    output_dir: str,
    cache_dir: str,
    dataset_dir: str,
    peft_strategy: str,
):
    # Load dataset from the hub
    datasets.config.DOWNLOADED_DATASETS_PATH = dataset_dir

    tokenizer = AutoTokenizer.from_pretrained(
        "meta-llama/Llama-2-7b-hf", cache_dir=cache_dir
    )

    quantization_config = BitsAndBytesConfig(load_in_8bit=True)

    config = PeftConfig.from_pretrained(os.path.join(output_dir, peft_strategy, "general"))

    model = AutoModelForCausalLM.from_pretrained(
        config.base_model_name_or_path,
        torch_dtype=torch.float,
        quantization_config=quantization_config,
        # device_map="auto",
        cache_dir=cache_dir,
    )

    model = PeftModel.from_pretrained(model, os.path.join(output_dir, peft_strategy, "general"))
    model.load_adapter(os.path.join(output_dir, peft_strategy, "code"), "code")
    model.load_adapter(os.path.join(output_dir, peft_strategy, "creative"), "creative")

    model.add_weighted_adapter(["code", "creative"], [0.5, 0.5], "code_creative")
    model.set_adapter("code_creative")

    prompt = f"""
    ### Input:
    What is the capital of Spain?

    ### Response:
    """

    input_ids = tokenizer(
        prompt, return_tensors="pt", truncation=True
    ).input_ids.cuda()

    outputs = model.generate(
        input_ids=input_ids,
        max_new_tokens=500,
        do_sample=True,
        top_p=0.9,
        temperature=0.9,
    )

    print(
        f"Response: \n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}"
    )


if __name__ == "__main__":
    run_model(
        output_dir="/path/to/output",
        cache_dir="/path/to/cache",
        dataset_dir="/path/to/dataset",
        peft_strategy = 'ia3',
    )
alexrs alexrs changed the title Combine multiple (IA)^3 Adapters add delete (IA)^3 adapters Combine multiple (IA)^3 Adapters and delete (IA)^3 adapters 1 year ago
BenjaminBossan
alexrs
HuggingFaceDocBuilderDev
alexrs alexrs force pushed from 583dcb77 to 096f412f 1 year ago
BenjaminBossan
BenjaminBossan requested changes on 2023-10-10
BenjaminBossan1 year ago

Thanks a lot for adding this feature. Sorry that it took longer to review. I have encountered a couple of issues, could please take a look? Thanks.

Conversation is marked as resolved
Show resolved
src/peft/tuners/ia3/model.py
308 Args:
309 adapter_name (str): Name of the adapter to be deleted.
310 """
311
if adapter_name not in list(self.peft_config.keys()):
BenjaminBossan1 year ago
Suggested change
if adapter_name not in list(self.peft_config.keys()):
if adapter_name not in self.peft_config:

I know this is just a 1:1 copy from LoRA, but let's do it better here and avoid the costly list.

Conversation is marked as resolved
Show resolved
src/peft/tuners/ia3/layer.py
BenjaminBossan1 year ago

Good catch.

src/peft/tuners/ia3/model.py
345 if adapter not in list(self.peft_config.keys()):
346 raise ValueError(f"Adapter {adapter} does not exist")
347
348
target_modules_type = type(self.peft_config[adapters[0]].target_modules)
349
new_target_modules = set() if target_modules_type == list else ""
350
feedforward_modules_type = type(self.peft_config[adapters[0]].feedforward_modules)
351
new_feedforward_modules = set() if feedforward_modules_type == list else ""
352
for adapter in adapters:
353
if type(self.peft_config[adapter].target_modules) != target_modules_type:
354
raise ValueError(
355
"all adapter configs should follow the same target modules type. "
356
"Combining adapters with `target_modules` type being a mix of list and string is not supported."
357
)
358
if target_modules_type == list:
359
new_target_modules |= set(self.peft_config[adapter].target_modules)
360
else:
361
new_target_modules += f"({self.peft_config[adapter].target_modules})|"
362
363
if type(self.peft_config[adapter].feedforward_modules) != feedforward_modules_type:
364
raise ValueError(
365
"all adapter configs should follow the same feedforward modules type. "
366
"Combining adapters with `feedforward_modules` type being a mix of list and string is not supported."
367
)
368
if feedforward_modules_type == list:
369
new_feedforward_modules |= set(self.peft_config[adapter].feedforward_modules)
370
else:
371
new_feedforward_modules += f"({self.peft_config[adapter].feedforward_modules})|"
372
373
new_target_modules = list(new_target_modules) if target_modules_type == list else new_target_modules[:-1]
374
new_feedforward_modules = (
375
list(new_feedforward_modules) if target_modules_type == list else new_feedforward_modules[:-1]
376
)
BenjaminBossan1 year ago

The whole logic has been refactored a bit for LoRA in #993 since you created this PR:

target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters]
if not target_module_types:
raise ValueError(f"Found no adapter matching the names in {adapters}")
if len(set(target_module_types)) > 1:
raise ValueError(
"all adapter configs should follow the same target modules type. "
"Combining adapters with `target_modules` type being a mix of list/set and string is not supported."
)
if target_module_types[0] == str:
new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters)
elif target_module_types[0] == set:
new_target_modules = reduce(
operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters)
)
else:
raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules")

Could you please adopt those changes here for consistency? Note that the type of target_modules and feedforward_modules has been changed from list to set (str is still valid).

alexrs1 year ago

Let me know if my changes are correct!

src/peft/tuners/ia3/model.py
BenjaminBossan1 year ago

I think this is not correct: When using IA³, the IA³ weights have to be multiplied, not added, right? I.e. they should be initialized as 1.0 and then each IA³ weight is multiplied on top, not added. See how it's accomplished in the forward method of IA³:

ia3_scaling = 1
for active_adapter in self.active_adapters:
if active_adapter not in self.ia3_l.keys():
continue
dtype = self.ia3_l[active_adapter].dtype
ia3_scaling *= self.ia3_l[active_adapter].flatten()

If this is correct, we encounter a second problem, namely that the weights argument makes little sense: Since we just multiply each IA³ weight and each weight from weights, due to commutativity, the order in weights doesn't matter. Whether a user passes weights=[2, 3] or weights=[3, 2] makes no difference.

We could still leave it as is for consistency, but I would be afraid that this would confuse many users. Instead, we could also 1) remove the weights argument entirely for IA³ or 2) only pass a single scalar to weights, which is applied once to all weights (could be set as the initial value). WDYT?

alexrs1 year ago (edited 1 year ago)

Thanks for the feedback and review!

When using IA³, the IA³ weights have to be multiplied, not added, right?

This is true in the forward pass. The learned vectors $l$ are multiplied with (in the case of attention) $K^T$ and $Q$. However, here we are not considering the Key and Value matrices, only learned vectors $l$ (as far as I understand), so my approach here was to compute a linear combination of the vectors (which is what we do in LoRA?).

Let's assume we have two adapters that target $K$ and $V$ with associated vectors $l_K$ and $l_V$, and weights [0.6, 0.4]. The way I wanted to combine this adapters on a new adapter was:

$l_K^{\text{new}} = l_K^1 * w_1 + l_K^2 * w_2$
$l_V^{\text{new}} = l_V^1 * w_1 + l_V^2 * w_2$

If we also target the FF layers, we would compute the resulting vector using the same procedure.

the weights argument makes little sense

If we multiply vectors, yes. However, that would not result in a linear combination of vectors, which was my goal.

Let me know if this makes sense!

BenjaminBossan1 year ago

Hmm, not sure. Let's work with scalars for a second. Let's say we have one IA³ weight with value 2 and one with value 3. As they are multiplied consecutively on the input, I would expect that we should multiply by 6, not by their sum 5. Am I missing something?

Anyway, I thought why not just test if the results are right or not. For this, I changed the test you added to do this instead:

        elif isinstance(config, (IA3Config)):
            model = get_peft_model(model, config, adapter_list[0])
            model = model.to(self.torch_device)
            dummy_input = self.prepare_inputs_for_testing()
            output0 = model(**dummy_input)[0]

            model.add_adapter(adapter_list[1], config)
            model.add_adapter(adapter_list[2], config)

            model.set_adapter(adapter_list)
            output1 = model(**dummy_input)[0]

            model.merge_adapter()
            output2 = model(**dummy_input)[0]

            model.unmerge_adapter()
            output3 = model(**dummy_input)[0]

            # using addition
            model.add_weighted_adapter(adapter_list, torch.ones(3) / 3, "merged-add")
            model.set_adapter("merged-add")
            output4 = model(**dummy_input)[0]

            # using multiplication
            model.add_weighted_adapter_mul(adapter_list, torch.ones(3), "merged-mul")
            model.set_adapter("merged-mul")
            output5 = model(**dummy_input)[0]

            assert not torch.allclose(output0, output1)
            torch.testing.assert_allclose(output1, output2)
            torch.testing.assert_allclose(output1, output3)
            torch.testing.assert_allclose(output1, output5)  # passes
            torch.testing.assert_allclose(output1, output4)  # fails

As you can see, we test the outputs from an IA³ model with the 3 adapters active but unmerged vs merged vs merged using add_weighted_adapter (your implementation) vs merged using add_weighted_adapter_mul (my implementation using multiply). When I run the tests, the multiply version passes but the addition version fails, which makes me think that multiplying is the way to go.

If you want to replicate this result, it will require a few steps because our code isn't really set up to work with multiple active adapters yet, so I had to make a few ad hoc changes to even get this far. I created a PR on top of your branch containing those changes:

https://github.com/alexrs/peft/pull/1/files

Obviously, it should not be merged, it's just to show you what steps I took. WDYT, is this plausible?

alexrs1 year ago (edited 1 year ago)

I see your point! However, I'm not sure this is consistent with the LoRA implementation. As far as I understand, there are two different scenarios here:
1. Stacking Adapters: When using set_adapter on multiple adapters, what we are doing is stacking adapters. That's how it works right now, and how it works in LoRA (I think!). This is equivalent to using combination_type=cat in LoRA's add_weighted_adapter (

if tuner_method == "lora":
# create a weighted adapter combining both adapters and check that
# its output is same as setting multiple active adapters
peft_model.add_weighted_adapter(
["adapter_1", "adapter_2"], [1.0, 1.0], "new_combined_adapter", combination_type="cat"
)
peft_model.set_adapter("new_combined_adapter")
new_combined_output = peft_model(**X)
self.assertTrue(torch.allclose(new_combined_output, combined_output, atol=1e-5))
)
2. Linear combination of Adapters: In this case, we are not stacking adapters but combining them to create a new adapter that is a linear combination of the input adapters and the input weights. This is equivalent to combination_type=linear in LoRA's add_weighted_adapter. If we change the code linked above to use linear, the test fails:

        if tuner_method == "lora":
            # create a weighted adapter combining both adapters and check that
            # its output is same as setting multiple active adapters
            peft_model.add_weighted_adapter(
                ["adapter_1", "adapter_2"], [1.0, 1.0], "new_combined_adapter", combination_type="linear"
            )
            peft_model.set_adapter("new_combined_adapter")
            new_combined_output = peft_model(**X)
            self.assertTrue(torch.allclose(new_combined_output, combined_output, atol=1e-5))

And same if we decide to give equal weight to both adapters to sum to 1:

            peft_model.add_weighted_adapter(
                ["adapter_1", "adapter_2"], [0.5, 0.5], "new_combined_adapter", combination_type="linear"
            )

I guess a solution is to add the different combination_types to $(IA)^3$'s add_weighted_adapter. Does this sound reasonable? Or do I have the wrong understanding of how this works?

BenjaminBossan1 year ago

Yes, you're right in the sense that for IA³, it is not quite clear how to interpret the combination of results. Unfortunately, I don't think that there is any existing evidence for IA³ for what the best way for combining adapters is. I agree that we could offer multiple methods and that hopefully, with time, the best method will emerge. When it comes to which default to choose, I'd argue it's a nice property to have the same output for combining the adapters as if they were all active at once, WDYT?

Another possibility that come to mind would be to go for geometric mean, which seems appropriate for a multiplicative operation, but it wouldn't work for negative numbers, so has to be ruled out.

When it comes to naming the combination types, the analogy to LoRA is a bit difficult, because the mathematical operation is different. I think for IA³ it is necessary to think from first principles.

alexrs1 year ago (edited 1 year ago)

Unfortunately, I don't think that there is any existing evidence for IA³ for what the best way for combining adapters is

Agreed.

I'd argue it's a nice property to have the same output for combining the adapters as if they were all active at once, WDYT?

That makes sense! But as discussed above, it is not how it works in LoRA by default, is it?

I guess the way to proceed is to allow both multiplication and linear combination methods using different combination_types, and setting the default to multiplication?

All in all, given that there is no evidence for what the best way for combining adapters is, I will try to run some experiments using both methods to get more clarity on this topic. Let me know if you have any suggestions or ideas for this!

BenjaminBossan1 year ago

That makes sense! But as discussed above, it is not how it works in LoRA by default, is it?

Yes, but we cannot really compare the two as I mentioned. E.g. it would not make sense to have an "svd" method for IA³, so I think we shouldn't put too much stress on consistency here.

I will try to run some experiments using both methods to get more clarity on this topic. Let me know if you have any suggestions or ideas for this!

That would be fantastic. Loading and combining multiple LoRAs is mostly a thing in image generation AFAIK, so that's probably what I would investigate, but I'm not sure how well IA³ lends itself to image generation in general.

alexrs Add add_weighted_adapter to ia3
149aeb35
alexrs Remove unused scaling from IA3
31237bda
alexrs Improve delete_adapter
eeb1b381
alexrs fix style
0e55abc0
alexrs revert test
e98e1480
alexrs add tests
46b3ad89
alexrs rebase on main
7f567a0f
alexrs Feedback from PR
43483b7a
alexrs alexrs force pushed from 096f412f to 43483b7a 1 year ago
alexrs
alexrs commented on 2023-10-10
BenjaminBossan
BenjaminBossan commented on 2023-10-12
BenjaminBossan1 year ago

Sorry for the late reply. I had written a review but somehow forgot to send it.

Could you please also fix the merge conflict?

alexrs
pacman100
pacman100 commented on 2023-11-09
pacman1001 year ago (edited 1 year ago)

Thank you @alexrs for adding new utils for IA3 🤗! The discussion between you and @BenjaminBossan was quite interesting and insightful. I do agree that combining IA3 adapters need to be thought from first principles as it follows multiplicative operators. At the same time, linear combination of IA3 adapters is a nice feature. I believe we can support it given that there is a clear documentation explaining how it is different from having multiple active adapters and it being the weighted average combination of IA3 adapters. Thank you for adding the support for deleting the iA3 adapters, useful when working with multiple adapters.

BenjaminBossan
alexrs
BenjaminBossan
alexrs
github-actions
github-actions github-actions closed this 1 year ago
BenjaminBossan
alexrs
BenjaminBossan

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone