peft
adds multiple adapters to a peft model
#133
Closed

adds multiple adapters to a peft model #133

edbeeching wants to merge 6 commits into huggingface:main from edbeeching:multiple-adapters
edbeeching
edbeeching2 years ago๐ŸŽ‰ 3โค 5

Draft PR to add multiple adapters to a base model.
Implemented for 8bit layers at the moment, I will attempt the implementation for merged layers tomorrow.

See examples/multiple_adapters.py for a very brief demo.

Early feedback would be great! cc @pacman100

pacman100
pacman100 commented on 2023-02-25
pacman1002 years ago (edited 2 years ago)

Hello, thank you @edbeeching for initiating the work on supporting multiple LoRA modules for a given base model,super helpful for a lot of use cases ๐Ÿ”ฅ.

Here are my thoughts:

  1. Current implementation is great to get started quickly with multiple LoRA modules. However, it isn't generic in the below sense:
    a. Users won't be able to apply LoRA layers to different number of layers of the base model, for example, if a user wants one LoRA to only target last 6 layers and the another LoRA module to target last 3 layers. Here, same target layers are applied to all the LoRA modules.
    b. Users won't be able to have finer control of things like LoRA rank r, dropout across different LoRA modules

Point (a) is difficult to implement in a clean way and it is fine as LoRA adapters are usually applied to all layers and are different from last_k layer tuning.

Point (b) can be implemented by taking a list of ranks r, alphas and dropouts each corresponding to one set of LoRA modules.

Conversation is marked as resolved
Show resolved
src/peft/peft_model.py
306 self.base_model.enable_adapter_layers_index(index)
307
308 def disable_adapter_index(self, index=0):
309
self.base_model.enable_adapter_layers_index(index)
pacman1002 years ago
Suggested change
self.base_model.enable_adapter_layers_index(index)
self.base_model.disable_adapter_layers_index(index)
edbeeching
edbeeching2 years ago

I have updated the PR to use Tuples for r, dropout and alpha.

I was thinking that a better API for the user could be something like the following:

lora_config = LoraConfig(task_type="CAUSAL_LM")

model = prepare_model_for_int8_training(model)
model = get_peft_model(model, lora_config)

adaptor_config_a = AdapterConfig(r=8, target_modules=["q_proj", "v_proj"], ...)
model.add_adaptor("finetune_task_a", adapter_config_a)
model.enable_adapter("finetune_task_a")
# train on task a

adaptor_config_b = AdapterConfig(r=16, target_modules="a regex", ...)
model.add_adaptor("finetune_task_b", adapter_config_b)
model.enable_adapter("finetune_task_b")
# train on task b

Ths should resolve point a. The adapter layers could contain ModuleDicts rather than ModuleLists to make it clearer to the user which adapter they are enabling. All the _find_and_replace() logic would take place during the add_adapter() call.

This is a larger change. Beyond the scope of this PR, what do you think?

As for this PR, the next step would be to update the MergedLinear layers. Let me know what you think of the PR so far and then I can work on that.

pacman100
pacman1002 years ago

I have updated the PR to use Tuples for r, dropout and alpha.

Hello, Thank you for this change. However, this isn't backward compatible and would be a breaking change. What I envisioned was that r, dropout and alpha would be Union[Tuple[dtype], dtype] where dtype is int/float. In post_init of the config, when a tuple is passed the length should be equal to the n_adapters value else a tuple is created with that length self.r = tuple([self.r]*self.n_adapters).

Please let me know your thoughts?

Also, the change you are suggesting to enable (a) is bigger change and indeed a better way, I'll look into it

edbeeching
edbeeching2 years ago๐Ÿ‘ 1

Good point, I have updated the PR.

edbeeching
edbeeching2 years ago

Regarding the larger change I was suggesting, I don't think it is a lot of work, but it would probably be a breaking change.

edbeeching adds multiple adapters for 8bit lora layers
1922fc37
edbeeching adds tuples for r, dropout and alpha
779befb0
edbeeching makes additional adapters backwards compatible
25fc7222
edbeeching Adds multiple adapter support to merged layers, linear layers
60e79af6
edbeeching edbeeching force pushed from bc4a4988 to 60e79af6 2 years ago
edbeeching adds tests
3fc947d8
edbeeching fixing tests
03568ca7
edbeeching
edbeeching2 years ago

@pacman100 , I have updated the PR to include multiple adapted support for the other merged layers etc.
I started making some tests and I noticed that all the original tests fail on my machine, do they pass on yours? Is there CI workflow set up for this repo yet, if not I can create a PR for it?

pacman100
pacman1002 years ago

Hello @edbeeching, thank you for the recent changes. I'll check the current tests tomorrow and fix the failing ones.

It would be great if you could raise PR for the CI integration ๐Ÿค—

edbeeching edbeeching marked this pull request as ready for review 2 years ago
totaltube
totaltube2 years ago๐Ÿ‘ 3

Would be great to be able to train adapters independently as usual and to load them together, with ability to quickly switch between them.

edbeeching
edbeeching2 years ago

closing as superceded by #263

edbeeching edbeeching closed this 2 years ago

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone