Update LightllmApi and Fix mmlu bug (#738)

* Update LightllmApi and Fix mmlu bug

* checkout mmlu_gen_a484b3.py

---------

Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
Yang Yong 2023-12-27 13:49:08 +08:00 committed by GitHub
parent 34561ececb
commit 54345c56b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 202 additions and 14 deletions

View File

@ -1,7 +1,7 @@
from mmengine.config import read_base from mmengine.config import read_base
with read_base(): with read_base():
from ..mmlu.mmlu_gen_a484b3 import mmlu_datasets from ..mmlu.mmlu_gen_4d595a import mmlu_datasets
from ..ceval.ceval_gen_5f30c7 import ceval_datasets from ..ceval.ceval_gen_5f30c7 import ceval_datasets
from ..agieval.agieval_gen_64afd3 import agieval_datasets from ..agieval.agieval_gen_64afd3 import agieval_datasets
from ..GaokaoBench.GaokaoBench_gen_5cfe9e import GaokaoBench_datasets from ..GaokaoBench.GaokaoBench_gen_5cfe9e import GaokaoBench_datasets

View File

@ -1,7 +1,7 @@
from mmengine.config import read_base from mmengine.config import read_base
with read_base(): with read_base():
from ..mmlu.mmlu_gen_a484b3 import mmlu_datasets from ..mmlu.mmlu_gen_4d595a import mmlu_datasets
from ..ceval.ceval_gen_5f30c7 import ceval_datasets from ..ceval.ceval_gen_5f30c7 import ceval_datasets
from ..bbh.bbh_gen_5b92b0 import bbh_datasets from ..bbh.bbh_gen_5b92b0 import bbh_datasets
from ..CLUE_CMRC.CLUE_CMRC_gen_1bd3c8 import CMRC_datasets from ..CLUE_CMRC.CLUE_CMRC_gen_1bd3c8 import CMRC_datasets

View File

@ -3,7 +3,7 @@ from mmengine.config import read_base
with read_base(): with read_base():
from ...ceval.ceval_gen_5f30c7 import ceval_datasets from ...ceval.ceval_gen_5f30c7 import ceval_datasets
from ...agieval.agieval_mixed_2f14ad import agieval_datasets from ...agieval.agieval_mixed_2f14ad import agieval_datasets
from ...mmlu.mmlu_gen_a484b3 import mmlu_datasets from ...mmlu.mmlu_gen_4d595a import mmlu_datasets
from ...cmmlu.cmmlu_gen_c13365 import cmmlu_datasets from ...cmmlu.cmmlu_gen_c13365 import cmmlu_datasets
from ...GaokaoBench.GaokaoBench_gen_5cfe9e import GaokaoBench_datasets from ...GaokaoBench.GaokaoBench_gen_5cfe9e import GaokaoBench_datasets
from ...ARC_c.ARC_c_ppl_2ef631 import ARC_c_datasets from ...ARC_c.ARC_c_ppl_2ef631 import ARC_c_datasets

View File

@ -1,4 +1,4 @@
from mmengine.config import read_base from mmengine.config import read_base
with read_base(): with read_base():
from .mmlu_gen_a484b3 import mmlu_datasets # noqa: F401, F403 from .mmlu_gen_4d595a import mmlu_datasets # noqa: F401, F403

View File

@ -0,0 +1,124 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import FixKRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import MMLUDataset
from opencompass.utils.text_postprocessors import first_capital_postprocess
# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader
# Please download the dataset from https://people.eecs.berkeley.edu/~hendrycks/data.tar
mmlu_reader_cfg = dict(
input_columns=["input", "A", "B", "C", "D"],
output_column="target",
train_split='dev')
mmlu_all_sets = [
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_physics",
"electrical_engineering",
"astronomy",
"anatomy",
"abstract_algebra",
"machine_learning",
"clinical_knowledge",
"global_facts",
"management",
"nutrition",
"marketing",
"professional_accounting",
"high_school_geography",
"international_law",
"moral_scenarios",
"computer_security",
"high_school_microeconomics",
"professional_law",
"medical_genetics",
"professional_psychology",
"jurisprudence",
"world_religions",
"philosophy",
"virology",
"high_school_chemistry",
"public_relations",
"high_school_macroeconomics",
"human_sexuality",
"elementary_mathematics",
"high_school_physics",
"high_school_computer_science",
"high_school_european_history",
"business_ethics",
"moral_disputes",
"high_school_statistics",
"miscellaneous",
"formal_logic",
"high_school_government_and_politics",
"prehistory",
"security_studies",
"high_school_biology",
"logical_fallacies",
"high_school_world_history",
"professional_medicine",
"high_school_mathematics",
"college_medicine",
"high_school_us_history",
"sociology",
"econometrics",
"high_school_psychology",
"human_aging",
"us_foreign_policy",
"conceptual_physics",
]
mmlu_datasets = []
for _name in mmlu_all_sets:
_hint = f'There is a single choice question about {_name.replace("_", " ")}. Answer the question by replying A, B, C or D.'
mmlu_infer_cfg = dict(
ice_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role="HUMAN",
prompt=
f"{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: "
),
dict(role="BOT", prompt="{target}\n")
]),
),
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin="</E>",
round=[
dict(
role="HUMAN",
prompt=
f"{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: "
),
],
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)
mmlu_eval_cfg = dict(
evaluator=dict(type=AccEvaluator),
pred_postprocessor=dict(type=first_capital_postprocess))
mmlu_datasets.append(
dict(
abbr=f"lukaemon_mmlu_{_name}",
type=MMLUDataset,
path="./data/mmlu/",
name=_name,
reader_cfg=mmlu_reader_cfg,
infer_cfg=mmlu_infer_cfg,
eval_cfg=mmlu_eval_cfg,
))
del _name, _hint

View File

@ -14,11 +14,12 @@ models = [
abbr='LightllmAPI', abbr='LightllmAPI',
type=LightllmAPI, type=LightllmAPI,
url='http://localhost:8080/generate', url='http://localhost:8080/generate',
max_out_len=1024, max_seq_len=2048,
batch_size=8, batch_size=32,
generation_kwargs=dict( generation_kwargs=dict(
do_sample=False, do_sample=False,
ignore_eos=False, ignore_eos=False,
max_new_tokens=1024
), ),
), ),
] ]
@ -27,7 +28,7 @@ infer = dict(
partitioner=dict(type=NaivePartitioner), partitioner=dict(type=NaivePartitioner),
runner=dict( runner=dict(
type=LocalRunner, type=LocalRunner,
max_num_workers=8, max_num_workers=32,
task=dict(type=OpenICLInferTask), task=dict(type=OpenICLInferTask),
), ),
) )

View File

@ -2,6 +2,7 @@ import json
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional from typing import Dict, List, Optional
import numpy as np
import requests import requests
from opencompass.registry import MODELS from opencompass.registry import MODELS
@ -32,8 +33,8 @@ class LightllmAPI(BaseAPIModel):
generation_kwargs=generation_kwargs) generation_kwargs=generation_kwargs)
self.logger = get_logger() self.logger = get_logger()
self.url = url self.url = url
self.do_sample = self.generation_kwargs.get('do_sample', False) self.generation_kwargs = generation_kwargs
self.ignore_eos = self.generation_kwargs.get('ignore_eos', False) self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
def generate(self, inputs: List[str], max_out_len: int, def generate(self, inputs: List[str], max_out_len: int,
**kwargs) -> List[str]: **kwargs) -> List[str]:
@ -52,7 +53,7 @@ class LightllmAPI(BaseAPIModel):
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
results = list( results = list(
executor.map(self._generate, inputs, executor.map(self._generate, inputs,
[max_out_len] * len(inputs))) [self.max_out_len] * len(inputs)))
return results return results
def _generate(self, input: str, max_out_len: int) -> str: def _generate(self, input: str, max_out_len: int) -> str:
@ -61,10 +62,7 @@ class LightllmAPI(BaseAPIModel):
self.wait() self.wait()
header = {'content-type': 'application/json'} header = {'content-type': 'application/json'}
try: try:
data = dict(inputs=input, data = dict(inputs=input, parameters=self.generation_kwargs)
parameters=dict(do_sample=self.do_sample,
ignore_eos=self.ignore_eos,
max_new_tokens=max_out_len))
raw_response = requests.post(self.url, raw_response = requests.post(self.url,
headers=header, headers=header,
data=json.dumps(data)) data=json.dumps(data))
@ -85,3 +83,68 @@ class LightllmAPI(BaseAPIModel):
raise RuntimeError('Calling LightllmAPI failed after retrying for ' raise RuntimeError('Calling LightllmAPI failed after retrying for '
f'{max_num_retries} times. Check the logs for ' f'{max_num_retries} times. Check the logs for '
'details.') 'details.')
def get_ppl(self, inputs: List[str], max_out_len: int,
**kwargs) -> List[float]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._get_ppl, inputs,
[self.max_out_len] * len(inputs)))
return np.array(results)
def _get_ppl(self, input: str, max_out_len: int) -> float:
max_num_retries = 0
if max_out_len is None:
max_out_len = 1
while max_num_retries < self.retry:
self.wait()
header = {'content-type': 'application/json'}
try:
data = dict(inputs=input, parameters=self.generation_kwargs)
raw_response = requests.post(self.url,
headers=header,
data=json.dumps(data))
except requests.ConnectionError:
self.logger.error('Got connection error, retrying...')
continue
try:
response = raw_response.json()
assert ('prompt_token_ids' in response and 'prompt_logprobs'
in response), 'prompt_token_ids and prompt_logprobs \
must be in the output. \
Please consider adding \
--return_all_prompt_logprobs argument \
when starting your lightllm service.'
prompt_token_ids = response['prompt_token_ids'][1:]
prompt_logprobs = [
item[1] for item in response['prompt_logprobs']
]
logprobs = [
item[str(token_id)] for token_id, item in zip(
prompt_token_ids, prompt_logprobs)
]
if len(logprobs) == 0:
return 0.0
ce_loss = -sum(logprobs) / len(logprobs)
return ce_loss
except requests.JSONDecodeError:
self.logger.error('JsonDecode error, got',
str(raw_response.content))
max_num_retries += 1
raise RuntimeError('Calling LightllmAPI failed after retrying for '
f'{max_num_retries} times. Check the logs for '
'details.')