transformers
Trainer - add cache clearing and the option for batched eval metrics computation
#28769
Merged

Trainer - add cache clearing and the option for batched eval metrics computation #28769

FoamoftheSea
FoamoftheSea1 year ago (edited 1 year ago)❤ 6👀 1

What does this PR do?

This PR does two things which are necessary for using the Trainer in resource constrained environments (like my RTX-3070Ti machine):

  1. Add cache clearing in training and evaluation loops
    • This reduces peak GPU load and prevents CUDA OOM errors when running near capacity.
  2. Add Trainer arg batch_eval_metrics for batched eval metrics computation.
    • When working with limited RAM, storing all logits across the entire evaluation set may not be feasible. A user working in this condition can pass True to batch_eval_metrics and construct a compute_metrics function which can update average metrics at a batch level to prevent OOM errors with large eval sets. Particularly useful for vision transformers.
    • Previous functionality is unaltered if option is not set to True

@muellerzr

FoamoftheSea Added cache clearing for GPU efficiency.
99e94ab1
FoamoftheSea Added cache clearing for GPU efficiency.
fc570304
FoamoftheSea Added batch_eval_metrics capability
548a26f4
FoamoftheSea Merge branch 'main' into trainer-updates
fa690a8d
FoamoftheSea Ran make fixup
f6350aaf
FoamoftheSea Merge remote-tracking branch 'origin/trainer-updates' into trainer-up…
006ecd23
FoamoftheSea Fixed bug
786046e0
HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev1 year ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

ArthurZucker
ArthurZucker1 year ago
FoamoftheSea Fixed whitespace issue
7752d130
FoamoftheSea Merge branch 'main' into trainer-updates
e155b0b4
FoamoftheSea
FoamoftheSea1 year ago

Hey everyone, I tried to look at the logs for the failed tests, but I don't see any actionable error reports. Can anyone help me figure out what needs to be done for them to pass?

ArthurZucker
ArthurZucker1 year ago👍 1

The main CI is a bit broken because of pytest package. Let's wait a bit here

ArthurZucker
ArthurZucker1 year ago

just re-ran the ci, you should actually rebase to main should be alright

ArthurZucker
ArthurZucker1 year ago👍 1

BTW @SunMarc would be nice if you can have a look as well!

FoamoftheSea Merge branch 'main' into trainer-updates
a26de727
FoamoftheSea
FoamoftheSea1 year ago

CIs are green after merging main ✔️

SunMarc
SunMarc commented on 2024-02-21
SunMarc1 year ago

Hi @FoamoftheSea, thanks for contributing ! I left a few comments to better understand how you are performing the batched metric computation. Can you also add tests to see if we get the same result with/without batched computation.

src/transformers/trainer.py
4030 if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
4031 is_last_step = step == len(dataloader) - 1
4032 if args.include_inputs_for_metrics:
4033
metrics = self.compute_metrics(
4034
EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host),
4035
compute_result=is_last_step,
4036
)
4037
else:
4038
metrics = self.compute_metrics(
4039
EvalPrediction(predictions=preds_host, label_ids=labels_host),
4040
compute_result=is_last_step,
4041
)
SunMarc1 year ago

You can't add a compute_result argument here.

FoamoftheSea1 year ago

This code path would only be used if the user set args.batch_eval_metrics to True, so only those trying to use this feature would need to worry about the expectation for this argument. I'm definitely open to other suggestions though.

FoamoftheSea1 year ago

Once we settle on the right solution I will write a test for it 🙏

SunMarc1 year ago

I will let @ArthurZucker and @muellerzr comment on that !

muellerzr1 year ago👍 1

This seems okay to me if this path is only ever when a user has this enabled. (We should maybe write a snippet/clarification about it in the docstring in TrainingArguments)

FoamoftheSea1 year ago❤ 1

Sounds good! I'll work on this and an appropriate test for it some time this week

src/transformers/trainer.py
4030 if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
4031 is_last_step = step == len(dataloader) - 1
4032 if args.include_inputs_for_metrics:
4033
metrics = self.compute_metrics(
4034
EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host),
4035
compute_result=is_last_step,
4036
)
4037
else:
4038
metrics = self.compute_metrics(
4039
EvalPrediction(predictions=preds_host, label_ids=labels_host),
4040
compute_result=is_last_step,
4041
)
SunMarc1 year ago

Are you computing + storing the batched metric inside self.compute_metrics (need to be a class with __call__ defined ) ?

FoamoftheSea1 year ago (edited 1 year ago)👍 1

I'm actually using a globally instanced metrics class that maintains state with update and compute methods inside of the compute_metrics function that is passed to the trainer, although there is probably a better solution, I found this was the least intrusive on the current expected behavior of the trainer code, since if the user does not activate the batched eval metrics option then they won't activate the code path where the compute_result argument is used, and therefore don't have to change anything about their previous compute_metrics functions.

Here's a very basic pseudo-code example of how I use this:

class MSEMetric:
    def __init__(self):
        self.batch_mse = []

    def update(self, preds, target):
        diff = target - preds
        batch_mse = np.mean(np.power(diff, 2))
        self.batch_mse.append(batch_mse)

    def compute(self):
        # Get result across entire eval set
        result = {"mse": np.mean(self.batch_mse)}
        # Reset batch statistics
        self.batch_mse = []
        return result

mse_metric = MSEMetric()
        
def compute_metrics(eval_pred, compute_result: bool = True) -> Optional[dict]:
    mse_metric.update(eval_pred.predictions, eval_pred.target)
        
    if compute_result:
        return mse_metric.compute()

# Use this compute_metrics fn in trainer
trainer = Trainer(compute_metrics=compute_metrics, ...)                    

This mirrors the update and compute methodology from the metrics classes in torcheval.metrics, which is where I got the inspiration from (see https://pytorch.org/torcheval/main/generated/torcheval.metrics.MeanSquaredError.html)

SunMarc1 year ago👍 2

Makes sense ! Thanks for explaining

src/transformers/trainer.py
34493453
34503454 self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
34513455
3456
if self.args.batch_eval_metrics:
3457
if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
3458
is_last_step = step == len(dataloader) - 1
3459
if args.include_inputs_for_metrics:
3460
metrics = self.compute_metrics(
3461
EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host),
3462
compute_result=is_last_step,
3463
)
3464
else:
3465
metrics = self.compute_metrics(
3466
EvalPrediction(predictions=preds_host, label_ids=labels_host),
3467
compute_result=is_last_step,
3468
)
3469
del losses_host, preds_host, inputs_host, labels_host
3470
torch.cuda.empty_cache()
3471
losses_host, preds_host, inputs_host, labels_host = None, None, None, None
3472
SunMarc1 year ago

same comment as above

FoamoftheSea
FoamoftheSea commented on 2024-02-22
Conversation is marked as resolved
Show resolved
src/transformers/trainer.py
4043 torch.cuda.empty_cache()
4044 losses_host, preds_host, inputs_host, labels_host = None, None, None, None
4045
4046
elif (
4047
args.eval_accumulation_steps is not None
4048
and (step + 1) % args.eval_accumulation_steps == 0
4049
and self.accelerator.sync_gradients
4050
):
FoamoftheSea1 year ago

It looks like the condition here was outdated, I updated it to the current version in main

Suggested change
elif (
args.eval_accumulation_steps is not None
and (step + 1) % args.eval_accumulation_steps == 0
and self.accelerator.sync_gradients
):
elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
FoamoftheSea Fixed outdated condition
a12e46f8
FoamoftheSea Merge branch 'main' into trainer-updates
137bd392
SunMarc
SunMarc commented on 2024-02-22
src/transformers/trainer.py
34513455
3456 if self.args.batch_eval_metrics:
3457 if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
3458
is_last_step = step == len(dataloader) - 1
SunMarc1 year ago👀 1

quick note, not every dataset have __len__() defined such as IterableDataset

FoamoftheSea1 year ago

I see, we'll need to find a more robust way to identify the last step of the eval set then... If anyone has an idea let me know

muellerzr1 year ago🎉 1

You should be able to use self.accelerator.gradient_state.end_of_dataloader here :)

muellerzr1 year ago
  • adding these changes :)
SunMarc SunMarc requested a review from muellerzr muellerzr 1 year ago
muellerzr
muellerzr commented on 2024-02-29
muellerzr1 year ago

Thanks! Overall this seems quite handy. If we can confirm that it does reduce your memory footprint without issue then I believe that's quite alright. Replied to comments from Marc's review. Let's apply those then I can give a checkmark on my end at least!

FoamoftheSea Merge branch 'main' into trainer-updates
994e7d06
github-actions
github-actions1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

FoamoftheSea
FoamoftheSea1 year ago🎉 2❤ 2🚀 3

Commenting to keep the PR fresh. I got super busy the past couple weeks but I will finish this soon.

github-actions
github-actions1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

ducha-aiki
ducha-aiki1 year ago🎉 1❤ 2

If we can confirm that it does reduce your memory footprint without issue then I believe that's quite alright.

I will check this today, thank you!

ducha-aiki
ducha-aiki1 year ago (edited 1 year ago)❤ 3

This works amazing!
Here is the RAM consumption before:

image

And after:

image

The only thing - PR needs to be updated to work with 4.40, because all_labels logic has changed since 4.39. I haven't tried to update PR to 4.40, tested on the its own branch only

muellerzr
muellerzr1 year ago

@FoamoftheSea if you can resolve the merge conflicts, we can land this 🚀

muellerzr
muellerzr approved these changes on 2024-04-25
muellerzr
muellerzr commented on 2024-04-25
src/transformers/trainer.py
40044028
4005 # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
4006 if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
4029 if self.args.batch_eval_metrics:
4030 if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
4031
is_last_step = step == len(dataloader) - 1
muellerzr1 year ago

Will need it here too

SunMarc
SunMarc approved these changes on 2024-04-25
SunMarc1 year ago

Thx for the your contribution !

src/transformers/training_args.py
692692 for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the
693693 [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
694694 `PeftModel` from peft.
695
696
batch_eval_metrics (`Optional[bool]`):
697
If set to True, evaluation will call compute_metrics at the end of each batch to accumulate statistics
698
rather than saving all eval logits in memory.
SunMarc1 year ago❤ 1

It would be great to add an example too. It will drive usage ! Do you have an idea where we could potentially put an example @muellerzr ? Otherwise, I think it is worth adding a link to this comment https://github.com/huggingface/transformers/pull/28769/files#r1498479868

FoamoftheSea Updated docstrings with instructions for batch_eval_metrics. Updated …
9b7c26e8
FoamoftheSea
FoamoftheSea1 year ago❤ 1

I am looking into this, it looks like the conflict is due to a different management of the variables in latest, which actually discards the use of the intermediate variable I was clearing from memory, so I want to double check on how that affects this change.

Also, I'm working on getting a test made for the batching functionality. Should have this ready soon.

FoamoftheSea Added first version of batch_eval_metrics Trainer test
760af34d
FoamoftheSea Merge branch 'main' into trainer-updates
4f87e817
FoamoftheSea Fixed batch_eval_metrics Trainer tests for both eval and predict
1fed1bcb
FoamoftheSea Fixed batch_eval_metrics behavior for new Trainer variables
5f44cb17
FoamoftheSea Fixed batch_eval_metrics Trainer tests
583d14b1
FoamoftheSea Ran fixup
f925e8c4
FoamoftheSea
FoamoftheSea1 year ago❤ 1

The pytests are ready and I've updated the code to work with the latest Trainer updates. I still need to run an AB test to see if the cache clearing is still providing any benefit with the new changes from main, since they seem to have the potential to handle the same issue. I just need to finish setting up a test case. Let's hold off on merging until we have some fresh test results.

FoamoftheSea
FoamoftheSea1 year ago (edited 1 year ago)❤ 3

The test results demonstrate 2 things:

  1. batch_eval_metrics is essential when training models with large solution spaces such as semantic segmentation where the typical code path accumulates many large tensors in memory during the eval loop, and batching the calculations avoids this.
  2. The CUDA cache clearing keeps maximum GPU utilization lower over time, and allows using a larger eval batch size to expedite the process.

Based on these results, I think that we can justify these changes.

Training details:

  • Model: nvidia/segformer-b0-finetuned-cityscapes-1024-1024
  • Dataset: Antreas/Cityscapes

System details:

  • Windows 10 Pro
  • GPU = NVidia Quadro T2000
  • CPU = Intel Xeon 2.80GHz (12 CPUs)
  • RAM = 32GB

In the following chart, we can see that the standard code path goes OOM and fails as it tries to store all of the dense logits in memory on either the GPU or the CPU (as in the case of using eval_accumulation_steps). Only the run using batch_eval_metrics survives the evaluation cycle without going OOM and failing.

Further, we can see there is a great boost in memory efficiency during the train and eval phases using the cache clearing, which leads to very slightly lower iteration speed, but leaves a lot of headroom for using a larger batch size during evaluation to compensate.

W B Chart 5_5_2024, 6_47_43 PM
W B Chart 5_5_2024, 6_48_29 PM
W B Chart 5_5_2024, 6_48_54 PM
W B Chart 5_5_2024, 6_49_16 PM

LysandreJik
LysandreJik1 year ago❤ 1

Awesome, great work @FoamoftheSea!

Approving the PR, @muellerzr feel free to merge at your convenience.

LysandreJik
LysandreJik approved these changes on 2024-05-06
muellerzr
muellerzr approved these changes on 2024-05-06
muellerzr1 year ago❤ 1

Thanks for all your hard work with this!

muellerzr muellerzr merged df475bf8 into main 1 year ago
Reveyer
Reveyer358 days ago👀 3

Hello, I've noticed that this pull request seems to slow down the speed when using the trainer, likely due to the frequent use of torch.cuda.empty_cache(). Is there a way to optimize this, or could we possibly have the option to choose whether or not to use torch.cuda.empty_cache()?"@muellerzr

muellerzr
muellerzr358 days ago (edited 358 days ago)

Do you have a small benchmark for us @Reveyer I haven’t noticed this yet when I was investigating another timing issue. But would be happy took into this

Reveyer
Reveyer358 days ago

@muellerzr I was able to reproduce this on the official code. Here’s the command I used:

CUDA_VISIBLE_DEVICES=0 python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path facebook/bart-large \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=16 \
    --per_device_eval_batch_size=16 \
    --overwrite_output_dir \
    --predict_with_generate \
    --num_train_epochs="10" \
    --seed="42"

Here are the results with and without torch.cuda.empty_cache() removed from the trainer:

  • With torch.cuda.empty_cache() removed:
    | 30/179450 [00:22<37:48:51, 1.32it/s]
  • With torch.cuda.empty_cache():
    | 30/179450 [00:23<39:22:17, 1.27it/s]

The discrepancy is even more significant in my personal code, which utilizes Llama-3:

  • With torch.cuda.empty_cache() removed:
    | 5/89543 [00:43<216:51:58, 8.72s/it]
  • With torch.cuda.empty_cache():
    | 5/89543 [00:51<251:54:05, 10.13s/it]

I hope this helps clarify the issue. Looking forward to hearing your thoughts!

FoamoftheSea
FoamoftheSea358 days ago👀 2

@muellerzr @Reveyer I would test the speed difference on the second training iteration (after first eval round). This change sacrificed some speed on the first training iteration for noticeable increase in speed in the second and onward after that. Something about the eval round clogged up memory and made all subsequent training loops very slow, the cache emptying fixed that at a slight decrease in initial speed.

Reveyer
Reveyer358 days ago❤ 3

@FoamoftheSea Hello, I've identified that these two lines of code are key to the speed differences observed. Could you please test the impact of these two lines on your experiments? Thank you!

del inputs
torch.cuda.empty_cache()

muellerzr
muellerzr358 days ago❤ 1

Looks like we also have a CPU leak in here too when using custom evals, so I'll be investigating this today.

FoamoftheSea
FoamoftheSea338 days ago❤ 1

@Reveyer sorry for the late response, the past few weeks have been very busy. I will find some time this week to re-run the experiment with that change.

FoamoftheSea
FoamoftheSea337 days ago❤ 1

I ran my test again on the main branch today. The results I got with/without the cache clearing in the training loop are mostly identical, other than that the run with no cache clearing was slightly less memory efficient, and in my case very slightly slower over time, but I think that might just be because my GPU was already warmed up for that training run.

Seeing how the difference here is minimal, I would have no problem with reverting the line in the training loop for cache clearing, since it doesn't seem to cause the memory overflow problem that the eval loop does without it.

W B Chart 6_19_2024, 8_57_45 PM
W B Chart 6_19_2024, 8_58_55 PM

Here's the script I ran the experiment with:

from collections import Counter

import torch

import numpy as np

from typing import Mapping, Optional, Dict, Set
from datasets import load_dataset
from torchvision.transforms.functional import pil_to_tensor

from transformers import (
    SegformerForSemanticSegmentation,
    SegformerImageProcessor,
    Trainer,
    TrainingArguments,
    EvalPrediction,
)

from cityscapesscripts.helpers.labels import id2label

id2trainId = {k: v.trainId for k, v in id2label.items() if k >= 0}
trainId2label = {v.trainId: v.name for k, v in id2label.items() if k >= 0}
conversion_lookup = np.array([id2trainId[i] for i in range(len(id2trainId))])

dataset = load_dataset("Antreas/Cityscapes")
train_dataset, eval_dataset = dataset["train"], dataset["val"]

image_processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-1024-1024")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-1024-1024")

BATCH_EVAL_METRICS = True


class SegformerSemanticSegEvalMetric:
    def __init__(
            self,
            id2label: Dict[int, str],
            ignore_class_ids: Optional[Set[int]] = None,
            reduced_labels: bool = False,
            batch_eval_metrics: bool = True,
    ):
        self.total_area_intersect = Counter()
        self.total_area_union = Counter()
        self.total_label_area = Counter()
        self.ignore_class_ids = ignore_class_ids or set()
        self.reduced_labels = reduced_labels
        self.id2label = id2label
        self.batch_eval_metrics = batch_eval_metrics

    def update(self, logits: torch.FloatTensor, gt_labels: torch.LongTensor):

        # logits_tensor = torch.from_numpy(logits)
        # scale the logits to the size of the label
        logits_tensor = torch.nn.functional.interpolate(
            logits,
            size=gt_labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        gt_labels = gt_labels.detach().cpu().numpy()

        for class_id in self.id2label.keys():
            if class_id in self.ignore_class_ids:
                continue
            if self.reduced_labels:
                label_id = class_id - 1 if class_id != 0 else 255
            else:
                label_id = class_id
            pred_pixels = pred_labels == label_id
            gt_pixels = gt_labels == label_id
            class_label = self.id2label[class_id]
            self.total_area_intersect.update({class_label: np.sum(np.bitwise_and(pred_pixels, gt_pixels))})
            self.total_area_union.update({class_label: np.sum(np.bitwise_or(pred_pixels, gt_pixels))})
            self.total_label_area.update({class_label: np.sum(gt_pixels)})

    def compute(self):
        accuracies = {f"accuracy_{k}": self.total_area_intersect[k] / self.total_label_area[k] for k in self.total_area_union}
        ious = {f"iou_{k}": self.total_area_intersect[k] / self.total_area_union[k] for k in self.total_area_union}
        metrics = {
            "overall_accuracy": sum(self.total_area_intersect.values()) / sum(self.total_label_area.values()),
            "mean_accuracy": np.mean(list(accuracies.values())),
            "mean_iou": np.mean(list(ious.values())),
        }
        metrics.update(accuracies)
        metrics.update(ious)

        return metrics

    def __call__(self, eval_pred: EvalPrediction, compute_result=False):
        if self.batch_eval_metrics:
            return self._call_batched(eval_pred, compute_result)
        else:
            return self._call_nonbatched(eval_pred)

    def _call_nonbatched(self, eval_pred):
        mious = {}
        with torch.no_grad():
            logits, gt_labels = eval_pred.predictions, eval_pred.label_ids
            logits_tensor = torch.from_numpy(logits)
            # scale the logits to the size of the label
            logits_tensor = torch.nn.functional.interpolate(
                logits_tensor,
                size=gt_labels.shape[-2:],
                mode="bilinear",
                align_corners=False,
            ).argmax(dim=1)

            pred_labels = logits_tensor.detach().cpu().numpy()

            for class_id in self.id2label.keys():
                if class_id in self.ignore_class_ids:
                    continue
                if self.reduced_labels:
                    label_id = class_id - 1 if class_id != 0 else 255
                else:
                    label_id = class_id
                pred_pixels = pred_labels == label_id
                gt_pixels = gt_labels == label_id
                class_label = self.id2label[class_id]
                intersection = np.sum(np.bitwise_and(pred_pixels, gt_pixels))
                union = np.sum(np.bitwise_or(pred_pixels, gt_pixels))
                mious[class_label] = intersection / union

        return np.mean(list(mious.values()))

    def _call_batched(self, eval_pred: EvalPrediction, compute_result: bool = True) -> Optional[dict]:
        with torch.no_grad():
            self.update(eval_pred.predictions, eval_pred.label_ids)
            return self.compute() if compute_result else None


def collate_fn(features: list):

    if not isinstance(features[0], Mapping):
        features = [vars(f) for f in features]
    first = features[0]
    images = None
    semantic_masks = None

    if "semantic_segmentation" in first and first["semantic_segmentation"] is not None:
        semantic_masks = [pil_to_tensor(f["semantic_segmentation"])[0] for f in features]
    if "image" in first and first["image"] is not None:
        images = [pil_to_tensor(f["image"].convert("RGB")) for f in features]

    processed = image_processor(images=images, segmentation_maps=semantic_masks)
    labels = np.array(processed.data["labels"])
    labels = conversion_lookup[labels]

    batch = {
        "pixel_values": torch.Tensor(np.array(processed.data["pixel_values"])),
        "labels": torch.LongTensor(np.array(labels)),
    }

    return batch


training_args = TrainingArguments(
    "segformer-test",
    num_train_epochs=10,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    save_total_limit=3,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=300,
    eval_steps=150,
    max_steps=10000,
    logging_steps=1,
    load_best_model_at_end=True,
    push_to_hub=False,
    gradient_checkpointing=False,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    use_cpu=False,
    learning_rate=0.0002,
    batch_eval_metrics=BATCH_EVAL_METRICS,
    remove_unused_columns=False,
    # eval_accumulation_steps=1 if not BATCH_EVAL_METRICS else None,
    run_name="segformer-b0-cityscapes-t4e8-batch-eval-metrics-no-cache-clear-train-loop",
)

trainer = Trainer(
    model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=SegformerSemanticSegEvalMetric(id2label=trainId2label, ignore_class_ids={255}),
    data_collator=collate_fn,
)

trainer.train()
SunMarc
SunMarc337 days ago👍 2

Hi @FoamoftheSea, thanks for re-running the experiments ! From your observation, I think it will be better to remove cache clearing in the training loop as other users show a huge increase of training time. Would you like to open a PR ? Otherwise, I'll do it !

manoja328
manoja32884 days ago👍 1

any solution to this issue? My trainer runs fine during train model but the same crashes when run with eval mode.

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone