[Feature] Support G-Pass@k and LiveMathBench (#1772)

* support G-Pass@k and livemathbench

* fix bugs

* fix comments of GPassKEvaluator

* update saved details of GPassKEvaluator

* update saved details of GPassKEvaluator

* fix eval api configs & update openai_api for ease of debugging

* update huggingface path

* fix method name of G-Pass@k

* fix default value of eval_model_name

* refactor G-Pass@k evaluator

* log generation params for each backend

* fix evaluation resume

* add notimplementerror
This commit is contained in:
Junnan Liu 2024-12-30 16:59:39 +08:00 committed by GitHub
parent 42b54d6bb8
commit 8e8d4f1c64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 524 additions and 324 deletions

View File

@ -1,4 +1,4 @@
from mmengine.config import read_base
with read_base():
from .livemathbench_gen_caed8f import livemathbench_datasets # noqa: F401, F403
from .livemathbench_gen_9befbf import livemathbench_datasets # noqa: F401, F403

View File

@ -0,0 +1,51 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets.livemathbench import LiveMathBenchDataset, LiveMathBenchEvaluator
livemathbench_dataset = dict(
type=LiveMathBenchDataset,
path='',
k=16,
replication=3,
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
dataset_languages=['cn', 'en'],
cot=True,
version='202412',
abbr='LiveMathBench-v202412',
reader_cfg=dict(
input_columns=['prompt'],
output_column='answer'
),
infer_cfg=dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{prompt}'),
]
)
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(
type=GenInferencer,
max_out_len=8192
),
),
eval_cfg=dict(
evaluator=dict(
type=LiveMathBenchEvaluator,
model_name='',
url=[],
use_extract_model=False,
extract_url=[],
extract_model_name='',
k=[4, 8, 16],
replication=3,
thresholds=[0.0, 0.25, 0.5, 0.75, 1.0]
)
)
)
livemathbench_datasets = [livemathbench_dataset]

View File

@ -1,45 +1,55 @@
import concurrent.futures
import os
import re
import warnings
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy
from functools import partial
from itertools import product
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List, Union
import jsonlines
import mmengine
import numpy as np
from datasets import Dataset
from datasets import Dataset, load_dataset
from opencompass.datasets.math import MATHAgentEvaluator, math_postprocess_v2
from opencompass.models import OpenAISDK
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.openicl.icl_evaluator import GPassKEvaluator
from opencompass.openicl.icl_inferencer.icl_base_inferencer import \
dump_results_dict
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, MODELS
from opencompass.utils import get_data_path
from ..base import BaseDataset
from .prompts import (EXTRACT_PROMPT_CN, EXTRACT_PROMPT_EN, JUDGE_PROMPT_CN,
JUDGE_PROMPT_EN, PROMPT_CN, PROMPT_EN)
from .utils import extract_judge_label
@LOAD_DATASET.register_module()
class LiveMathBenchDataset(BaseDataset):
@staticmethod
def load(
path: str,
k: int,
n: int,
dataset_splits: List[str] = [
'AIMC', 'CEE', 'CMO', 'MATH500', 'AIME2024'
],
dataset_languages: List[str] = ['cn', 'en'],
) -> List[Dict[str, Any]]:
def load(path: str,
k: Union[int, List[int]],
replication: int,
dataset_splits: List[str] = [
'CNMO',
'CCEE',
'AMC',
'WLPMC',
],
dataset_languages: List[str] = ['cn', 'en'],
cot: bool = True,
version: str = '202412') -> List[Dict[str, Any]]:
dataset = []
dataset_info = {}
path = get_data_path(path)
if path != '':
path = get_data_path(path)
head, tail = os.path.split(path)
path = os.path.join(head, f'{tail}-{version}')
for split, language in product(dataset_splits, dataset_languages):
file_path = os.path.join(path, f'{split}_{language}.jsonl')
if not os.path.exists(file_path):
continue
dataset_info[f'{split}_{language}'] = {
'single-choice': 0,
'multiple-choice': 0,
@ -52,36 +62,57 @@ class LiveMathBenchDataset(BaseDataset):
'填空': 'fill-in-the-blank',
'问答': 'problem-solving'
}
with jsonlines.open(file_path, 'r') as file:
for example_idx, example in enumerate(file):
dataset_info[f'{split}_{language}'][
example['question_type'] if language == 'en' else
question_type_mapping[example['question_type']]] += 1
prompt = PROMPT_EN if language == 'en' else PROMPT_CN
example.update({
'dataset_key':
f'{split}_{language}_{example_idx}',
'prompt':
prompt.format(question_type=example['question_type'],
question=example['question'] +
('' if 'options' not in example else
' '.join(example['options']))),
'k':
k,
'n':
n
})
for idx in range(k * n):
duplicated_example = deepcopy(example)
duplicated_example.update({'duplicated_idx': idx})
dataset.append(duplicated_example)
if path != '':
file_path = os.path.join(path, f'{split}_{language}.jsonl')
if not os.path.exists(file_path):
continue
examples = []
with jsonlines.open(file_path, 'r') as file:
for example in file:
examples.append(example)
else:
hf_dataset = load_dataset(
'opencompass/LiveMathBench',
f'v{version}_{split}_{language}')['test']
examples = []
for example in hf_dataset:
examples.append(example)
for example_idx, example in enumerate(examples):
dataset_info[f'{split}_{language}'][
example['question_type'] if language == 'en' else
question_type_mapping[example['question_type']]] += 1
prompt = PROMPT_EN if language == 'en' else PROMPT_CN
if not cot:
if language == 'cn':
prompt = prompt.replace(',请逐步推理', '')
else:
prompt = prompt.replace(
', please reasoning step by step', '')
example.update({
'subdivision':
f'{split}_{language}',
'idx':
str(example_idx),
'prompt':
prompt.format(question_type=example['question_type'],
question=example['question'] +
('' if 'options' not in example else
' '.join(example['options']))),
})
max_k = k if isinstance(k, int) else max(k)
for idx in range(max_k * replication):
duplicated_example = deepcopy(example)
duplicated_example.update({'replication_idx': idx})
dataset.append(duplicated_example)
return Dataset.from_list(dataset)
@ICL_EVALUATORS.register_module()
class LiveMathBenchEvaluator(BaseEvaluator):
class LiveMathBenchEvaluator(GPassKEvaluator):
api_meta_template = dict(round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
@ -90,72 +121,125 @@ class LiveMathBenchEvaluator(BaseEvaluator):
def __init__(self,
model_name,
url,
with_postprocess=True,
use_extract_model=False,
post_url=[],
post_model_name='',
**kwargs):
extract_url=[],
extract_model_name='',
k: Union[int, List[int]] = 16,
replication: int = 3,
thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]):
super().__init__(k, replication, thresholds)
if isinstance(url, str):
url = [url]
self.model = [
MODELS.build(
dict(
type=OpenAISDK,
path=model_name,
openai_api_base=url,
key='EMPTY',
query_per_second=128,
meta_template=self.api_meta_template,
temperature=kwargs.get('temperature', 0.001),
max_seq_len=kwargs.get('max_tokens', 16384),
)) for url in url
]
self.with_postprocess = with_postprocess
self.use_extract_model = use_extract_model
self.post_url = post_url
self.post_model_name = post_model_name
def batch_response(self, models: List[OpenAISDK],
inputs: List[str]) -> List[str]:
batch_num = len(models)
batch_size = (len(inputs) + batch_num - 1) // batch_num
result_responses = []
with concurrent.futures.ThreadPoolExecutor(
max_workers=batch_num) as executor:
futures = [
executor.submit(models[i].generate,
inputs[i * batch_size:(i + 1) * batch_size])
for i in range(batch_num)
]
for response in executor.map(lambda f: f.result(), futures):
result_responses.extend(response)
return result_responses
def postprocess(self, questions: List[str], predictions: List[str],
question_types: List[str],
languages: List[str]) -> List[str]:
if self.use_extract_model:
assert len(self.post_url) > 0 and self.post_model_name != ''
post_model = [
if model_name == '' or len(url) == 0:
warnings.warn('Unable to leverage LLM-as-judge abd backup to '
'rule-based judge due to incomplete parameters, '
'this may cause performance degradation, check '
'`model_name` or `url` of evaluator if you do '
'not want to do this.')
self.judge_models = []
else:
self.judge_models = [
MODELS.build(
dict(
type=OpenAISDK,
path=self.post_model_name,
path=model_name,
openai_api_base=_url,
key='EMPTY',
query_per_second=2,
retry=5,
meta_template=self.api_meta_template,
temperature=0.0,
max_seq_len=16384,
)) for _url in url
]
self.use_extract_model = use_extract_model
self.extract_url = extract_url
self.extract_model_name = extract_model_name
self.extract_output_handler = LiveMathBenchOutputHandler()
self.judge_output_handler = LiveMathBenchOutputHandler()
def batch_infer(self, models: List[OpenAISDK], inputs: List[str],
completed_indexes: set,
output_handler: 'LiveMathBenchOutputHandler',
postprocess: Callable) -> List[str]:
batch_size = 16
batch_num = (len(inputs) + batch_size - 1) // batch_size
all_indexes = [i for i in range(len(inputs))]
indexes = [i for i in all_indexes if i not in completed_indexes]
inputs = [inputs[i] for i in indexes]
result_responses = []
result_indexes = []
def thread_worker(inputs, max_out_len, temperature, indexes, model):
return model.generate(inputs, max_out_len,
temperature), inputs, indexes
if len(indexes) > 0:
with ThreadPoolExecutor(max_workers=len(models)) as pool:
tasks = [
pool.submit(
partial(thread_worker, model=models[i % len(models)]),
inputs[i * batch_size:(i + 1) * batch_size], 8192, 0.0,
indexes[i * batch_size:(i + 1) * batch_size])
for i in range(batch_num)
]
for completed_task in as_completed(tasks):
responses, current_inputs, indexes = completed_task.result(
)
for input, response, index in zip(current_inputs,
responses, indexes):
output_handler.save(
index,
prompt=input,
response=response,
postprocess_response=postprocess(response))
result_responses.append(postprocess(response))
result_indexes.append(index)
output_handler.write_to_json()
return [
output_handler.output_dict[str(i)]['postprocess_response']
for i in all_indexes
]
def extract(self, questions: List[str], predictions: List[str],
question_types: List[str], languages: List[str]) -> List[str]:
# extract answer by model
if self.use_extract_model:
assert len(self.extract_url) > 0 and self.extract_model_name != ''
extract_models = [
MODELS.build(
dict(
type=OpenAISDK,
path=self.extract_model_name,
openai_api_base=url,
key='EMPTY',
query_per_second=2,
retry=5,
meta_template=self.api_meta_template,
temperature=0.01,
temperature=0.0,
max_seq_len=1024,
)) for url in self.post_url
)) for url in self.extract_url
]
completed_indexes = []
mmengine.mkdir_or_exist(self.output_dir)
tmp_json_file_path = os.path.join(self.output_dir,
'tmp_extract.json')
self.extract_output_handler.save_file_path = tmp_json_file_path
if os.path.exists(tmp_json_file_path):
tmp_dict = mmengine.load(tmp_json_file_path)
self.extract_output_handler.output_dict = tmp_dict
for index in tmp_dict:
completed_indexes.add(int(index))
input_prompts = []
for question, prediction, question_type, language in zip(
questions, predictions, question_types, languages):
for question, prediction, question_type, language in enumerate(
zip(questions, predictions, question_types, languages)):
prompt = (EXTRACT_PROMPT_EN
if language == 'en' else EXTRACT_PROMPT_CN)
input_prompts.append(
@ -163,245 +247,125 @@ class LiveMathBenchEvaluator(BaseEvaluator):
response=prediction,
question_type=question_type))
result_responses = self.batch_response(post_model, input_prompts)
results = self.batch_infer(extract_models,
input_prompts,
completed_indexes,
self.extract_output_handler,
postprocess=lambda x: x)
return result_responses
return results
def last_boxed_only_string(string):
idx = string.rfind('\\boxed')
if idx < 0:
idx = string.rfind('\\fbox')
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == '{':
num_left_braces_open += 1
if string[i] == '}':
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]
return retval
def remove_boxed(s):
left = '\\boxed{'
try:
assert s[:len(left)] == left
assert s[-1] == '}'
return s[len(left):-1]
except Exception:
return None
def extract_boxed_answer(pred_str, strip_double_curly_brace=False):
boxed_str = last_boxed_only_string(pred_str)
if boxed_str is None:
return None
answer = remove_boxed(boxed_str)
if answer is None:
return None
if strip_double_curly_brace:
match = re.match('^\{(.*)\}$', answer) # noqa: W605
if match:
answer = match.group(1)
return answer
predictions = [
extract_boxed_answer(prediction) for prediction in predictions
# extract answer in \\boxed{}
results = [
math_postprocess_v2(prediction) for prediction in predictions
]
return predictions
return results
def extract_boxed_answer(self, text):
match = re.findall(r'\\boxed{(.+?)}', text)
if match:
return match[-1]
return None
def score(self, predictions, references, origin_prompt, test_set):
def judge(self, predictions, references, test_set):
if len(predictions) != len(references):
return {'error': 'preds and refrs have different length'}
raise ValueError('preds and refrs have different length')
completed_indexes = set()
mmengine.mkdir_or_exist(self.output_dir)
tmp_json_file_path = os.path.join(self.output_dir, 'tmp_judge.json')
self.judge_output_handler.save_file_path = tmp_json_file_path
if os.path.exists(tmp_json_file_path):
tmp_dict = mmengine.load(tmp_json_file_path)
self.judge_output_handler.output_dict = tmp_dict
for index in tmp_dict:
completed_indexes.add(int(index))
questions = test_set['question']
question_types = test_set['question_type']
languages = [key.split('_')[1] for key in test_set['dataset_key']]
languages = [key.split('_')[1] for key in test_set['subdivision']]
if self.with_postprocess:
predictions = self.postprocess(questions, predictions,
question_types, languages)
predictions = self.extract(questions, predictions, question_types,
languages)
inputs = []
for prediction, reference, question, language in zip(
predictions, references, questions, languages):
prompt = JUDGE_PROMPT_EN if language == 'en' else JUDGE_PROMPT_CN
inputs.append(
prompt.format(answer=prediction,
gold_answer=reference,
question=question))
result_responses = self.batch_response(self.model, inputs)
results = [
self.extract_boxed_answer(result) == 'yes'
for result in result_responses
]
if len(self.judge_models) > 0:
inputs = []
for prediction, reference, question, language in zip(
predictions, references, questions, languages):
prompt = (JUDGE_PROMPT_EN
if language == 'en' else JUDGE_PROMPT_CN)
inputs.append(
prompt.format(answer=prediction,
gold_answer=reference,
question=question))
K = test_set['k'][0]
N = test_set['n'][0]
key2example = {}
for example, result_response, result, prediction in zip(
test_set, result_responses, results, predictions):
if example['dataset_key'] not in key2example:
key2example[example['dataset_key']] = []
example.update({
'eval_response': result_response,
'prediction': prediction,
'correct': result
})
key2example[example['dataset_key']].append(example)
for key in key2example:
key2example[key] = [
key2example[key][i * K:(i + 1) * K] for i in range(N)
labels = self.batch_infer(
self.judge_models, inputs, completed_indexes,
self.judge_output_handler, lambda x:
(1 if extract_judge_label(x) == 'yes' else 0))
else:
is_equiv = MATHAgentEvaluator(version='v2').is_equiv
labels = [
1 if is_equiv(prediction, reference) else 0
for prediction, reference in zip(predictions, references)
]
return labels
count = []
total_pass_num = []
details = []
all_dataset = set()
for key, examples in key2example.items():
detail = OrderedDict()
detail['question'] = examples[0][0]['question']
detail['answer'] = examples[0][0]['answer']
detail['responses'] = []
detail['dataset'] = '_'.join(key.split('_')[:-1])
all_dataset.add('_'.join(key.split('_')[:-1]))
if_pass_list = []
for single_run_examples in examples:
detail['responses'].append([])
if_pass_list.append([])
for example in single_run_examples:
detail['responses'][-1].append({
'prediction':
example['prediction'],
'eval_response':
example['eval_response']
})
if_pass_list[-1].append(1.0 if example['correct'] else 0.0)
def preprocess(self, predictions, references, test_set):
return self.judge(predictions, references, test_set)
if_pass_list = [
sorted(if_pass, reverse=True) for if_pass in if_pass_list
]
if_pass_list = np.array(if_pass_list)
i = 1
while i <= K:
detail.update({
f'pass-rate@{i}':
if_pass_list[:, :i].mean(axis=1).mean(axis=0).item(),
f'pass-rate@{i}/std':
if_pass_list[:, :i].mean(axis=1).std(axis=0).item(),
f'pass@{i}':
np.ceil(
if_pass_list[:, :i].mean(axis=1)).mean(axis=0).item(),
f'pass@{i}/std':
np.ceil(
if_pass_list[:, :i].mean(axis=1)).std(axis=0).item(),
})
i = i * 2
def group(self, predictions, labels, test_set):
example2replications = {}
for example, label, prediction in zip(test_set, labels, predictions):
example_abbr = f"{example['subdivision']}_{example['idx']}"
if example_abbr not in example2replications:
example2replications[example_abbr] = []
example.update({'prediction': prediction, 'label': label})
example2replications[example_abbr].append(example)
for _, replications in example2replications.items():
assert len(replications) == self.n, print(len(replications),
self.n)
return example2replications
for threshold in [0.5, 0.75, 1.0]:
detail.update({
f'{K}-pass@{threshold}':
np.floor(
np.where(
if_pass_list.mean(axis=1) >= threshold, 1.0,
0.0).mean(axis=0))
})
def reduce(self, details) -> Dict[str, Any]:
"""Aggregate the overall metrics.
count.append(np.ones_like(if_pass_list).sum(axis=1))
total_pass_num.append(if_pass_list.sum(axis=1))
Return:
A dict contains overall metrics, like:
{'details': details for each example, 'G-Pass@16': xxx}
"""
g_passk_details = OrderedDict()
g_passk_details['details'] = details
details.append(detail)
all_dataset = set([detail['subdivision'] for detail in details])
detailed_result = OrderedDict()
detailed_result['details'] = details
i = 1
while i <= K:
detailed_result.update({
f'pass-rate@{i}':
100. *
np.mean([detail[f'pass-rate@{i}'] for detail in details]),
f'pass-rate@{i}/std':
100. *
np.mean([detail[f'pass-rate@{i}/std'] for detail in details]),
f'pass@{i}':
100. * np.mean([detail[f'pass@{i}'] for detail in details]),
f'pass@{i}/std':
100. * np.mean([detail[f'pass@{i}/std'] for detail in details])
})
for d in sorted(list(all_dataset)):
detailed_result.update({
f'{d}/pass-rate@{i}':
100. * np.mean([
detail[f'pass-rate@{i}']
for detail in details if detail['dataset'] == d
]),
f'{d}/pass-rate@{i}/std':
100. * np.mean([
detail[f'pass-rate@{i}/std']
for detail in details if detail['dataset'] == d
]),
f'{d}/pass@{i}':
100. * np.mean([
detail[f'pass@{i}']
for detail in details if detail['dataset'] == d
]),
f'{d}/pass@{i}/std':
100. * np.mean([
detail[f'pass@{i}/std']
for detail in details if detail['dataset'] == d
for k in self.k:
for subdivision in sorted(list(all_dataset)):
for threshold in self.thresholds:
g_passk_details[
f'{subdivision}/G-Pass@{k}_{threshold}'] = \
100. * np.mean(
[
detail[f'G-Pass@{k}_{threshold}']
for detail in details
if detail['subdivision'] == subdivision
])
g_passk_details[f'{subdivision}/mG-Pass@{k}'] = 100. * np.mean(
[
detail[f'mG-Pass@{k}'] for detail in details
if detail['subdivision'] == subdivision
])
})
i = i * 2
for threshold in [0.5, 0.75, 1.0]:
detailed_result.update({
f'{K}-pass@{threshold}':
100. * np.mean([
detail[f'{K}-pass@{threshold}'] for detail in details
])
})
detailed_result.update({
f'{K}-pass@{threshold}/std':
100. * np.mean([
detail[f'{K}-pass@{threshold}'] for detail in details
])
})
for d in sorted(list(all_dataset)):
for threshold in self.thresholds:
g_passk_details[f'G-Pass@{k}_{threshold}'] = 100. * np.mean(
[detail[f'G-Pass@{k}_{threshold}'] for detail in details])
g_passk_details[f'mG-Pass@{k}'] = 100. * np.mean(
[detail[f'mG-Pass@{k}'] for detail in details])
for threshold in [0.5, 0.75, 1.0]:
detailed_result.update({
f'{d}/{K}-pass@{threshold}':
100. * np.mean([
detail[f'{K}-pass@{threshold}']
for detail in details if detail['dataset'] == d
])
})
detailed_result.update({
f'{d}/{K}-pass@{threshold}/std':
100. * np.mean([
detail[f'{K}-pass@{threshold}']
for detail in details if detail['dataset'] == d
])
})
return g_passk_details
return detailed_result
class LiveMathBenchOutputHandler:
output_dict = {}
save_file_path = ''
def write_to_json(self):
"""Dump the result to a json file."""
dump_results_dict(self.output_dict, self.save_file_path)
def save(self, idx, **kwargs):
self.output_dict[str(idx)] = kwargs

View File

@ -0,0 +1,10 @@
import re
def extract_judge_label(text):
if isinstance(text, str):
match = re.findall(r'\\boxed{(.+?)}', text)
if match:
return match[-1]
return None

View File

@ -106,7 +106,7 @@ def _format_with_fast_chat_template(inputs: List[str], name: str='vicuna'):
elif item['role'] == 'system':
continue
else:
raise ValueError(f'Unknown role {item["role"]}')
raise ValueError(f"Unknown role {item['role']}")
template.append_message(template.roles[1], None)
outputs.append(template.get_prompt())
return outputs
@ -474,6 +474,8 @@ class HuggingFacewithChatTemplate(BaseModel):
if min_out_len is not None:
generation_kwargs['min_new_tokens'] = min_out_len
generation_kwargs['pad_token_id'] = self.tokenizer.pad_token_id
self.logger.info('Generation Args of Huggingface: ')
self.logger.info(generation_kwargs)
# step-2: conduct model forward to generate output
outputs = self.model.generate(**tokens, **generation_kwargs)

View File

@ -516,10 +516,13 @@ class OpenAISDK(OpenAI):
# support multiple api_base for acceleration
if isinstance(openai_api_base, List):
openai_api_base = random.choice(openai_api_base)
self.openai_api_base = random.choice(openai_api_base)
else:
self.openai_api_base = openai_api_base
if self.proxy_url is None:
self.openai_client = OpenAI(base_url=openai_api_base, api_key=key)
self.openai_client = OpenAI(base_url=self.openai_api_base,
api_key=key)
else:
proxies = {
'http://': self.proxy_url,
@ -527,7 +530,7 @@ class OpenAISDK(OpenAI):
}
self.openai_client = OpenAI(
base_url=openai_api_base,
base_url=self.openai_api_base,
api_key=key,
http_client=httpx.Client(proxies=proxies))
if self.verbose:
@ -617,8 +620,8 @@ class OpenAISDK(OpenAI):
'Successfully get response from OpenAI API')
try:
self.logger.info(responses)
except Exception as e: # noqa F841
pass
except Exception:
pass # noqa F841
if not responses.choices:
self.logger.error(
'Response is empty, it is an internal server error \
@ -635,13 +638,18 @@ class OpenAISDK(OpenAI):
if (status_code is not None
and status_code in self.status_code_mappings):
error_message = self.status_code_mappings[status_code]
self.logger.error(
f'error occurs at {self.openai_api_base}')
self.logger.info(f'Status Code: {status_code}, \n'
f'Original Error Message: {e}, \n'
f'Return Message: {error_message} ')
return error_message
else:
self.logger.error(
f'error occurs at {self.openai_api_base}')
self.logger.error(e)
except Exception as e:
self.logger.error(f'error occurs at {self.openai_api_base}')
self.logger.error(e)
num_retries += 1
raise RuntimeError('Calling OpenAI API failed after retrying for '

View File

@ -128,10 +128,7 @@ class TurboMindModelwithChatTemplate(BaseModel):
gen_config['max_new_tokens'] = max_out_len
if min_out_len is not None:
gen_config['min_new_tokens'] = min_out_len
if do_sample or ('do_sample' in self.gen_config and self.gen_config['do_sample']):
gen_config['top_k'] = 40
gen_config['temperature'] = temperature
else:
if not(do_sample or ('do_sample' in self.gen_config and self.gen_config['do_sample'])):
if self.version_info >= (0, 6, 0):
gen_config['do_sample'] = False
else:
@ -140,6 +137,8 @@ class TurboMindModelwithChatTemplate(BaseModel):
from lmdeploy import GenerationConfig
gen_config = {k: v for k, v in gen_config.items() if hasattr(GenerationConfig, k)}
gen_config = GenerationConfig(**gen_config)
self.logger.info('Generation Config of LMdeploy: ')
self.logger.info(gen_config)
results = []
outputs = self.pipe(messages, gen_config=gen_config, do_preprocess=False)

View File

@ -108,6 +108,8 @@ class VLLMwithChatTemplate(BaseModel):
sampling_kwargs.update(self.generation_kwargs)
sampling_kwargs.update(kwargs)
sampling_kwargs = SamplingParams(**sampling_kwargs)
self.logger.info('Sampling Params of vLLM: ')
self.logger.info(sampling_kwargs)
outputs = self.model.generate(messages, sampling_kwargs)

View File

@ -4,6 +4,7 @@ from .icl_base_evaluator import BaseEvaluator # noqa
from .icl_bpc_evaluator import BPCEvaluator # noqa
from .icl_circular_evaluator import CircularEvaluator # noqa
from .icl_em_evaluator import EMEvaluator # noqa
from .icl_gpassk_evaluator import GPassKEvaluator # noqa
from .icl_hf_evaluator import * # noqa
from .icl_jieba_rouge_evaluator import JiebaRougeEvaluator # noqa
from .icl_misc_evaluator import AverageInferencePPLEvaluator # noqa

View File

@ -0,0 +1,163 @@
from abc import abstractmethod
from typing import Any, Dict, List, Union
import numpy as np
from scipy.stats import hypergeom
from opencompass.registry import ICL_EVALUATORS
from .icl_base_evaluator import BaseEvaluator
def compute_pass_at_k(n, c, k):
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
def _compute_g_pass_at_k(n, c, k, m):
if m > min(c, k) or k > n or c < 0 or n <= 0 or m < 0:
return 0.0
return hypergeom.sf(m - 1, n, c, k)
def compute_g_pass_at_k(n, c, k, t):
m = max(int(np.ceil(k * t)), 1)
return _compute_g_pass_at_k(n, c, k, m)
def compute_mg_pass_at_k(n, c, k):
l, r = int(np.ceil(k * 0.5)), k
mg_pass_at_k = 0.0
for i in range(l + 1, r + 1):
mg_pass_at_k += _compute_g_pass_at_k(n, c, k, i)
mg_pass_at_k = 2 * mg_pass_at_k / k
return mg_pass_at_k
@ICL_EVALUATORS.register_module()
class GPassKEvaluator(BaseEvaluator):
"""Evaluator for computing the G-Pass@k Metric.
This evaluator performs the following steps:
1. Invokes task-specific `preprocess` on predictions to
assign a consistency label to each prediction and its
corresponding reference.
2. Calculates metrics for each input example based on
these labels.
3. Aggregates the overall metrics through a task-specific
`postprocess`.
Args:
k (int or list of int): Number of predictions to be
considered in G-Pass@k. It can be a single integer
(e.g., `k=16` computes G-Pass@16) or a list of
integers (e.g., `[4, 8, 16]` computes G-Pass@4,
G-Pass@8, and G-Pass@16).
replication (int): Controls the number of generations
used to estimate G-Pass@k. The total number of
generations is determined by multiplying the
maximum of `k` with `replication`. This parameter
should be a single integer.
thresholds (list of float): A list of floating-point
numbers that define the thresholds for the G-Pass@k
metric.
"""
def __init__(
self,
k: Union[int, List[int]] = 16,
replication: int = 3,
thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]) -> None:
super().__init__()
if isinstance(k, int):
k = [k]
self.k = k
self.replication = replication
self.n = max(k) * replication
self.thresholds = thresholds
@property
def output_dir(self):
# please see opencompass/opencompass/tasks/openicl_eval.py Line 197-200
return self._out_dir
@abstractmethod
def preprocess(self, predictions, references, test_set) -> None:
"""Perform operations on predictions before computing metrics, for
example, do answer_extraction and model_judge in mathematical reasoning
task.
Return:
labels: A list contains the label which indicates whether
prediction is consistency with reference at each position.
"""
raise NotImplementedError
@abstractmethod
def group(self, predictions, labels, test_set) -> Dict[str, Any]:
"""Group the predictions and references.
Return:
A dict contains the grouped predictions and references.
"""
raise NotImplementedError
@abstractmethod
def reduce(self, details) -> Dict[str, Any]:
"""Aggregate the overall metrics.
Return:
A dict contains overall metrics, like:
{'details': details for each example, 'G-Pass@16': xxx}
"""
raise NotImplementedError
def score(self, predictions, references, test_set) -> Dict[str, Any]:
"""Compute G-Pass@k metrics.
Return:
A dict contains metrics for each dataset sample and
overall metrics reduced by `self.reduce`, like:
{'details': details for each example, 'G-Pass@16': xxx}
"""
labels = self.preprocess(predictions, references, test_set)
grouped_examples = self.group(predictions, labels, test_set)
details = []
total_pass_num, count = 0, 0
for example_abbr, examples in grouped_examples.items():
detail = {
k: v
for k, v in examples[0].items()
if k not in ['prediction', 'label']
}
detail.update({
'predictions': [{
'prediction': example['prediction'],
'label': example['label']
} for example in examples],
})
current_example_labels = [e['label'] for e in examples]
c = int(np.sum(current_example_labels))
for k in self.k:
for threshold in self.thresholds:
detail[f'G-Pass@{k}_{threshold}'] = compute_g_pass_at_k(
n=self.n, c=c, k=k, t=threshold)
detail[f'mG-Pass@{k}'] = compute_mg_pass_at_k(n=self.n,
c=c,
k=k)
count += self.n
total_pass_num += c
details.append(detail)
return self.reduce(details)