Added custom model inference. (#437)
Enables the evaluation of any system in the user's control. Fixes [Issue 430](https://github.com/huggingface/lighteval/issues/430).
Try with
```
python -m lighteval custom google-translate /path/to/google_translate_model.py "lighteval|wmt20:fr-de|0|0" --max-samples 10
```
google_translate_model.py
```
import logging
from typing import Optional
from tqdm import tqdm
from transformers import AutoTokenizer
from lighteval.data import GenerativeTaskDataset
from lighteval.models.abstract_model import LightevalModel, ModelInfo
from lighteval.models.model_output import (
GenerativeResponse,
LoglikelihoodResponse,
LoglikelihoodSingleTokenResponse,
)
from lighteval.tasks.requests import (
GreedyUntilRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
)
logger = logging.getLogger(__name__)
class GoogleTranslateClient(LightevalModel):
def __init__(self, config, env_config) -> None:
self.model = config.model
self.model_definition_file_path = config.model_definition_file_path
self.model_info = ModelInfo(
model_name=config.model,
model_sha="",
model_dtype=None,
model_size="",
)
self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility
import httpcore
# Needed to fix some googletrans bug
# https://stackoverflow.com/questions/72796594/attributeerror-module-httpcore-has-no-attribute-synchttptransport#comment136664963_77334618
setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
from googletrans import Translator
self.translator = Translator()
def greedy_until(
self,
requests: list[GreedyUntilRequest],
override_bs: Optional[int] = None,
) -> list[GenerativeResponse]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met.
Args:
requests (list[Request]): list of requests containing the context and ending conditions.
disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False.
override_bs (int, optional): Override the batch size for generation. Defaults to None.
Returns:
list[GenerativeResponse]: list of generated responses.
"""
for request in requests:
request.tokenized_context = self.tok_encode(request.context)
dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
results = []
for _ in tqdm(
dataset.splits_start_end_iterator(),
total=dataset.num_dataset_splits,
desc="Splits",
position=0,
disable=False, # self.disable_tqdm,
):
for r in tqdm(dataset, desc="Batch", position=1, disable=False):
context = r.context.replace("French phrase: ", "")
# TODO: Get src and dest from request
translation = self.translator.translate(context, src='fr', dest='de')
result = translation.text
cur_response = GenerativeResponse(
result=result,
logits=None,
generated_tokens=[],
input_tokens=[],
)
results.append(cur_response)
return dataset.get_original_order(results)
@property
def tokenizer(self):
return self._tokenizer
def tok_encode(self, text: str):
return self.tokenizer.encode(text)
@property
def add_special_tokens(self) -> bool:
return False
@property
def max_length(self) -> int:
"""Return the maximum sequence length of the model."""
return 4096
def loglikelihood(
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
raise NotImplementedError
def loglikelihood_rolling(
self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodResponse]:
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
raise NotImplementedError
def loglikelihood_single_token(
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodSingleTokenResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
raise NotImplementedError
```