lighteval
b5b096e8 - Added custom model inference. (#437)

Commit
228 days ago
Added custom model inference. (#437) Enables the evaluation of any system in the user's control. Fixes [Issue 430](https://github.com/huggingface/lighteval/issues/430). Try with ``` python -m lighteval custom google-translate /path/to/google_translate_model.py "lighteval|wmt20:fr-de|0|0" --max-samples 10 ``` google_translate_model.py ``` import logging from typing import Optional from tqdm import tqdm from transformers import AutoTokenizer from lighteval.data import GenerativeTaskDataset from lighteval.models.abstract_model import LightevalModel, ModelInfo from lighteval.models.model_output import ( GenerativeResponse, LoglikelihoodResponse, LoglikelihoodSingleTokenResponse, ) from lighteval.tasks.requests import ( GreedyUntilRequest, LoglikelihoodRequest, LoglikelihoodRollingRequest, LoglikelihoodSingleTokenRequest, ) logger = logging.getLogger(__name__) class GoogleTranslateClient(LightevalModel): def __init__(self, config, env_config) -> None: self.model = config.model self.model_definition_file_path = config.model_definition_file_path self.model_info = ModelInfo( model_name=config.model, model_sha="", model_dtype=None, model_size="", ) self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility import httpcore # Needed to fix some googletrans bug # https://stackoverflow.com/questions/72796594/attributeerror-module-httpcore-has-no-attribute-synchttptransport#comment136664963_77334618 setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy') from googletrans import Translator self.translator = Translator() def greedy_until( self, requests: list[GreedyUntilRequest], override_bs: Optional[int] = None, ) -> list[GenerativeResponse]: """ Generates responses using a greedy decoding strategy until certain ending conditions are met. Args: requests (list[Request]): list of requests containing the context and ending conditions. disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False. override_bs (int, optional): Override the batch size for generation. Defaults to None. Returns: list[GenerativeResponse]: list of generated responses. """ for request in requests: request.tokenized_context = self.tok_encode(request.context) dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) results = [] for _ in tqdm( dataset.splits_start_end_iterator(), total=dataset.num_dataset_splits, desc="Splits", position=0, disable=False, # self.disable_tqdm, ): for r in tqdm(dataset, desc="Batch", position=1, disable=False): context = r.context.replace("French phrase: ", "") # TODO: Get src and dest from request translation = self.translator.translate(context, src='fr', dest='de') result = translation.text cur_response = GenerativeResponse( result=result, logits=None, generated_tokens=[], input_tokens=[], ) results.append(cur_response) return dataset.get_original_order(results) @property def tokenizer(self): return self._tokenizer def tok_encode(self, text: str): return self.tokenizer.encode(text) @property def add_special_tokens(self) -> bool: return False @property def max_length(self) -> int: """Return the maximum sequence length of the model.""" return 4096 def loglikelihood( self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None ) -> list[LoglikelihoodResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. """ raise NotImplementedError def loglikelihood_rolling( self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None ) -> list[LoglikelihoodResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" raise NotImplementedError def loglikelihood_single_token( self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None ) -> list[LoglikelihoodSingleTokenResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. """ raise NotImplementedError ```
Author
Parents
Loading