diff --git a/configs/datasets/collections/chat_medium.py b/configs/datasets/collections/chat_medium.py index 909cd21f..557d1455 100644 --- a/configs/datasets/collections/chat_medium.py +++ b/configs/datasets/collections/chat_medium.py @@ -1,7 +1,7 @@ from mmengine.config import read_base with read_base(): - from ..mmlu.mmlu_gen_a484b3 import mmlu_datasets + from ..mmlu.mmlu_gen_4d595a import mmlu_datasets from ..ceval.ceval_gen_5f30c7 import ceval_datasets from ..agieval.agieval_gen_64afd3 import agieval_datasets from ..GaokaoBench.GaokaoBench_gen_5cfe9e import GaokaoBench_datasets diff --git a/configs/datasets/collections/chat_small.py b/configs/datasets/collections/chat_small.py index 11ac216c..6314e46c 100644 --- a/configs/datasets/collections/chat_small.py +++ b/configs/datasets/collections/chat_small.py @@ -1,7 +1,7 @@ from mmengine.config import read_base with read_base(): - from ..mmlu.mmlu_gen_a484b3 import mmlu_datasets + from ..mmlu.mmlu_gen_4d595a import mmlu_datasets from ..ceval.ceval_gen_5f30c7 import ceval_datasets from ..bbh.bbh_gen_5b92b0 import bbh_datasets from ..CLUE_CMRC.CLUE_CMRC_gen_1bd3c8 import CMRC_datasets diff --git a/configs/datasets/collections/leaderboard/qwen_chat.py b/configs/datasets/collections/leaderboard/qwen_chat.py index 22e9555c..7d02ee7c 100644 --- a/configs/datasets/collections/leaderboard/qwen_chat.py +++ b/configs/datasets/collections/leaderboard/qwen_chat.py @@ -3,7 +3,7 @@ from mmengine.config import read_base with read_base(): from ...ceval.ceval_gen_5f30c7 import ceval_datasets from ...agieval.agieval_mixed_2f14ad import agieval_datasets - from ...mmlu.mmlu_gen_a484b3 import mmlu_datasets + from ...mmlu.mmlu_gen_4d595a import mmlu_datasets from ...cmmlu.cmmlu_gen_c13365 import cmmlu_datasets from ...GaokaoBench.GaokaoBench_gen_5cfe9e import GaokaoBench_datasets from ...ARC_c.ARC_c_ppl_2ef631 import ARC_c_datasets diff --git a/configs/datasets/mmlu/mmlu_gen.py b/configs/datasets/mmlu/mmlu_gen.py index 1a13e563..157ee329 100644 --- a/configs/datasets/mmlu/mmlu_gen.py +++ b/configs/datasets/mmlu/mmlu_gen.py @@ -1,4 +1,4 @@ from mmengine.config import read_base with read_base(): - from .mmlu_gen_a484b3 import mmlu_datasets # noqa: F401, F403 + from .mmlu_gen_4d595a import mmlu_datasets # noqa: F401, F403 diff --git a/configs/datasets/mmlu/mmlu_gen_4d595a.py b/configs/datasets/mmlu/mmlu_gen_4d595a.py new file mode 100644 index 00000000..6b81299f --- /dev/null +++ b/configs/datasets/mmlu/mmlu_gen_4d595a.py @@ -0,0 +1,124 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import FixKRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import MMLUDataset +from opencompass.utils.text_postprocessors import first_capital_postprocess + +# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader +# Please download the dataset from https://people.eecs.berkeley.edu/~hendrycks/data.tar + +mmlu_reader_cfg = dict( + input_columns=["input", "A", "B", "C", "D"], + output_column="target", + train_split='dev') + +mmlu_all_sets = [ + "college_biology", + "college_chemistry", + "college_computer_science", + "college_mathematics", + "college_physics", + "electrical_engineering", + "astronomy", + "anatomy", + "abstract_algebra", + "machine_learning", + "clinical_knowledge", + "global_facts", + "management", + "nutrition", + "marketing", + "professional_accounting", + "high_school_geography", + "international_law", + "moral_scenarios", + "computer_security", + "high_school_microeconomics", + "professional_law", + "medical_genetics", + "professional_psychology", + "jurisprudence", + "world_religions", + "philosophy", + "virology", + "high_school_chemistry", + "public_relations", + "high_school_macroeconomics", + "human_sexuality", + "elementary_mathematics", + "high_school_physics", + "high_school_computer_science", + "high_school_european_history", + "business_ethics", + "moral_disputes", + "high_school_statistics", + "miscellaneous", + "formal_logic", + "high_school_government_and_politics", + "prehistory", + "security_studies", + "high_school_biology", + "logical_fallacies", + "high_school_world_history", + "professional_medicine", + "high_school_mathematics", + "college_medicine", + "high_school_us_history", + "sociology", + "econometrics", + "high_school_psychology", + "human_aging", + "us_foreign_policy", + "conceptual_physics", +] + +mmlu_datasets = [] +for _name in mmlu_all_sets: + _hint = f'There is a single choice question about {_name.replace("_", " ")}. Answer the question by replying A, B, C or D.' + mmlu_infer_cfg = dict( + ice_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict( + role="HUMAN", + prompt= + f"{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: " + ), + dict(role="BOT", prompt="{target}\n") + ]), + ), + prompt_template=dict( + type=PromptTemplate, + template=dict( + begin="", + round=[ + dict( + role="HUMAN", + prompt= + f"{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: " + ), + ], + ), + ice_token="", + ), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=GenInferencer), + ) + + mmlu_eval_cfg = dict( + evaluator=dict(type=AccEvaluator), + pred_postprocessor=dict(type=first_capital_postprocess)) + + mmlu_datasets.append( + dict( + abbr=f"lukaemon_mmlu_{_name}", + type=MMLUDataset, + path="./data/mmlu/", + name=_name, + reader_cfg=mmlu_reader_cfg, + infer_cfg=mmlu_infer_cfg, + eval_cfg=mmlu_eval_cfg, + )) + +del _name, _hint diff --git a/configs/eval_lightllm.py b/configs/eval_lightllm.py index da7589e9..02847dfc 100644 --- a/configs/eval_lightllm.py +++ b/configs/eval_lightllm.py @@ -14,11 +14,12 @@ models = [ abbr='LightllmAPI', type=LightllmAPI, url='http://localhost:8080/generate', - max_out_len=1024, - batch_size=8, + max_seq_len=2048, + batch_size=32, generation_kwargs=dict( do_sample=False, ignore_eos=False, + max_new_tokens=1024 ), ), ] @@ -27,7 +28,7 @@ infer = dict( partitioner=dict(type=NaivePartitioner), runner=dict( type=LocalRunner, - max_num_workers=8, + max_num_workers=32, task=dict(type=OpenICLInferTask), ), ) diff --git a/opencompass/models/lightllm_api.py b/opencompass/models/lightllm_api.py index f5e0b719..eb752ae9 100644 --- a/opencompass/models/lightllm_api.py +++ b/opencompass/models/lightllm_api.py @@ -2,6 +2,7 @@ import json from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Optional +import numpy as np import requests from opencompass.registry import MODELS @@ -32,8 +33,8 @@ class LightllmAPI(BaseAPIModel): generation_kwargs=generation_kwargs) self.logger = get_logger() self.url = url - self.do_sample = self.generation_kwargs.get('do_sample', False) - self.ignore_eos = self.generation_kwargs.get('ignore_eos', False) + self.generation_kwargs = generation_kwargs + self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024) def generate(self, inputs: List[str], max_out_len: int, **kwargs) -> List[str]: @@ -52,7 +53,7 @@ class LightllmAPI(BaseAPIModel): with ThreadPoolExecutor() as executor: results = list( executor.map(self._generate, inputs, - [max_out_len] * len(inputs))) + [self.max_out_len] * len(inputs))) return results def _generate(self, input: str, max_out_len: int) -> str: @@ -61,10 +62,7 @@ class LightllmAPI(BaseAPIModel): self.wait() header = {'content-type': 'application/json'} try: - data = dict(inputs=input, - parameters=dict(do_sample=self.do_sample, - ignore_eos=self.ignore_eos, - max_new_tokens=max_out_len)) + data = dict(inputs=input, parameters=self.generation_kwargs) raw_response = requests.post(self.url, headers=header, data=json.dumps(data)) @@ -85,3 +83,68 @@ class LightllmAPI(BaseAPIModel): raise RuntimeError('Calling LightllmAPI failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.') + + def get_ppl(self, inputs: List[str], max_out_len: int, + **kwargs) -> List[float]: + """Generate results given a list of inputs. + + Args: + inputs (List[str]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._get_ppl, inputs, + [self.max_out_len] * len(inputs))) + return np.array(results) + + def _get_ppl(self, input: str, max_out_len: int) -> float: + max_num_retries = 0 + if max_out_len is None: + max_out_len = 1 + while max_num_retries < self.retry: + self.wait() + header = {'content-type': 'application/json'} + try: + data = dict(inputs=input, parameters=self.generation_kwargs) + raw_response = requests.post(self.url, + headers=header, + data=json.dumps(data)) + except requests.ConnectionError: + self.logger.error('Got connection error, retrying...') + continue + try: + response = raw_response.json() + + assert ('prompt_token_ids' in response and 'prompt_logprobs' + in response), 'prompt_token_ids and prompt_logprobs \ + must be in the output. \ + Please consider adding \ + --return_all_prompt_logprobs argument \ + when starting your lightllm service.' + + prompt_token_ids = response['prompt_token_ids'][1:] + prompt_logprobs = [ + item[1] for item in response['prompt_logprobs'] + ] + logprobs = [ + item[str(token_id)] for token_id, item in zip( + prompt_token_ids, prompt_logprobs) + ] + if len(logprobs) == 0: + return 0.0 + ce_loss = -sum(logprobs) / len(logprobs) + return ce_loss + except requests.JSONDecodeError: + self.logger.error('JsonDecode error, got', + str(raw_response.content)) + max_num_retries += 1 + raise RuntimeError('Calling LightllmAPI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + 'details.')