transformers
Support saving only PEFT adapter in checkpoints when using PEFT + FSDP
#28297
Merged

Support saving only PEFT adapter in checkpoints when using PEFT + FSDP #28297

AjayP13
AjayP131 year ago (edited 1 year ago)

What does this PR do?

Currently, both the full model weights (pytorch_model_fsdp.bin) and the PEFT adapter weights (adapter_model.safetensors) are saved when saving checkpoints when PEFT + FSDP is used (leading to unnecessary excessive disk usage and slower training due to saving large files).

These changes ensure only the PEFT adapters are saved/loaded by:

  • Using the newly added adapter_only parameter on save_fsdp_model and load_fsdp_model. This will ensure pytorch_model_fsdp.bin only contains the PEFT adapter weights v.s. the full model weights.
  • See the related PR in accelerate: huggingface/accelerate#2321

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@pacman100 @ArthurZucker @amyeroberts

ArthurZucker
ArthurZucker1 year ago
younesbelkada
younesbelkada approved these changes on 2024-01-08
younesbelkada1 year ago

Makes sense, thanks for the multiple fixes!

src/transformers/trainer.py
21812181 load_result = model.load_state_dict(state_dict, strict=True)
21822182 else:
2183 if _is_peft_model(model):
2183
if _is_peft_model(unwrap_model(model)):
younesbelkada1 year ago

weird that we never catched bugs here, in any case it is safe to put unwrap_model(model) here IMO

amyeroberts1 year ago👍 1

Indeed - @AjayP13 can we add a test which passes with these fixes but fails on main?

younesbelkada1 year ago

yes that would be great if possible!

AjayP131 year ago (edited 1 year ago)

I unfortunately don't have consistent access to a multi-GPU machine (for which this test would require since FSDP requires >=2 GPUs), however, I did test this fix when I was running a job on a multi-GPU machine.

It seems PEFT + FSDP pathway is largely untested though, there is one FSDP test here. Would be great to get tests for this pathway in general, but probably won't be able to on my side, if someone working on FSDP @huggingface could, that would be great.

younesbelkada1 year ago👍 1❤ 1

@AjayP13 the condition self.is_fsdp_enabled and not _is_peft_model(unwrap_model(model)): makes sure that inside the else block at the line 2182 we are not having a PEFT model, therefore the condition if _is_peft_model(unwrap_model(model)): will never be met, so I think we can remove that whole block. what do you think?

@amyeroberts indeed I think PEFT + FSDP is not well tested, I can take care of adding proper tests on the PEFT testing suite as we do test on transformers main, here: https://github.com/huggingface/peft/blob/main/tests/test_gpu_examples.py we do test PEFT + trainer extensively, I can add a test for FSDP as we do have access to multi-gpu envs on PEFT. Let me know what do you think

pacman100
pacman100 commented on 2024-01-09
pacman1001 year ago👍 1

Hello, saving both adapters and full weights is required to enable resuming the finetuning when using FSDP. So, that is necessary. If you don't want the flexibility of being able to resume the training, please use save_only_model training arguments which doesn't save the full model checkpoint, optimizer, scheduler and rng states.

Regarding loading the best model, can you make sure the loaded model is performing as expected because relaoding during finetuning would require the checkpoint to be in a format compatible with FSDP wrapper.

AjayP13
AjayP131 year ago (edited 1 year ago)

@pacman100 Thanks for the insight on this.

Edit: @pacman100 's right, I just tested this PR again:

  • PEFT + FSDP + Load Best Checkpoint at End: does work
  • PEFT + FSDP + Resume From Checkpoint: does not work
AjayP13
AjayP131 year ago (edited 1 year ago)❤ 1

@pacman100 @younesbelkada - The PR is ready for review again and this time it supports FSDP + PEFT + Resuming. It's now reliant on a PR I made over in accelerate: huggingface/accelerate#2321.

@younesbelkada I was able to rent a multi-GPU machine for development and testing, I've attached a test branch_test.zip file, that may help create an actual test case for the CI system. You can see the pytorch_model_fsdp.bin is only 12.15MB now in the test output log (vs. the size of the full model weights).

Here is the test information:

Summary of test:

  1. Pull in changes from transformers PR and accelerate PR via requirements.txt
  2. Train for 99 epochs
  3. Resume from Checkpoint
  4. Train for 1 Additional Epoch
  5. Load Best Checkpoint
  6. Test Final Trained Adapter Weights
Test Dataset
dataset = Dataset.from_dict({
  "text": [
      "Input: A Output: 14",
      "Input: B Output: 52",
      "Input: C Output: 83",
      "Input: D Output: 57",
      "Input: E Output: 38",
      "Input: F Output: 54",
  ]
})
Output Log of Training / Test
Installing requirements.txt...
========= preparing dataset ============
========= loading models + tokenizer ============
========= preparing for training ============
========= training ============
{'loss': 6.8567, 'learning_rate': 0.000494949494949495, 'epoch': 1.0}
{'eval_loss': 5.881069183349609, 'eval_runtime': 0.06, 'eval_samples_per_second': 99.934, 'eval_steps_per_second': 16.656, 'epoch': 1.0}
{'loss': 5.986, 'learning_rate': 0.0004898989898989899, 'epoch': 2.0}
{'eval_loss': 5.117019176483154, 'eval_runtime': 0.0795, 'eval_samples_per_second': 75.427, 'eval_steps_per_second': 12.571, 'epoch': 2.0}
{'loss': 5.2125, 'learning_rate': 0.0004848484848484849, 'epoch': 3.0}
{'eval_loss': 4.231423854827881, 'eval_runtime': 0.0595, 'eval_samples_per_second': 100.84, 'eval_steps_per_second': 16.807, 'epoch': 3.0}
{'loss': 4.6937, 'learning_rate': 0.0004797979797979798, 'epoch': 4.0}
{'eval_loss': 3.4078128337860107, 'eval_runtime': 0.0651, 'eval_samples_per_second': 92.211, 'eval_steps_per_second': 15.369, 'epoch': 4.0}
{'loss': 4.2886, 'learning_rate': 0.00047474747474747476, 'epoch': 5.0}
{'eval_loss': 2.664665937423706, 'eval_runtime': 0.0593, 'eval_samples_per_second': 101.123, 'eval_steps_per_second': 16.854, 'epoch': 5.0}
{'loss': 2.9141, 'learning_rate': 0.0004696969696969697, 'epoch': 6.0}
{'eval_loss': 2.032987117767334, 'eval_runtime': 0.0612, 'eval_samples_per_second': 97.965, 'eval_steps_per_second': 16.328, 'epoch': 6.0}
{'loss': 2.2988, 'learning_rate': 0.0004646464646464646, 'epoch': 7.0}
{'eval_loss': 1.591854453086853, 'eval_runtime': 0.0731, 'eval_samples_per_second': 82.099, 'eval_steps_per_second': 13.683, 'epoch': 7.0}
{'loss': 1.8455, 'learning_rate': 0.00045959595959595964, 'epoch': 8.0}
{'eval_loss': 1.2650986909866333, 'eval_runtime': 0.0669, 'eval_samples_per_second': 89.719, 'eval_steps_per_second': 14.953, 'epoch': 8.0}
{'loss': 1.7772, 'learning_rate': 0.00045454545454545455, 'epoch': 9.0}
{'eval_loss': 1.072060227394104, 'eval_runtime': 0.0671, 'eval_samples_per_second': 89.479, 'eval_steps_per_second': 14.913, 'epoch': 9.0}
{'loss': 1.4851, 'learning_rate': 0.0004494949494949495, 'epoch': 10.0}
{'eval_loss': 0.9297663569450378, 'eval_runtime': 0.0744, 'eval_samples_per_second': 80.613, 'eval_steps_per_second': 13.436, 'epoch': 10.0}
{'loss': 1.4598, 'learning_rate': 0.0004444444444444444, 'epoch': 11.0}
{'eval_loss': 0.8234553933143616, 'eval_runtime': 0.0691, 'eval_samples_per_second': 86.777, 'eval_steps_per_second': 14.463, 'epoch': 11.0}
{'loss': 1.3115, 'learning_rate': 0.0004393939393939394, 'epoch': 12.0}
{'eval_loss': 0.7762313485145569, 'eval_runtime': 0.0601, 'eval_samples_per_second': 99.775, 'eval_steps_per_second': 16.629, 'epoch': 12.0}
{'loss': 0.8095, 'learning_rate': 0.0004343434343434344, 'epoch': 13.0}
{'eval_loss': 0.7559001445770264, 'eval_runtime': 0.0611, 'eval_samples_per_second': 98.177, 'eval_steps_per_second': 16.363, 'epoch': 13.0}
{'loss': 0.8534, 'learning_rate': 0.0004292929292929293, 'epoch': 14.0}
{'eval_loss': 0.749554455280304, 'eval_runtime': 0.0649, 'eval_samples_per_second': 92.517, 'eval_steps_per_second': 15.42, 'epoch': 14.0}
{'loss': 0.8644, 'learning_rate': 0.00042424242424242425, 'epoch': 15.0}
{'eval_loss': 0.7470491528511047, 'eval_runtime': 0.0728, 'eval_samples_per_second': 82.379, 'eval_steps_per_second': 13.73, 'epoch': 15.0}
{'loss': 1.017, 'learning_rate': 0.00041919191919191916, 'epoch': 16.0}
{'eval_loss': 0.7447641491889954, 'eval_runtime': 0.0675, 'eval_samples_per_second': 88.867, 'eval_steps_per_second': 14.811, 'epoch': 16.0}
{'loss': 1.0205, 'learning_rate': 0.0004141414141414142, 'epoch': 17.0}
{'eval_loss': 0.7410345673561096, 'eval_runtime': 0.0605, 'eval_samples_per_second': 99.164, 'eval_steps_per_second': 16.527, 'epoch': 17.0}
{'loss': 1.0018, 'learning_rate': 0.00040909090909090913, 'epoch': 18.0}
{'eval_loss': 0.7377716898918152, 'eval_runtime': 0.0716, 'eval_samples_per_second': 83.775, 'eval_steps_per_second': 13.962, 'epoch': 18.0}
{'loss': 0.7863, 'learning_rate': 0.00040404040404040404, 'epoch': 19.0}
{'eval_loss': 0.7327172756195068, 'eval_runtime': 0.0611, 'eval_samples_per_second': 98.14, 'eval_steps_per_second': 16.357, 'epoch': 19.0}
{'loss': 0.7273, 'learning_rate': 0.000398989898989899, 'epoch': 20.0}
{'eval_loss': 0.7267041206359863, 'eval_runtime': 0.067, 'eval_samples_per_second': 89.487, 'eval_steps_per_second': 14.915, 'epoch': 20.0}
{'loss': 1.0861, 'learning_rate': 0.0003939393939393939, 'epoch': 21.0}
{'eval_loss': 0.7202563285827637, 'eval_runtime': 0.0623, 'eval_samples_per_second': 96.246, 'eval_steps_per_second': 16.041, 'epoch': 21.0}
{'loss': 1.0336, 'learning_rate': 0.0003888888888888889, 'epoch': 22.0}
{'eval_loss': 0.7143564820289612, 'eval_runtime': 0.0599, 'eval_samples_per_second': 100.172, 'eval_steps_per_second': 16.695, 'epoch': 22.0}
{'loss': 0.7619, 'learning_rate': 0.00038383838383838383, 'epoch': 23.0}
{'eval_loss': 0.7085524201393127, 'eval_runtime': 0.0694, 'eval_samples_per_second': 86.504, 'eval_steps_per_second': 14.417, 'epoch': 23.0}
{'loss': 1.0004, 'learning_rate': 0.0003787878787878788, 'epoch': 24.0}
{'eval_loss': 0.704052746295929, 'eval_runtime': 0.0658, 'eval_samples_per_second': 91.239, 'eval_steps_per_second': 15.206, 'epoch': 24.0}
{'loss': 0.6674, 'learning_rate': 0.00037373737373737375, 'epoch': 25.0}
{'eval_loss': 0.7021067142486572, 'eval_runtime': 0.0721, 'eval_samples_per_second': 83.256, 'eval_steps_per_second': 13.876, 'epoch': 25.0}
{'loss': 1.0418, 'learning_rate': 0.0003686868686868687, 'epoch': 26.0}
{'eval_loss': 0.7003771662712097, 'eval_runtime': 0.0624, 'eval_samples_per_second': 96.231, 'eval_steps_per_second': 16.038, 'epoch': 26.0}
{'loss': 0.7463, 'learning_rate': 0.00036363636363636367, 'epoch': 27.0}
{'eval_loss': 0.6978425979614258, 'eval_runtime': 0.076, 'eval_samples_per_second': 78.91, 'eval_steps_per_second': 13.152, 'epoch': 27.0}
{'loss': 0.9374, 'learning_rate': 0.0003585858585858586, 'epoch': 28.0}
{'eval_loss': 0.6954618096351624, 'eval_runtime': 0.0707, 'eval_samples_per_second': 84.814, 'eval_steps_per_second': 14.136, 'epoch': 28.0}
{'loss': 0.6975, 'learning_rate': 0.00035353535353535354, 'epoch': 29.0}
{'eval_loss': 0.6914337277412415, 'eval_runtime': 0.0755, 'eval_samples_per_second': 79.493, 'eval_steps_per_second': 13.249, 'epoch': 29.0}
{'loss': 0.7536, 'learning_rate': 0.0003484848484848485, 'epoch': 30.0}
{'eval_loss': 0.6854903697967529, 'eval_runtime': 0.0596, 'eval_samples_per_second': 100.615, 'eval_steps_per_second': 16.769, 'epoch': 30.0}
{'loss': 0.758, 'learning_rate': 0.00034343434343434346, 'epoch': 31.0}
{'eval_loss': 0.6792461276054382, 'eval_runtime': 0.0597, 'eval_samples_per_second': 100.564, 'eval_steps_per_second': 16.761, 'epoch': 31.0}
{'loss': 0.8828, 'learning_rate': 0.0003383838383838384, 'epoch': 32.0}
{'eval_loss': 0.6729412078857422, 'eval_runtime': 0.0691, 'eval_samples_per_second': 86.891, 'eval_steps_per_second': 14.482, 'epoch': 32.0}
{'loss': 0.9425, 'learning_rate': 0.0003333333333333333, 'epoch': 33.0}
{'eval_loss': 0.6672318577766418, 'eval_runtime': 0.0715, 'eval_samples_per_second': 83.967, 'eval_steps_per_second': 13.994, 'epoch': 33.0}
{'loss': 0.8739, 'learning_rate': 0.0003282828282828283, 'epoch': 34.0}
{'eval_loss': 0.6612274050712585, 'eval_runtime': 0.0691, 'eval_samples_per_second': 86.887, 'eval_steps_per_second': 14.481, 'epoch': 34.0}
{'loss': 0.6459, 'learning_rate': 0.00032323232323232324, 'epoch': 35.0}
{'eval_loss': 0.6533860564231873, 'eval_runtime': 0.069, 'eval_samples_per_second': 86.928, 'eval_steps_per_second': 14.488, 'epoch': 35.0}
{'loss': 0.71, 'learning_rate': 0.0003181818181818182, 'epoch': 36.0}
{'eval_loss': 0.6435515284538269, 'eval_runtime': 0.0622, 'eval_samples_per_second': 96.408, 'eval_steps_per_second': 16.068, 'epoch': 36.0}
{'loss': 0.6217, 'learning_rate': 0.00031313131313131316, 'epoch': 37.0}
{'eval_loss': 0.6283969283103943, 'eval_runtime': 0.0641, 'eval_samples_per_second': 93.673, 'eval_steps_per_second': 15.612, 'epoch': 37.0}
{'loss': 0.8286, 'learning_rate': 0.00030808080808080807, 'epoch': 38.0}
{'eval_loss': 0.6098210215568542, 'eval_runtime': 0.0664, 'eval_samples_per_second': 90.421, 'eval_steps_per_second': 15.07, 'epoch': 38.0}
{'loss': 0.8709, 'learning_rate': 0.00030303030303030303, 'epoch': 39.0}
{'eval_loss': 0.5880406498908997, 'eval_runtime': 0.0605, 'eval_samples_per_second': 99.113, 'eval_steps_per_second': 16.519, 'epoch': 39.0}
{'loss': 0.6882, 'learning_rate': 0.00029797979797979794, 'epoch': 40.0}
{'eval_loss': 0.5603983402252197, 'eval_runtime': 0.0663, 'eval_samples_per_second': 90.523, 'eval_steps_per_second': 15.087, 'epoch': 40.0}
{'loss': 0.8101, 'learning_rate': 0.00029292929292929295, 'epoch': 41.0}
{'eval_loss': 0.535556972026825, 'eval_runtime': 0.0738, 'eval_samples_per_second': 81.351, 'eval_steps_per_second': 13.558, 'epoch': 41.0}
{'loss': 0.988, 'learning_rate': 0.0002878787878787879, 'epoch': 42.0}
{'eval_loss': 0.5133206248283386, 'eval_runtime': 0.0652, 'eval_samples_per_second': 91.956, 'eval_steps_per_second': 15.326, 'epoch': 42.0}
{'loss': 0.7791, 'learning_rate': 0.0002828282828282828, 'epoch': 43.0}
{'eval_loss': 0.49268361926078796, 'eval_runtime': 0.0677, 'eval_samples_per_second': 88.604, 'eval_steps_per_second': 14.767, 'epoch': 43.0}
{'loss': 0.5676, 'learning_rate': 0.0002777777777777778, 'epoch': 44.0}
{'eval_loss': 0.466182678937912, 'eval_runtime': 0.0608, 'eval_samples_per_second': 98.746, 'eval_steps_per_second': 16.458, 'epoch': 44.0}
{'loss': 0.616, 'learning_rate': 0.00027272727272727274, 'epoch': 45.0}
{'eval_loss': 0.4505433738231659, 'eval_runtime': 0.0801, 'eval_samples_per_second': 74.905, 'eval_steps_per_second': 12.484, 'epoch': 45.0}
{'loss': 0.5265, 'learning_rate': 0.0002676767676767677, 'epoch': 46.0}
{'eval_loss': 0.43705224990844727, 'eval_runtime': 0.0742, 'eval_samples_per_second': 80.819, 'eval_steps_per_second': 13.47, 'epoch': 46.0}
{'loss': 0.5378, 'learning_rate': 0.00026262626262626266, 'epoch': 47.0}
{'eval_loss': 0.42776837944984436, 'eval_runtime': 0.0802, 'eval_samples_per_second': 74.852, 'eval_steps_per_second': 12.475, 'epoch': 47.0}
{'loss': 0.5289, 'learning_rate': 0.00025757575757575756, 'epoch': 48.0}
{'eval_loss': 0.41791656613349915, 'eval_runtime': 0.0617, 'eval_samples_per_second': 97.202, 'eval_steps_per_second': 16.2, 'epoch': 48.0}
{'loss': 0.6216, 'learning_rate': 0.0002525252525252525, 'epoch': 49.0}
{'eval_loss': 0.4046964645385742, 'eval_runtime': 0.0597, 'eval_samples_per_second': 100.523, 'eval_steps_per_second': 16.754, 'epoch': 49.0}
{'loss': 0.4525, 'learning_rate': 0.0002474747474747475, 'epoch': 50.0}
{'eval_loss': 0.3916676938533783, 'eval_runtime': 0.0599, 'eval_samples_per_second': 100.174, 'eval_steps_per_second': 16.696, 'epoch': 50.0}
{'loss': 0.4181, 'learning_rate': 0.00024242424242424245, 'epoch': 51.0}
{'eval_loss': 0.38267311453819275, 'eval_runtime': 0.0742, 'eval_samples_per_second': 80.905, 'eval_steps_per_second': 13.484, 'epoch': 51.0}
{'loss': 0.3541, 'learning_rate': 0.00023737373737373738, 'epoch': 52.0}
{'eval_loss': 0.3767699897289276, 'eval_runtime': 0.0595, 'eval_samples_per_second': 100.789, 'eval_steps_per_second': 16.798, 'epoch': 52.0}
{'loss': 0.5992, 'learning_rate': 0.0002323232323232323, 'epoch': 53.0}
{'eval_loss': 0.3728148639202118, 'eval_runtime': 0.0597, 'eval_samples_per_second': 100.515, 'eval_steps_per_second': 16.752, 'epoch': 53.0}
{'loss': 0.614, 'learning_rate': 0.00022727272727272727, 'epoch': 54.0}
{'eval_loss': 0.3699725568294525, 'eval_runtime': 0.0667, 'eval_samples_per_second': 90.017, 'eval_steps_per_second': 15.003, 'epoch': 54.0}
{'loss': 0.4125, 'learning_rate': 0.0002222222222222222, 'epoch': 55.0}
{'eval_loss': 0.3682810962200165, 'eval_runtime': 0.0692, 'eval_samples_per_second': 86.711, 'eval_steps_per_second': 14.452, 'epoch': 55.0}
{'loss': 0.4785, 'learning_rate': 0.0002171717171717172, 'epoch': 56.0}
{'eval_loss': 0.3675954341888428, 'eval_runtime': 0.0728, 'eval_samples_per_second': 82.372, 'eval_steps_per_second': 13.729, 'epoch': 56.0}
{'loss': 0.5642, 'learning_rate': 0.00021212121212121213, 'epoch': 57.0}
{'eval_loss': 0.3668169677257538, 'eval_runtime': 0.0664, 'eval_samples_per_second': 90.378, 'eval_steps_per_second': 15.063, 'epoch': 57.0}
{'loss': 0.3966, 'learning_rate': 0.0002070707070707071, 'epoch': 58.0}
{'eval_loss': 0.3652132451534271, 'eval_runtime': 0.0612, 'eval_samples_per_second': 98.002, 'eval_steps_per_second': 16.334, 'epoch': 58.0}
{'loss': 0.539, 'learning_rate': 0.00020202020202020202, 'epoch': 59.0}
{'eval_loss': 0.36377087235450745, 'eval_runtime': 0.0622, 'eval_samples_per_second': 96.52, 'eval_steps_per_second': 16.087, 'epoch': 59.0}
{'loss': 0.5168, 'learning_rate': 0.00019696969696969695, 'epoch': 60.0}
{'eval_loss': 0.3619898557662964, 'eval_runtime': 0.062, 'eval_samples_per_second': 96.847, 'eval_steps_per_second': 16.141, 'epoch': 60.0}
{'loss': 0.655, 'learning_rate': 0.00019191919191919191, 'epoch': 61.0}
{'eval_loss': 0.36089444160461426, 'eval_runtime': 0.0609, 'eval_samples_per_second': 98.531, 'eval_steps_per_second': 16.422, 'epoch': 61.0}
{'loss': 0.528, 'learning_rate': 0.00018686868686868687, 'epoch': 62.0}
{'eval_loss': 0.36035969853401184, 'eval_runtime': 0.0601, 'eval_samples_per_second': 99.776, 'eval_steps_per_second': 16.629, 'epoch': 62.0}
{'loss': 0.3566, 'learning_rate': 0.00018181818181818183, 'epoch': 63.0}
{'eval_loss': 0.3601444661617279, 'eval_runtime': 0.0758, 'eval_samples_per_second': 79.129, 'eval_steps_per_second': 13.188, 'epoch': 63.0}
{'loss': 0.3974, 'learning_rate': 0.00017676767676767677, 'epoch': 64.0}
{'eval_loss': 0.36019062995910645, 'eval_runtime': 0.0608, 'eval_samples_per_second': 98.757, 'eval_steps_per_second': 16.459, 'epoch': 64.0}
{'loss': 0.3913, 'learning_rate': 0.00017171717171717173, 'epoch': 65.0}
{'eval_loss': 0.36041781306266785, 'eval_runtime': 0.0647, 'eval_samples_per_second': 92.715, 'eval_steps_per_second': 15.453, 'epoch': 65.0}
{'loss': 0.4687, 'learning_rate': 0.00016666666666666666, 'epoch': 66.0}
{'eval_loss': 0.36066195368766785, 'eval_runtime': 0.0616, 'eval_samples_per_second': 97.351, 'eval_steps_per_second': 16.225, 'epoch': 66.0}
{'loss': 0.4079, 'learning_rate': 0.00016161616161616162, 'epoch': 67.0}
{'eval_loss': 0.36099013686180115, 'eval_runtime': 0.0655, 'eval_samples_per_second': 91.608, 'eval_steps_per_second': 15.268, 'epoch': 67.0}
{'loss': 0.3759, 'learning_rate': 0.00015656565656565658, 'epoch': 68.0}
{'eval_loss': 0.3616194427013397, 'eval_runtime': 0.0769, 'eval_samples_per_second': 78.04, 'eval_steps_per_second': 13.007, 'epoch': 68.0}
{'loss': 0.4366, 'learning_rate': 0.00015151515151515152, 'epoch': 69.0}
{'eval_loss': 0.3618600070476532, 'eval_runtime': 0.0631, 'eval_samples_per_second': 95.107, 'eval_steps_per_second': 15.851, 'epoch': 69.0}
{'loss': 0.4562, 'learning_rate': 0.00014646464646464648, 'epoch': 70.0}
{'eval_loss': 0.36202383041381836, 'eval_runtime': 0.0693, 'eval_samples_per_second': 86.527, 'eval_steps_per_second': 14.421, 'epoch': 70.0}
{'loss': 0.4306, 'learning_rate': 0.0001414141414141414, 'epoch': 71.0}
{'eval_loss': 0.3623146116733551, 'eval_runtime': 0.0651, 'eval_samples_per_second': 92.107, 'eval_steps_per_second': 15.351, 'epoch': 71.0}
{'loss': 0.4461, 'learning_rate': 0.00013636363636363637, 'epoch': 72.0}
{'eval_loss': 0.36247682571411133, 'eval_runtime': 0.0602, 'eval_samples_per_second': 99.713, 'eval_steps_per_second': 16.619, 'epoch': 72.0}
{'loss': 0.5163, 'learning_rate': 0.00013131313131313133, 'epoch': 73.0}
{'eval_loss': 0.3626745045185089, 'eval_runtime': 0.0714, 'eval_samples_per_second': 83.979, 'eval_steps_per_second': 13.996, 'epoch': 73.0}
{'loss': 0.4319, 'learning_rate': 0.00012626262626262626, 'epoch': 74.0}
{'eval_loss': 0.36249294877052307, 'eval_runtime': 0.0643, 'eval_samples_per_second': 93.381, 'eval_steps_per_second': 15.563, 'epoch': 74.0}
{'loss': 0.3901, 'learning_rate': 0.00012121212121212122, 'epoch': 75.0}
{'eval_loss': 0.3623819351196289, 'eval_runtime': 0.082, 'eval_samples_per_second': 73.159, 'eval_steps_per_second': 12.193, 'epoch': 75.0}
{'loss': 0.3908, 'learning_rate': 0.00011616161616161616, 'epoch': 76.0}
{'eval_loss': 0.3622519075870514, 'eval_runtime': 0.0729, 'eval_samples_per_second': 82.254, 'eval_steps_per_second': 13.709, 'epoch': 76.0}
{'loss': 0.3609, 'learning_rate': 0.0001111111111111111, 'epoch': 77.0}
{'eval_loss': 0.3618755042552948, 'eval_runtime': 0.0659, 'eval_samples_per_second': 91.017, 'eval_steps_per_second': 15.169, 'epoch': 77.0}
{'loss': 0.3959, 'learning_rate': 0.00010606060606060606, 'epoch': 78.0}
{'eval_loss': 0.3615656793117523, 'eval_runtime': 0.0675, 'eval_samples_per_second': 88.871, 'eval_steps_per_second': 14.812, 'epoch': 78.0}
{'loss': 0.4189, 'learning_rate': 0.00010101010101010101, 'epoch': 79.0}
{'eval_loss': 0.3611675798892975, 'eval_runtime': 0.0701, 'eval_samples_per_second': 85.573, 'eval_steps_per_second': 14.262, 'epoch': 79.0}
{'loss': 0.4515, 'learning_rate': 9.595959595959596e-05, 'epoch': 80.0}
{'eval_loss': 0.36078181862831116, 'eval_runtime': 0.0745, 'eval_samples_per_second': 80.54, 'eval_steps_per_second': 13.423, 'epoch': 80.0}
{'loss': 0.3864, 'learning_rate': 9.090909090909092e-05, 'epoch': 81.0}
{'eval_loss': 0.3604678213596344, 'eval_runtime': 0.0645, 'eval_samples_per_second': 93.02, 'eval_steps_per_second': 15.503, 'epoch': 81.0}
{'loss': 0.3414, 'learning_rate': 8.585858585858586e-05, 'epoch': 82.0}
{'eval_loss': 0.36031660437583923, 'eval_runtime': 0.0598, 'eval_samples_per_second': 100.369, 'eval_steps_per_second': 16.728, 'epoch': 82.0}
{'loss': 0.4023, 'learning_rate': 8.080808080808081e-05, 'epoch': 83.0}
{'eval_loss': 0.3602430522441864, 'eval_runtime': 0.0759, 'eval_samples_per_second': 79.061, 'eval_steps_per_second': 13.177, 'epoch': 83.0}
{'loss': 0.3382, 'learning_rate': 7.575757575757576e-05, 'epoch': 84.0}
{'eval_loss': 0.36019110679626465, 'eval_runtime': 0.0668, 'eval_samples_per_second': 89.801, 'eval_steps_per_second': 14.967, 'epoch': 84.0}
{'loss': 0.392, 'learning_rate': 7.07070707070707e-05, 'epoch': 85.0}
{'eval_loss': 0.360186904668808, 'eval_runtime': 0.0853, 'eval_samples_per_second': 70.381, 'eval_steps_per_second': 11.73, 'epoch': 85.0}
{'loss': 0.5104, 'learning_rate': 6.565656565656566e-05, 'epoch': 86.0}
{'eval_loss': 0.36022713780403137, 'eval_runtime': 0.0673, 'eval_samples_per_second': 89.159, 'eval_steps_per_second': 14.86, 'epoch': 86.0}
{'loss': 0.3643, 'learning_rate': 6.060606060606061e-05, 'epoch': 87.0}
{'eval_loss': 0.3602248728275299, 'eval_runtime': 0.0629, 'eval_samples_per_second': 95.465, 'eval_steps_per_second': 15.911, 'epoch': 87.0}
{'loss': 0.3615, 'learning_rate': 5.555555555555555e-05, 'epoch': 88.0}
{'eval_loss': 0.3601932227611542, 'eval_runtime': 0.066, 'eval_samples_per_second': 90.975, 'eval_steps_per_second': 15.163, 'epoch': 88.0}
{'loss': 0.396, 'learning_rate': 5.0505050505050505e-05, 'epoch': 89.0}
{'eval_loss': 0.3601706922054291, 'eval_runtime': 0.064, 'eval_samples_per_second': 93.73, 'eval_steps_per_second': 15.622, 'epoch': 89.0}
{'loss': 0.3589, 'learning_rate': 4.545454545454546e-05, 'epoch': 90.0}
{'eval_loss': 0.3601481020450592, 'eval_runtime': 0.0596, 'eval_samples_per_second': 100.659, 'eval_steps_per_second': 16.777, 'epoch': 90.0}
{'loss': 0.4028, 'learning_rate': 4.0404040404040405e-05, 'epoch': 91.0}
{'eval_loss': 0.3601456880569458, 'eval_runtime': 0.0711, 'eval_samples_per_second': 84.444, 'eval_steps_per_second': 14.074, 'epoch': 91.0}
{'loss': 0.3747, 'learning_rate': 3.535353535353535e-05, 'epoch': 92.0}
{'eval_loss': 0.3601391017436981, 'eval_runtime': 0.0611, 'eval_samples_per_second': 98.232, 'eval_steps_per_second': 16.372, 'epoch': 92.0}
{'loss': 0.3753, 'learning_rate': 3.0303030303030306e-05, 'epoch': 93.0}
{'eval_loss': 0.3601249158382416, 'eval_runtime': 0.0682, 'eval_samples_per_second': 87.988, 'eval_steps_per_second': 14.665, 'epoch': 93.0}
{'loss': 0.397, 'learning_rate': 2.5252525252525253e-05, 'epoch': 94.0}
{'eval_loss': 0.36011412739753723, 'eval_runtime': 0.0629, 'eval_samples_per_second': 95.374, 'eval_steps_per_second': 15.896, 'epoch': 94.0}
{'loss': 0.399, 'learning_rate': 2.0202020202020203e-05, 'epoch': 95.0}
{'eval_loss': 0.36010661721229553, 'eval_runtime': 0.0609, 'eval_samples_per_second': 98.569, 'eval_steps_per_second': 16.428, 'epoch': 95.0}
{'loss': 0.4198, 'learning_rate': 1.5151515151515153e-05, 'epoch': 96.0}
{'eval_loss': 0.3600958287715912, 'eval_runtime': 0.0704, 'eval_samples_per_second': 85.187, 'eval_steps_per_second': 14.198, 'epoch': 96.0}
{'loss': 0.3618, 'learning_rate': 1.0101010101010101e-05, 'epoch': 97.0}
{'eval_loss': 0.3600878417491913, 'eval_runtime': 0.0691, 'eval_samples_per_second': 86.839, 'eval_steps_per_second': 14.473, 'epoch': 97.0}
{'loss': 0.4199, 'learning_rate': 5.050505050505051e-06, 'epoch': 98.0}
{'eval_loss': 0.3600861728191376, 'eval_runtime': 0.0622, 'eval_samples_per_second': 96.43, 'eval_steps_per_second': 16.072, 'epoch': 98.0}
{'loss': 0.3681, 'learning_rate': 0.0, 'epoch': 99.0}
{'eval_loss': 0.3600858747959137, 'eval_runtime': 0.0644, 'eval_samples_per_second': 93.18, 'eval_steps_per_second': 15.53, 'epoch': 99.0}
{'train_runtime': 201.2614, 'train_samples_per_second': 2.951, 'train_steps_per_second': 0.492, 'train_loss': 0.9207682233266156, 'epoch': 99.0}
========= preparing dataset ============
========= loading models + tokenizer ============
========= preparing for training ============
========= resuming ============
{'loss': 0.3515, 'learning_rate': 0.0, 'epoch': 100.0}
{'eval_loss': 0.3600858747959137, 'eval_runtime': 0.0703, 'eval_samples_per_second': 85.379, 'eval_steps_per_second': 14.23, 'epoch': 100.0}
{'train_runtime': 3.0627, 'train_samples_per_second': 195.905, 'train_steps_per_second': 32.651, 'train_loss': 0.0035148251056671144, 'epoch': 100.0}
========= load best adapter at: ./output/checkpoint-99 ============
pytorch_model_fsdp.bin size: ./output/checkpoint-99: 12.15 MB
[{'generated_text': 'Input: A Output:  14'}]
[{'generated_text': 'Input: B Output:  Output'}]
[{'generated_text': 'Input: C Output:  83'}]
[{'generated_text': 'Input: D Output:  57'}]
[{'generated_text': 'Input: E Output:  38'}]
[{'generated_text': 'Input: F Output:  54'}]
younesbelkada
younesbelkada commented on 2024-01-10
younesbelkada1 year ago

Thanks a lot for all your great work on this!
I left one comment, I think currently if a user do not have the correct accelerate version FSDP will fail because adapter_only is not present in that method's signature for earlier accelerate versions. What do you think ?
Also let's first merge the accelerate PR before merging this one

Conversation is marked as resolved
Show resolved
src/transformers/trainer.py
younesbelkada1 year ago

You need to first check if adapter_only is correctly in load_fsdp_model signature, and if that's the case, pass adapter_only=True, if not, maybe throw a warning message with logger.warning() advising users to upgrade accelerate to be able to properly save adapters with FSDP.
You can check a simple example on how to check if an arg is in a method signature for example here: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L746 - ideally we want to check that once and not multiple times

AjayP131 year ago

Thanks @younesbelkada I was imagining the accelerate PR is merged before this one. But that works too.

younesbelkada1 year ago

yes but I meant for compatibility for users that have previous versions of accelerate. If I have accelerate==0.21.0 and transformers with your patch I'll get an error because you always pass adapter_only to load_fsdp_model, does that makes sense?

AjayP131 year ago❤ 1

Yes, sorry, I meant I was imagining the dependency requirement would get bumped to the latest accelerate, but I'll work on inspecting the function signature on this PR, I think that might be a better solution to keep backwards compatibility.

younesbelkada1 year ago👍 1

ahh I see, no usually we don't bump libraries version to ensure BC, indeed it shouldn't be that hard to inspect the method signature, let me know if you need any help !

AjayP131 year ago

This is now done @younesbelkada. Thanks for the pointer. I've also tested it with the branch_test.zip (that pulled in the latest changes from this PR and the changes in the accelerate PR).

younesbelkada1 year ago

perfect, thanks!

pacman100
pacman100
pacman100 commented on 2024-01-18
pacman1001 year ago

Thank you @AjayP13 for these changes. As replied on the other PR, if this is not meant when using FULL_STATE_DICT, the argument adapter_only should be set accordingly.

AjayP13
AjayP13
pacman100
pacman100 approved these changes on 2024-01-19
pacman1001 year ago

Thank you @AjayP13 for reducing the storage requirements when using FSDP+PEFT! 🔥

amyeroberts
amyeroberts approved these changes on 2024-01-22
amyeroberts1 year ago

Thanks for adding!

amyeroberts
younesbelkada
younesbelkada approved these changes on 2024-01-22
younesbelkada1 year ago (edited 1 year ago)

Thanks ! Can you just push an empty commit to trigger the CI (failing tests seem to be flaky ones but just in case)
EDIT: please refer to @amyeroberts 's comment

AjayP13 AjayP13 force pushed 1 year ago
AjayP13 Update trainer.py
24bd72cb
AjayP13 Revert "Update trainer.py"
fe23fa4b
AjayP13 Make trainer.py use adapter_only=True when using FSDP + PEFT
290d6287
AjayP13 Support load_best_model with adapter_only=True
2c3ad747
AjayP13 Ruff format
60323d31
AjayP13 Inspect function args for save_ load_ fsdp utility functions and only…
6c052b44
AjayP13 AjayP13 force pushed to 6c052b44 1 year ago
AjayP13
younesbelkada
AjayP13
amyeroberts amyeroberts merged a055d09e into main 1 year ago

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone