Generally, this looks like a reasonable change. Do you have a small script to reproduce the error? I wonder if we can craft a small unit test.
@BenjaminBossan
Hello! Yeah, I have used accelerate to launch training
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file=config.yaml --main_process_port=12355 train.py --output_dir=./save
Accelerate config:
# config.yaml
compute_environment: LOCAL_MACHINE
debug: true
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: NO_PREFETCH
fsdp_forward_prefetch: false
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Training script:
# train.py
import argparse
import logging
import os
from functools import partial
import pandas as pd
import torch
import torch.distributed as dist
from datasets import Dataset
from peft import LoraConfig
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
logging.basicConfig(level=logging.DEBUG, filename="logs.log", format="%(asctime)s %(levelname)s %(message)s")
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ['TORCH_NCCL_ENABLE_MONITORING'] = '0'
def get_data(tokenizer):
data = [{
'user_message': "Hi, how are you?",
'model_message': "I'm good, thanks. How about you?"
}] * 20
data = Dataset.from_list(data)
data = data.train_test_split(train_size=0.7, shuffle=True, seed=42)
tmp = data['test'].train_test_split(test_size=0.6, shuffle=True, seed=143)
data['validation'] = tmp['train']
data['test'] = tmp['test']
def tokenize(x):
messages = [
{'role': 'user', "content": x['user_message']},
{'role': 'assistant', "content": x['model_message']},
]
text = tokenizer.decode(tokenizer.apply_chat_template(messages))
result = tokenizer(text, return_tensors='pt')
sep = '<|im_start|>assistant\n'
input_text = text.split(sep)[0] + sep
input_len = len(tokenizer(input_text)['input_ids'])
result['labels'] = result['input_ids'].clone().detach()
result['labels'][:, :input_len] = -100
return {k: v.tolist()[0] for k, v in result.items()}
tokenized_datasets = data.map(
tokenize,
remove_columns=['user_message', 'model_message'],
)
tokenized_datasets.set_format('torch')
return tokenized_datasets
def collate_fn(data, pad_token_id):
input_ids, labels = tuple([x[key] for x in data] for key in ('input_ids', 'labels'))
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
return {
'input_ids': input_ids,
'labels': labels,
'attention_mask': input_ids.ne(pad_token_id) * 1
}
def print_trainable_parameters(model):
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(f'trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}')
def training_function(args):
model_name = 'Qwen/Qwen2.5-1.5B-Instruct'
training_args = TrainingArguments(
output_dir=args.output_dir,
gradient_checkpointing=True,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
num_train_epochs=1,
save_strategy='no',
seed=42,
data_seed=42,
optim='adamw_8bit'
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
data = get_data(tokenizer)
model = AutoModelForCausalLM.from_pretrained(
model_name,
return_dict=True,
)
model.add_adapter(LoraConfig(
r=16,
lora_alpha=16,
lora_dropout=0.1,
target_modules=['q_proj', 'k_proj']
))
trainer = Trainer(
model=model,
args=training_args,
train_dataset=data['train'],
eval_dataset=data['validation'],
data_collator=partial(collate_fn, pad_token_id=tokenizer.pad_token_id),
)
if trainer.accelerator.is_main_process:
print_trainable_parameters(model)
trainer.train()
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model()
def main():
parser = argparse.ArgumentParser(description='Main training script.')
parser.add_argument(
'--output_dir',
type=str,
default='.',
help='Optional save directory where all checkpoint folders will be stored. Default is the current working directory.'
)
args = parser.parse_args()
training_function(args)
if __name__ == '__main__':
main()
Environment:
accelerate==1.1.1
torch==2.5.1+cu124
pandas==2.2.3
peft==0.13.2
datasets==3.1.0
transformers==4.46.3
tqdm==4.67.1
If you run this command without the proposed changes, it will crash after half an hour with the following error:
[rank0]:[E1218 12:49:09.086536415 ProcessGroupNCCL.cpp:616] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=517, OpType=_ALLGATHER_BASE, NumelIn=11718912, NumelOut=46875648, Timeout(ms)=1800000) ran for 1800046 milliseconds before timing out.
[rank0]:[E1218 12:49:09.086960221 ProcessGroupNCCL.cpp:1785] [PG ID 0 PG GUID 0(default_pg) Rank 0] Exception (either an error or timeout) detected by watchdog at work: 517, last enqueued NCCL work: 517, last completed NCCL work: 516.
[rank0]:[E1218 12:49:09.630929998 ProcessGroupNCCL.cpp:1834] [PG ID 0 PG GUID 0(default_pg) Rank 0] Timeout at NCCL work: 517, last enqueued NCCL work: 517, last completed NCCL work: 516.
[rank0]:[E1218 12:49:09.630952723 ProcessGroupNCCL.cpp:630] [Rank 0] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank0]:[E1218 12:49:09.630959550 ProcessGroupNCCL.cpp:636] [Rank 0] To avoid data inconsistency, we are taking the entire process down.
[rank0]:[E1218 12:49:09.632257057 ProcessGroupNCCL.cpp:1595] [PG ID 0 PG GUID 0(default_pg) Rank 0] Process group watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=517, OpType=_ALLGATHER_BASE, NumelIn=11718912, NumelOut=46875648, Timeout(ms)=1800000) ran for 1800046 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f3ce429f446 in /mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7f3ce55b2772 in /mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f3ce55b9bb3 in /mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f3ce55bb61d in /mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7f3d3401d5c0 in /mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x8609 (0x7f3d3f9d0609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #6: clone + 0x43 (0x7f3d3fb0a133 in /lib/x86_64-linux-gnu/libc.so.6)
terminate called after throwing an instance of 'c10::DistBackendError'
what(): [PG ID 0 PG GUID 0(default_pg) Rank 0] Process group watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=517, OpType=_ALLGATHER_BASE, NumelIn=11718912, NumelOut=46875648, Timeout(ms)=1800000) ran for 1800046 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f3ce429f446 in /mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7f3ce55b2772 in /mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f3ce55b9bb3 in /mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f3ce55bb61d in /mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7f3d3401d5c0 in /mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x8609 (0x7f3d3f9d0609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #6: clone + 0x43 (0x7f3d3fb0a133 in /lib/x86_64-linux-gnu/libc.so.6)
Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1601 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f3ce429f446 in /mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe4271b (0x7f3ce522871b in /mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0x145c0 (0x7f3d3401d5c0 in /mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x8609 (0x7f3d3f9d0609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #4: clone + 0x43 (0x7f3d3fb0a133 in /lib/x86_64-linux-gnu/libc.so.6)
E1218 12:49:10.390000 407160 torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: -6) local_rank: 0 (pid: 407232) of binary: /mnt/data/a.kudisov/transformers/.venv/bin/python
Traceback (most recent call last):
File "/mnt/data/a.kudisov/transformers/.venv/bin/accelerate", line 8, in <module>
sys.exit(main())
^^^^^^
File "/mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
args.func(args)
File "/mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/accelerate/commands/launch.py", line 1155, in launch_command
multi_gpu_launcher(args)
File "/mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/accelerate/commands/launch.py", line 793, in multi_gpu_launcher
distrib_run.run(args)
File "/mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/distributed/run.py", line 910, in run
elastic_launch(
File "/mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/data/a.kudisov/transformers/.venv/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
=======================================================
train.py FAILED
-------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
-------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2024-12-18_12:49:10
host : ...
rank : 0 (local_rank: 0)
exitcode : -6 (pid: 407232)
error_file: <N/A>
traceback : Signal 6 (SIGABRT) received by PID 407232
=======================================================
But if you apply the code changes, then the program will complete successfully in a few seconds
Hello! Yeah, I have used accelerate to launch training
Thanks for the code, I can confirm that the process hangs when PEFT tries to call model.state_dict()
, and that your solution resolves the problem.
Crafting a unit test from this would require a small model and lowering the timeout or else a failing test would hang too long. Also, it would need to run on a multi GPU environment. Maybe not worth the effort.
Probably, what we should do in PEFT is that the state_dict
can be correctly collected even if FSDP is being used. I tried to use the code from accelerate:
However, this still did not resolve the issue, I'm not sure why that is.
Hopefully @muellerzr can help here.
@BenjaminBossan @muellerz
Hello! Could we merged this code? What are the next steps?
From my side, I don't have anything to add. Hopefully, Zach can share his thoughts, though he is a bit busy this week.
@muellerzr Could you please take a look at this PR?
I think that it should be safe to merge !
LGTM
Login to write a write a comment.
What does this PR do?
FSDP + PEFT (lora) learning process crashes (
Watchdog caught collective operation timeout
) when trainer tries to get peft state dict (via get_peft_model_state_dict function from huggingface/peft) and doesn't provide it with model's state dict. Therefore the function tries to get this data on its own and freezes because it has no information about FSDPThis PR addresses this issue.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@SunMarc
@muellerzr