mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support G-Pass@k and LiveMathBench (#1772)
* support G-Pass@k and livemathbench * fix bugs * fix comments of GPassKEvaluator * update saved details of GPassKEvaluator * update saved details of GPassKEvaluator * fix eval api configs & update openai_api for ease of debugging * update huggingface path * fix method name of G-Pass@k * fix default value of eval_model_name * refactor G-Pass@k evaluator * log generation params for each backend * fix evaluation resume * add notimplementerror
This commit is contained in:
parent
42b54d6bb8
commit
8e8d4f1c64
@ -1,4 +1,4 @@
|
|||||||
from mmengine.config import read_base
|
from mmengine.config import read_base
|
||||||
|
|
||||||
with read_base():
|
with read_base():
|
||||||
from .livemathbench_gen_caed8f import livemathbench_datasets # noqa: F401, F403
|
from .livemathbench_gen_9befbf import livemathbench_datasets # noqa: F401, F403
|
@ -0,0 +1,51 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
|
|
||||||
|
from opencompass.datasets.livemathbench import LiveMathBenchDataset, LiveMathBenchEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
livemathbench_dataset = dict(
|
||||||
|
type=LiveMathBenchDataset,
|
||||||
|
path='',
|
||||||
|
k=16,
|
||||||
|
replication=3,
|
||||||
|
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
|
||||||
|
dataset_languages=['cn', 'en'],
|
||||||
|
cot=True,
|
||||||
|
version='202412',
|
||||||
|
abbr='LiveMathBench-v202412',
|
||||||
|
reader_cfg=dict(
|
||||||
|
input_columns=['prompt'],
|
||||||
|
output_column='answer'
|
||||||
|
),
|
||||||
|
infer_cfg=dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(
|
||||||
|
round=[
|
||||||
|
dict(role='HUMAN', prompt='{prompt}'),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(
|
||||||
|
type=GenInferencer,
|
||||||
|
max_out_len=8192
|
||||||
|
),
|
||||||
|
),
|
||||||
|
eval_cfg=dict(
|
||||||
|
evaluator=dict(
|
||||||
|
type=LiveMathBenchEvaluator,
|
||||||
|
model_name='',
|
||||||
|
url=[],
|
||||||
|
use_extract_model=False,
|
||||||
|
extract_url=[],
|
||||||
|
extract_model_name='',
|
||||||
|
k=[4, 8, 16],
|
||||||
|
replication=3,
|
||||||
|
thresholds=[0.0, 0.25, 0.5, 0.75, 1.0]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
livemathbench_datasets = [livemathbench_dataset]
|
@ -1,45 +1,55 @@
|
|||||||
import concurrent.futures
|
|
||||||
import os
|
import os
|
||||||
import re
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from functools import partial
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Callable, Dict, List, Union
|
||||||
|
|
||||||
import jsonlines
|
import jsonlines
|
||||||
|
import mmengine
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import Dataset
|
from datasets import Dataset, load_dataset
|
||||||
|
|
||||||
|
from opencompass.datasets.math import MATHAgentEvaluator, math_postprocess_v2
|
||||||
from opencompass.models import OpenAISDK
|
from opencompass.models import OpenAISDK
|
||||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
from opencompass.openicl.icl_evaluator import GPassKEvaluator
|
||||||
|
from opencompass.openicl.icl_inferencer.icl_base_inferencer import \
|
||||||
|
dump_results_dict
|
||||||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, MODELS
|
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, MODELS
|
||||||
from opencompass.utils import get_data_path
|
from opencompass.utils import get_data_path
|
||||||
|
|
||||||
from ..base import BaseDataset
|
from ..base import BaseDataset
|
||||||
from .prompts import (EXTRACT_PROMPT_CN, EXTRACT_PROMPT_EN, JUDGE_PROMPT_CN,
|
from .prompts import (EXTRACT_PROMPT_CN, EXTRACT_PROMPT_EN, JUDGE_PROMPT_CN,
|
||||||
JUDGE_PROMPT_EN, PROMPT_CN, PROMPT_EN)
|
JUDGE_PROMPT_EN, PROMPT_CN, PROMPT_EN)
|
||||||
|
from .utils import extract_judge_label
|
||||||
|
|
||||||
|
|
||||||
@LOAD_DATASET.register_module()
|
@LOAD_DATASET.register_module()
|
||||||
class LiveMathBenchDataset(BaseDataset):
|
class LiveMathBenchDataset(BaseDataset):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(
|
def load(path: str,
|
||||||
path: str,
|
k: Union[int, List[int]],
|
||||||
k: int,
|
replication: int,
|
||||||
n: int,
|
dataset_splits: List[str] = [
|
||||||
dataset_splits: List[str] = [
|
'CNMO',
|
||||||
'AIMC', 'CEE', 'CMO', 'MATH500', 'AIME2024'
|
'CCEE',
|
||||||
],
|
'AMC',
|
||||||
dataset_languages: List[str] = ['cn', 'en'],
|
'WLPMC',
|
||||||
) -> List[Dict[str, Any]]:
|
],
|
||||||
|
dataset_languages: List[str] = ['cn', 'en'],
|
||||||
|
cot: bool = True,
|
||||||
|
version: str = '202412') -> List[Dict[str, Any]]:
|
||||||
dataset = []
|
dataset = []
|
||||||
dataset_info = {}
|
dataset_info = {}
|
||||||
path = get_data_path(path)
|
|
||||||
|
if path != '':
|
||||||
|
path = get_data_path(path)
|
||||||
|
head, tail = os.path.split(path)
|
||||||
|
path = os.path.join(head, f'{tail}-{version}')
|
||||||
for split, language in product(dataset_splits, dataset_languages):
|
for split, language in product(dataset_splits, dataset_languages):
|
||||||
file_path = os.path.join(path, f'{split}_{language}.jsonl')
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
continue
|
|
||||||
dataset_info[f'{split}_{language}'] = {
|
dataset_info[f'{split}_{language}'] = {
|
||||||
'single-choice': 0,
|
'single-choice': 0,
|
||||||
'multiple-choice': 0,
|
'multiple-choice': 0,
|
||||||
@ -52,36 +62,57 @@ class LiveMathBenchDataset(BaseDataset):
|
|||||||
'填空': 'fill-in-the-blank',
|
'填空': 'fill-in-the-blank',
|
||||||
'问答': 'problem-solving'
|
'问答': 'problem-solving'
|
||||||
}
|
}
|
||||||
with jsonlines.open(file_path, 'r') as file:
|
|
||||||
for example_idx, example in enumerate(file):
|
|
||||||
dataset_info[f'{split}_{language}'][
|
|
||||||
example['question_type'] if language == 'en' else
|
|
||||||
question_type_mapping[example['question_type']]] += 1
|
|
||||||
|
|
||||||
prompt = PROMPT_EN if language == 'en' else PROMPT_CN
|
if path != '':
|
||||||
example.update({
|
file_path = os.path.join(path, f'{split}_{language}.jsonl')
|
||||||
'dataset_key':
|
if not os.path.exists(file_path):
|
||||||
f'{split}_{language}_{example_idx}',
|
continue
|
||||||
'prompt':
|
examples = []
|
||||||
prompt.format(question_type=example['question_type'],
|
with jsonlines.open(file_path, 'r') as file:
|
||||||
question=example['question'] +
|
for example in file:
|
||||||
('' if 'options' not in example else
|
examples.append(example)
|
||||||
' '.join(example['options']))),
|
else:
|
||||||
'k':
|
hf_dataset = load_dataset(
|
||||||
k,
|
'opencompass/LiveMathBench',
|
||||||
'n':
|
f'v{version}_{split}_{language}')['test']
|
||||||
n
|
examples = []
|
||||||
})
|
for example in hf_dataset:
|
||||||
for idx in range(k * n):
|
examples.append(example)
|
||||||
duplicated_example = deepcopy(example)
|
|
||||||
duplicated_example.update({'duplicated_idx': idx})
|
for example_idx, example in enumerate(examples):
|
||||||
dataset.append(duplicated_example)
|
dataset_info[f'{split}_{language}'][
|
||||||
|
example['question_type'] if language == 'en' else
|
||||||
|
question_type_mapping[example['question_type']]] += 1
|
||||||
|
|
||||||
|
prompt = PROMPT_EN if language == 'en' else PROMPT_CN
|
||||||
|
if not cot:
|
||||||
|
if language == 'cn':
|
||||||
|
prompt = prompt.replace(',请逐步推理', '')
|
||||||
|
else:
|
||||||
|
prompt = prompt.replace(
|
||||||
|
', please reasoning step by step', '')
|
||||||
|
example.update({
|
||||||
|
'subdivision':
|
||||||
|
f'{split}_{language}',
|
||||||
|
'idx':
|
||||||
|
str(example_idx),
|
||||||
|
'prompt':
|
||||||
|
prompt.format(question_type=example['question_type'],
|
||||||
|
question=example['question'] +
|
||||||
|
('' if 'options' not in example else
|
||||||
|
' '.join(example['options']))),
|
||||||
|
})
|
||||||
|
max_k = k if isinstance(k, int) else max(k)
|
||||||
|
for idx in range(max_k * replication):
|
||||||
|
duplicated_example = deepcopy(example)
|
||||||
|
duplicated_example.update({'replication_idx': idx})
|
||||||
|
dataset.append(duplicated_example)
|
||||||
|
|
||||||
return Dataset.from_list(dataset)
|
return Dataset.from_list(dataset)
|
||||||
|
|
||||||
|
|
||||||
@ICL_EVALUATORS.register_module()
|
@ICL_EVALUATORS.register_module()
|
||||||
class LiveMathBenchEvaluator(BaseEvaluator):
|
class LiveMathBenchEvaluator(GPassKEvaluator):
|
||||||
api_meta_template = dict(round=[
|
api_meta_template = dict(round=[
|
||||||
dict(role='HUMAN', api_role='HUMAN'),
|
dict(role='HUMAN', api_role='HUMAN'),
|
||||||
dict(role='BOT', api_role='BOT', generate=True),
|
dict(role='BOT', api_role='BOT', generate=True),
|
||||||
@ -90,72 +121,125 @@ class LiveMathBenchEvaluator(BaseEvaluator):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_name,
|
model_name,
|
||||||
url,
|
url,
|
||||||
with_postprocess=True,
|
|
||||||
use_extract_model=False,
|
use_extract_model=False,
|
||||||
post_url=[],
|
extract_url=[],
|
||||||
post_model_name='',
|
extract_model_name='',
|
||||||
**kwargs):
|
k: Union[int, List[int]] = 16,
|
||||||
|
replication: int = 3,
|
||||||
|
thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]):
|
||||||
|
super().__init__(k, replication, thresholds)
|
||||||
|
|
||||||
if isinstance(url, str):
|
if isinstance(url, str):
|
||||||
url = [url]
|
url = [url]
|
||||||
|
|
||||||
self.model = [
|
if model_name == '' or len(url) == 0:
|
||||||
MODELS.build(
|
warnings.warn('Unable to leverage LLM-as-judge abd backup to '
|
||||||
dict(
|
'rule-based judge due to incomplete parameters, '
|
||||||
type=OpenAISDK,
|
'this may cause performance degradation, check '
|
||||||
path=model_name,
|
'`model_name` or `url` of evaluator if you do '
|
||||||
openai_api_base=url,
|
'not want to do this.')
|
||||||
key='EMPTY',
|
self.judge_models = []
|
||||||
query_per_second=128,
|
else:
|
||||||
meta_template=self.api_meta_template,
|
self.judge_models = [
|
||||||
temperature=kwargs.get('temperature', 0.001),
|
|
||||||
max_seq_len=kwargs.get('max_tokens', 16384),
|
|
||||||
)) for url in url
|
|
||||||
]
|
|
||||||
self.with_postprocess = with_postprocess
|
|
||||||
self.use_extract_model = use_extract_model
|
|
||||||
self.post_url = post_url
|
|
||||||
self.post_model_name = post_model_name
|
|
||||||
|
|
||||||
def batch_response(self, models: List[OpenAISDK],
|
|
||||||
inputs: List[str]) -> List[str]:
|
|
||||||
batch_num = len(models)
|
|
||||||
batch_size = (len(inputs) + batch_num - 1) // batch_num
|
|
||||||
result_responses = []
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(
|
|
||||||
max_workers=batch_num) as executor:
|
|
||||||
futures = [
|
|
||||||
executor.submit(models[i].generate,
|
|
||||||
inputs[i * batch_size:(i + 1) * batch_size])
|
|
||||||
for i in range(batch_num)
|
|
||||||
]
|
|
||||||
for response in executor.map(lambda f: f.result(), futures):
|
|
||||||
result_responses.extend(response)
|
|
||||||
|
|
||||||
return result_responses
|
|
||||||
|
|
||||||
def postprocess(self, questions: List[str], predictions: List[str],
|
|
||||||
question_types: List[str],
|
|
||||||
languages: List[str]) -> List[str]:
|
|
||||||
if self.use_extract_model:
|
|
||||||
assert len(self.post_url) > 0 and self.post_model_name != ''
|
|
||||||
post_model = [
|
|
||||||
MODELS.build(
|
MODELS.build(
|
||||||
dict(
|
dict(
|
||||||
type=OpenAISDK,
|
type=OpenAISDK,
|
||||||
path=self.post_model_name,
|
path=model_name,
|
||||||
|
openai_api_base=_url,
|
||||||
|
key='EMPTY',
|
||||||
|
query_per_second=2,
|
||||||
|
retry=5,
|
||||||
|
meta_template=self.api_meta_template,
|
||||||
|
temperature=0.0,
|
||||||
|
max_seq_len=16384,
|
||||||
|
)) for _url in url
|
||||||
|
]
|
||||||
|
self.use_extract_model = use_extract_model
|
||||||
|
self.extract_url = extract_url
|
||||||
|
self.extract_model_name = extract_model_name
|
||||||
|
|
||||||
|
self.extract_output_handler = LiveMathBenchOutputHandler()
|
||||||
|
self.judge_output_handler = LiveMathBenchOutputHandler()
|
||||||
|
|
||||||
|
def batch_infer(self, models: List[OpenAISDK], inputs: List[str],
|
||||||
|
completed_indexes: set,
|
||||||
|
output_handler: 'LiveMathBenchOutputHandler',
|
||||||
|
postprocess: Callable) -> List[str]:
|
||||||
|
batch_size = 16
|
||||||
|
batch_num = (len(inputs) + batch_size - 1) // batch_size
|
||||||
|
all_indexes = [i for i in range(len(inputs))]
|
||||||
|
indexes = [i for i in all_indexes if i not in completed_indexes]
|
||||||
|
inputs = [inputs[i] for i in indexes]
|
||||||
|
result_responses = []
|
||||||
|
result_indexes = []
|
||||||
|
|
||||||
|
def thread_worker(inputs, max_out_len, temperature, indexes, model):
|
||||||
|
return model.generate(inputs, max_out_len,
|
||||||
|
temperature), inputs, indexes
|
||||||
|
|
||||||
|
if len(indexes) > 0:
|
||||||
|
with ThreadPoolExecutor(max_workers=len(models)) as pool:
|
||||||
|
tasks = [
|
||||||
|
pool.submit(
|
||||||
|
partial(thread_worker, model=models[i % len(models)]),
|
||||||
|
inputs[i * batch_size:(i + 1) * batch_size], 8192, 0.0,
|
||||||
|
indexes[i * batch_size:(i + 1) * batch_size])
|
||||||
|
for i in range(batch_num)
|
||||||
|
]
|
||||||
|
for completed_task in as_completed(tasks):
|
||||||
|
responses, current_inputs, indexes = completed_task.result(
|
||||||
|
)
|
||||||
|
for input, response, index in zip(current_inputs,
|
||||||
|
responses, indexes):
|
||||||
|
output_handler.save(
|
||||||
|
index,
|
||||||
|
prompt=input,
|
||||||
|
response=response,
|
||||||
|
postprocess_response=postprocess(response))
|
||||||
|
result_responses.append(postprocess(response))
|
||||||
|
result_indexes.append(index)
|
||||||
|
output_handler.write_to_json()
|
||||||
|
|
||||||
|
return [
|
||||||
|
output_handler.output_dict[str(i)]['postprocess_response']
|
||||||
|
for i in all_indexes
|
||||||
|
]
|
||||||
|
|
||||||
|
def extract(self, questions: List[str], predictions: List[str],
|
||||||
|
question_types: List[str], languages: List[str]) -> List[str]:
|
||||||
|
|
||||||
|
# extract answer by model
|
||||||
|
if self.use_extract_model:
|
||||||
|
assert len(self.extract_url) > 0 and self.extract_model_name != ''
|
||||||
|
extract_models = [
|
||||||
|
MODELS.build(
|
||||||
|
dict(
|
||||||
|
type=OpenAISDK,
|
||||||
|
path=self.extract_model_name,
|
||||||
openai_api_base=url,
|
openai_api_base=url,
|
||||||
key='EMPTY',
|
key='EMPTY',
|
||||||
query_per_second=2,
|
query_per_second=2,
|
||||||
|
retry=5,
|
||||||
meta_template=self.api_meta_template,
|
meta_template=self.api_meta_template,
|
||||||
temperature=0.01,
|
temperature=0.0,
|
||||||
max_seq_len=1024,
|
max_seq_len=1024,
|
||||||
)) for url in self.post_url
|
)) for url in self.extract_url
|
||||||
]
|
]
|
||||||
|
|
||||||
|
completed_indexes = []
|
||||||
|
mmengine.mkdir_or_exist(self.output_dir)
|
||||||
|
tmp_json_file_path = os.path.join(self.output_dir,
|
||||||
|
'tmp_extract.json')
|
||||||
|
self.extract_output_handler.save_file_path = tmp_json_file_path
|
||||||
|
if os.path.exists(tmp_json_file_path):
|
||||||
|
tmp_dict = mmengine.load(tmp_json_file_path)
|
||||||
|
self.extract_output_handler.output_dict = tmp_dict
|
||||||
|
for index in tmp_dict:
|
||||||
|
completed_indexes.add(int(index))
|
||||||
|
|
||||||
input_prompts = []
|
input_prompts = []
|
||||||
for question, prediction, question_type, language in zip(
|
for question, prediction, question_type, language in enumerate(
|
||||||
questions, predictions, question_types, languages):
|
zip(questions, predictions, question_types, languages)):
|
||||||
prompt = (EXTRACT_PROMPT_EN
|
prompt = (EXTRACT_PROMPT_EN
|
||||||
if language == 'en' else EXTRACT_PROMPT_CN)
|
if language == 'en' else EXTRACT_PROMPT_CN)
|
||||||
input_prompts.append(
|
input_prompts.append(
|
||||||
@ -163,245 +247,125 @@ class LiveMathBenchEvaluator(BaseEvaluator):
|
|||||||
response=prediction,
|
response=prediction,
|
||||||
question_type=question_type))
|
question_type=question_type))
|
||||||
|
|
||||||
result_responses = self.batch_response(post_model, input_prompts)
|
results = self.batch_infer(extract_models,
|
||||||
|
input_prompts,
|
||||||
|
completed_indexes,
|
||||||
|
self.extract_output_handler,
|
||||||
|
postprocess=lambda x: x)
|
||||||
|
|
||||||
return result_responses
|
return results
|
||||||
|
|
||||||
def last_boxed_only_string(string):
|
# extract answer in \\boxed{}
|
||||||
idx = string.rfind('\\boxed')
|
results = [
|
||||||
if idx < 0:
|
math_postprocess_v2(prediction) for prediction in predictions
|
||||||
idx = string.rfind('\\fbox')
|
|
||||||
if idx < 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
i = idx
|
|
||||||
right_brace_idx = None
|
|
||||||
num_left_braces_open = 0
|
|
||||||
while i < len(string):
|
|
||||||
if string[i] == '{':
|
|
||||||
num_left_braces_open += 1
|
|
||||||
if string[i] == '}':
|
|
||||||
num_left_braces_open -= 1
|
|
||||||
if num_left_braces_open == 0:
|
|
||||||
right_brace_idx = i
|
|
||||||
break
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
if right_brace_idx is None:
|
|
||||||
retval = None
|
|
||||||
else:
|
|
||||||
retval = string[idx:right_brace_idx + 1]
|
|
||||||
|
|
||||||
return retval
|
|
||||||
|
|
||||||
def remove_boxed(s):
|
|
||||||
left = '\\boxed{'
|
|
||||||
try:
|
|
||||||
assert s[:len(left)] == left
|
|
||||||
assert s[-1] == '}'
|
|
||||||
return s[len(left):-1]
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def extract_boxed_answer(pred_str, strip_double_curly_brace=False):
|
|
||||||
boxed_str = last_boxed_only_string(pred_str)
|
|
||||||
if boxed_str is None:
|
|
||||||
return None
|
|
||||||
answer = remove_boxed(boxed_str)
|
|
||||||
if answer is None:
|
|
||||||
return None
|
|
||||||
if strip_double_curly_brace:
|
|
||||||
match = re.match('^\{(.*)\}$', answer) # noqa: W605
|
|
||||||
if match:
|
|
||||||
answer = match.group(1)
|
|
||||||
return answer
|
|
||||||
|
|
||||||
predictions = [
|
|
||||||
extract_boxed_answer(prediction) for prediction in predictions
|
|
||||||
]
|
]
|
||||||
return predictions
|
return results
|
||||||
|
|
||||||
def extract_boxed_answer(self, text):
|
def judge(self, predictions, references, test_set):
|
||||||
match = re.findall(r'\\boxed{(.+?)}', text)
|
|
||||||
if match:
|
|
||||||
return match[-1]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def score(self, predictions, references, origin_prompt, test_set):
|
|
||||||
if len(predictions) != len(references):
|
if len(predictions) != len(references):
|
||||||
return {'error': 'preds and refrs have different length'}
|
raise ValueError('preds and refrs have different length')
|
||||||
|
|
||||||
|
completed_indexes = set()
|
||||||
|
mmengine.mkdir_or_exist(self.output_dir)
|
||||||
|
tmp_json_file_path = os.path.join(self.output_dir, 'tmp_judge.json')
|
||||||
|
self.judge_output_handler.save_file_path = tmp_json_file_path
|
||||||
|
if os.path.exists(tmp_json_file_path):
|
||||||
|
tmp_dict = mmengine.load(tmp_json_file_path)
|
||||||
|
self.judge_output_handler.output_dict = tmp_dict
|
||||||
|
for index in tmp_dict:
|
||||||
|
completed_indexes.add(int(index))
|
||||||
|
|
||||||
questions = test_set['question']
|
questions = test_set['question']
|
||||||
question_types = test_set['question_type']
|
question_types = test_set['question_type']
|
||||||
languages = [key.split('_')[1] for key in test_set['dataset_key']]
|
languages = [key.split('_')[1] for key in test_set['subdivision']]
|
||||||
|
|
||||||
if self.with_postprocess:
|
predictions = self.extract(questions, predictions, question_types,
|
||||||
predictions = self.postprocess(questions, predictions,
|
languages)
|
||||||
question_types, languages)
|
|
||||||
|
|
||||||
inputs = []
|
if len(self.judge_models) > 0:
|
||||||
for prediction, reference, question, language in zip(
|
inputs = []
|
||||||
predictions, references, questions, languages):
|
for prediction, reference, question, language in zip(
|
||||||
prompt = JUDGE_PROMPT_EN if language == 'en' else JUDGE_PROMPT_CN
|
predictions, references, questions, languages):
|
||||||
inputs.append(
|
prompt = (JUDGE_PROMPT_EN
|
||||||
prompt.format(answer=prediction,
|
if language == 'en' else JUDGE_PROMPT_CN)
|
||||||
gold_answer=reference,
|
inputs.append(
|
||||||
question=question))
|
prompt.format(answer=prediction,
|
||||||
result_responses = self.batch_response(self.model, inputs)
|
gold_answer=reference,
|
||||||
results = [
|
question=question))
|
||||||
self.extract_boxed_answer(result) == 'yes'
|
|
||||||
for result in result_responses
|
|
||||||
]
|
|
||||||
|
|
||||||
K = test_set['k'][0]
|
labels = self.batch_infer(
|
||||||
N = test_set['n'][0]
|
self.judge_models, inputs, completed_indexes,
|
||||||
key2example = {}
|
self.judge_output_handler, lambda x:
|
||||||
|
(1 if extract_judge_label(x) == 'yes' else 0))
|
||||||
for example, result_response, result, prediction in zip(
|
else:
|
||||||
test_set, result_responses, results, predictions):
|
is_equiv = MATHAgentEvaluator(version='v2').is_equiv
|
||||||
if example['dataset_key'] not in key2example:
|
labels = [
|
||||||
key2example[example['dataset_key']] = []
|
1 if is_equiv(prediction, reference) else 0
|
||||||
example.update({
|
for prediction, reference in zip(predictions, references)
|
||||||
'eval_response': result_response,
|
|
||||||
'prediction': prediction,
|
|
||||||
'correct': result
|
|
||||||
})
|
|
||||||
key2example[example['dataset_key']].append(example)
|
|
||||||
for key in key2example:
|
|
||||||
key2example[key] = [
|
|
||||||
key2example[key][i * K:(i + 1) * K] for i in range(N)
|
|
||||||
]
|
]
|
||||||
|
return labels
|
||||||
|
|
||||||
count = []
|
def preprocess(self, predictions, references, test_set):
|
||||||
total_pass_num = []
|
return self.judge(predictions, references, test_set)
|
||||||
details = []
|
|
||||||
all_dataset = set()
|
|
||||||
for key, examples in key2example.items():
|
|
||||||
detail = OrderedDict()
|
|
||||||
detail['question'] = examples[0][0]['question']
|
|
||||||
detail['answer'] = examples[0][0]['answer']
|
|
||||||
detail['responses'] = []
|
|
||||||
detail['dataset'] = '_'.join(key.split('_')[:-1])
|
|
||||||
all_dataset.add('_'.join(key.split('_')[:-1]))
|
|
||||||
if_pass_list = []
|
|
||||||
for single_run_examples in examples:
|
|
||||||
detail['responses'].append([])
|
|
||||||
if_pass_list.append([])
|
|
||||||
for example in single_run_examples:
|
|
||||||
detail['responses'][-1].append({
|
|
||||||
'prediction':
|
|
||||||
example['prediction'],
|
|
||||||
'eval_response':
|
|
||||||
example['eval_response']
|
|
||||||
})
|
|
||||||
if_pass_list[-1].append(1.0 if example['correct'] else 0.0)
|
|
||||||
|
|
||||||
if_pass_list = [
|
def group(self, predictions, labels, test_set):
|
||||||
sorted(if_pass, reverse=True) for if_pass in if_pass_list
|
example2replications = {}
|
||||||
]
|
for example, label, prediction in zip(test_set, labels, predictions):
|
||||||
if_pass_list = np.array(if_pass_list)
|
example_abbr = f"{example['subdivision']}_{example['idx']}"
|
||||||
i = 1
|
if example_abbr not in example2replications:
|
||||||
while i <= K:
|
example2replications[example_abbr] = []
|
||||||
detail.update({
|
example.update({'prediction': prediction, 'label': label})
|
||||||
f'pass-rate@{i}':
|
example2replications[example_abbr].append(example)
|
||||||
if_pass_list[:, :i].mean(axis=1).mean(axis=0).item(),
|
for _, replications in example2replications.items():
|
||||||
f'pass-rate@{i}/std':
|
assert len(replications) == self.n, print(len(replications),
|
||||||
if_pass_list[:, :i].mean(axis=1).std(axis=0).item(),
|
self.n)
|
||||||
f'pass@{i}':
|
return example2replications
|
||||||
np.ceil(
|
|
||||||
if_pass_list[:, :i].mean(axis=1)).mean(axis=0).item(),
|
|
||||||
f'pass@{i}/std':
|
|
||||||
np.ceil(
|
|
||||||
if_pass_list[:, :i].mean(axis=1)).std(axis=0).item(),
|
|
||||||
})
|
|
||||||
i = i * 2
|
|
||||||
|
|
||||||
for threshold in [0.5, 0.75, 1.0]:
|
def reduce(self, details) -> Dict[str, Any]:
|
||||||
detail.update({
|
"""Aggregate the overall metrics.
|
||||||
f'{K}-pass@{threshold}':
|
|
||||||
np.floor(
|
|
||||||
np.where(
|
|
||||||
if_pass_list.mean(axis=1) >= threshold, 1.0,
|
|
||||||
0.0).mean(axis=0))
|
|
||||||
})
|
|
||||||
|
|
||||||
count.append(np.ones_like(if_pass_list).sum(axis=1))
|
Return:
|
||||||
total_pass_num.append(if_pass_list.sum(axis=1))
|
A dict contains overall metrics, like:
|
||||||
|
{'details': details for each example, 'G-Pass@16': xxx}
|
||||||
|
"""
|
||||||
|
g_passk_details = OrderedDict()
|
||||||
|
g_passk_details['details'] = details
|
||||||
|
|
||||||
details.append(detail)
|
all_dataset = set([detail['subdivision'] for detail in details])
|
||||||
|
|
||||||
detailed_result = OrderedDict()
|
for k in self.k:
|
||||||
detailed_result['details'] = details
|
for subdivision in sorted(list(all_dataset)):
|
||||||
|
for threshold in self.thresholds:
|
||||||
i = 1
|
g_passk_details[
|
||||||
while i <= K:
|
f'{subdivision}/G-Pass@{k}_{threshold}'] = \
|
||||||
detailed_result.update({
|
100. * np.mean(
|
||||||
f'pass-rate@{i}':
|
[
|
||||||
100. *
|
detail[f'G-Pass@{k}_{threshold}']
|
||||||
np.mean([detail[f'pass-rate@{i}'] for detail in details]),
|
for detail in details
|
||||||
f'pass-rate@{i}/std':
|
if detail['subdivision'] == subdivision
|
||||||
100. *
|
])
|
||||||
np.mean([detail[f'pass-rate@{i}/std'] for detail in details]),
|
g_passk_details[f'{subdivision}/mG-Pass@{k}'] = 100. * np.mean(
|
||||||
f'pass@{i}':
|
[
|
||||||
100. * np.mean([detail[f'pass@{i}'] for detail in details]),
|
detail[f'mG-Pass@{k}'] for detail in details
|
||||||
f'pass@{i}/std':
|
if detail['subdivision'] == subdivision
|
||||||
100. * np.mean([detail[f'pass@{i}/std'] for detail in details])
|
|
||||||
})
|
|
||||||
for d in sorted(list(all_dataset)):
|
|
||||||
detailed_result.update({
|
|
||||||
f'{d}/pass-rate@{i}':
|
|
||||||
100. * np.mean([
|
|
||||||
detail[f'pass-rate@{i}']
|
|
||||||
for detail in details if detail['dataset'] == d
|
|
||||||
]),
|
|
||||||
f'{d}/pass-rate@{i}/std':
|
|
||||||
100. * np.mean([
|
|
||||||
detail[f'pass-rate@{i}/std']
|
|
||||||
for detail in details if detail['dataset'] == d
|
|
||||||
]),
|
|
||||||
f'{d}/pass@{i}':
|
|
||||||
100. * np.mean([
|
|
||||||
detail[f'pass@{i}']
|
|
||||||
for detail in details if detail['dataset'] == d
|
|
||||||
]),
|
|
||||||
f'{d}/pass@{i}/std':
|
|
||||||
100. * np.mean([
|
|
||||||
detail[f'pass@{i}/std']
|
|
||||||
for detail in details if detail['dataset'] == d
|
|
||||||
])
|
])
|
||||||
})
|
|
||||||
i = i * 2
|
|
||||||
|
|
||||||
for threshold in [0.5, 0.75, 1.0]:
|
for threshold in self.thresholds:
|
||||||
detailed_result.update({
|
g_passk_details[f'G-Pass@{k}_{threshold}'] = 100. * np.mean(
|
||||||
f'{K}-pass@{threshold}':
|
[detail[f'G-Pass@{k}_{threshold}'] for detail in details])
|
||||||
100. * np.mean([
|
g_passk_details[f'mG-Pass@{k}'] = 100. * np.mean(
|
||||||
detail[f'{K}-pass@{threshold}'] for detail in details
|
[detail[f'mG-Pass@{k}'] for detail in details])
|
||||||
])
|
|
||||||
})
|
|
||||||
detailed_result.update({
|
|
||||||
f'{K}-pass@{threshold}/std':
|
|
||||||
100. * np.mean([
|
|
||||||
detail[f'{K}-pass@{threshold}'] for detail in details
|
|
||||||
])
|
|
||||||
})
|
|
||||||
for d in sorted(list(all_dataset)):
|
|
||||||
|
|
||||||
for threshold in [0.5, 0.75, 1.0]:
|
return g_passk_details
|
||||||
detailed_result.update({
|
|
||||||
f'{d}/{K}-pass@{threshold}':
|
|
||||||
100. * np.mean([
|
|
||||||
detail[f'{K}-pass@{threshold}']
|
|
||||||
for detail in details if detail['dataset'] == d
|
|
||||||
])
|
|
||||||
})
|
|
||||||
detailed_result.update({
|
|
||||||
f'{d}/{K}-pass@{threshold}/std':
|
|
||||||
100. * np.mean([
|
|
||||||
detail[f'{K}-pass@{threshold}']
|
|
||||||
for detail in details if detail['dataset'] == d
|
|
||||||
])
|
|
||||||
})
|
|
||||||
|
|
||||||
return detailed_result
|
|
||||||
|
class LiveMathBenchOutputHandler:
|
||||||
|
output_dict = {}
|
||||||
|
save_file_path = ''
|
||||||
|
|
||||||
|
def write_to_json(self):
|
||||||
|
"""Dump the result to a json file."""
|
||||||
|
dump_results_dict(self.output_dict, self.save_file_path)
|
||||||
|
|
||||||
|
def save(self, idx, **kwargs):
|
||||||
|
self.output_dict[str(idx)] = kwargs
|
||||||
|
10
opencompass/datasets/livemathbench/utils.py
Normal file
10
opencompass/datasets/livemathbench/utils.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def extract_judge_label(text):
|
||||||
|
if isinstance(text, str):
|
||||||
|
match = re.findall(r'\\boxed{(.+?)}', text)
|
||||||
|
if match:
|
||||||
|
return match[-1]
|
||||||
|
|
||||||
|
return None
|
@ -106,7 +106,7 @@ def _format_with_fast_chat_template(inputs: List[str], name: str='vicuna'):
|
|||||||
elif item['role'] == 'system':
|
elif item['role'] == 'system':
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unknown role {item["role"]}')
|
raise ValueError(f"Unknown role {item['role']}")
|
||||||
template.append_message(template.roles[1], None)
|
template.append_message(template.roles[1], None)
|
||||||
outputs.append(template.get_prompt())
|
outputs.append(template.get_prompt())
|
||||||
return outputs
|
return outputs
|
||||||
@ -474,6 +474,8 @@ class HuggingFacewithChatTemplate(BaseModel):
|
|||||||
if min_out_len is not None:
|
if min_out_len is not None:
|
||||||
generation_kwargs['min_new_tokens'] = min_out_len
|
generation_kwargs['min_new_tokens'] = min_out_len
|
||||||
generation_kwargs['pad_token_id'] = self.tokenizer.pad_token_id
|
generation_kwargs['pad_token_id'] = self.tokenizer.pad_token_id
|
||||||
|
self.logger.info('Generation Args of Huggingface: ')
|
||||||
|
self.logger.info(generation_kwargs)
|
||||||
|
|
||||||
# step-2: conduct model forward to generate output
|
# step-2: conduct model forward to generate output
|
||||||
outputs = self.model.generate(**tokens, **generation_kwargs)
|
outputs = self.model.generate(**tokens, **generation_kwargs)
|
||||||
|
@ -516,10 +516,13 @@ class OpenAISDK(OpenAI):
|
|||||||
|
|
||||||
# support multiple api_base for acceleration
|
# support multiple api_base for acceleration
|
||||||
if isinstance(openai_api_base, List):
|
if isinstance(openai_api_base, List):
|
||||||
openai_api_base = random.choice(openai_api_base)
|
self.openai_api_base = random.choice(openai_api_base)
|
||||||
|
else:
|
||||||
|
self.openai_api_base = openai_api_base
|
||||||
|
|
||||||
if self.proxy_url is None:
|
if self.proxy_url is None:
|
||||||
self.openai_client = OpenAI(base_url=openai_api_base, api_key=key)
|
self.openai_client = OpenAI(base_url=self.openai_api_base,
|
||||||
|
api_key=key)
|
||||||
else:
|
else:
|
||||||
proxies = {
|
proxies = {
|
||||||
'http://': self.proxy_url,
|
'http://': self.proxy_url,
|
||||||
@ -527,7 +530,7 @@ class OpenAISDK(OpenAI):
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.openai_client = OpenAI(
|
self.openai_client = OpenAI(
|
||||||
base_url=openai_api_base,
|
base_url=self.openai_api_base,
|
||||||
api_key=key,
|
api_key=key,
|
||||||
http_client=httpx.Client(proxies=proxies))
|
http_client=httpx.Client(proxies=proxies))
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
@ -617,8 +620,8 @@ class OpenAISDK(OpenAI):
|
|||||||
'Successfully get response from OpenAI API')
|
'Successfully get response from OpenAI API')
|
||||||
try:
|
try:
|
||||||
self.logger.info(responses)
|
self.logger.info(responses)
|
||||||
except Exception as e: # noqa F841
|
except Exception:
|
||||||
pass
|
pass # noqa F841
|
||||||
if not responses.choices:
|
if not responses.choices:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
'Response is empty, it is an internal server error \
|
'Response is empty, it is an internal server error \
|
||||||
@ -635,13 +638,18 @@ class OpenAISDK(OpenAI):
|
|||||||
if (status_code is not None
|
if (status_code is not None
|
||||||
and status_code in self.status_code_mappings):
|
and status_code in self.status_code_mappings):
|
||||||
error_message = self.status_code_mappings[status_code]
|
error_message = self.status_code_mappings[status_code]
|
||||||
|
self.logger.error(
|
||||||
|
f'error occurs at {self.openai_api_base}')
|
||||||
self.logger.info(f'Status Code: {status_code}, \n'
|
self.logger.info(f'Status Code: {status_code}, \n'
|
||||||
f'Original Error Message: {e}, \n'
|
f'Original Error Message: {e}, \n'
|
||||||
f'Return Message: {error_message} ')
|
f'Return Message: {error_message} ')
|
||||||
return error_message
|
return error_message
|
||||||
else:
|
else:
|
||||||
|
self.logger.error(
|
||||||
|
f'error occurs at {self.openai_api_base}')
|
||||||
self.logger.error(e)
|
self.logger.error(e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self.logger.error(f'error occurs at {self.openai_api_base}')
|
||||||
self.logger.error(e)
|
self.logger.error(e)
|
||||||
num_retries += 1
|
num_retries += 1
|
||||||
raise RuntimeError('Calling OpenAI API failed after retrying for '
|
raise RuntimeError('Calling OpenAI API failed after retrying for '
|
||||||
|
@ -128,10 +128,7 @@ class TurboMindModelwithChatTemplate(BaseModel):
|
|||||||
gen_config['max_new_tokens'] = max_out_len
|
gen_config['max_new_tokens'] = max_out_len
|
||||||
if min_out_len is not None:
|
if min_out_len is not None:
|
||||||
gen_config['min_new_tokens'] = min_out_len
|
gen_config['min_new_tokens'] = min_out_len
|
||||||
if do_sample or ('do_sample' in self.gen_config and self.gen_config['do_sample']):
|
if not(do_sample or ('do_sample' in self.gen_config and self.gen_config['do_sample'])):
|
||||||
gen_config['top_k'] = 40
|
|
||||||
gen_config['temperature'] = temperature
|
|
||||||
else:
|
|
||||||
if self.version_info >= (0, 6, 0):
|
if self.version_info >= (0, 6, 0):
|
||||||
gen_config['do_sample'] = False
|
gen_config['do_sample'] = False
|
||||||
else:
|
else:
|
||||||
@ -140,6 +137,8 @@ class TurboMindModelwithChatTemplate(BaseModel):
|
|||||||
from lmdeploy import GenerationConfig
|
from lmdeploy import GenerationConfig
|
||||||
gen_config = {k: v for k, v in gen_config.items() if hasattr(GenerationConfig, k)}
|
gen_config = {k: v for k, v in gen_config.items() if hasattr(GenerationConfig, k)}
|
||||||
gen_config = GenerationConfig(**gen_config)
|
gen_config = GenerationConfig(**gen_config)
|
||||||
|
self.logger.info('Generation Config of LMdeploy: ')
|
||||||
|
self.logger.info(gen_config)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
outputs = self.pipe(messages, gen_config=gen_config, do_preprocess=False)
|
outputs = self.pipe(messages, gen_config=gen_config, do_preprocess=False)
|
||||||
|
@ -108,6 +108,8 @@ class VLLMwithChatTemplate(BaseModel):
|
|||||||
sampling_kwargs.update(self.generation_kwargs)
|
sampling_kwargs.update(self.generation_kwargs)
|
||||||
sampling_kwargs.update(kwargs)
|
sampling_kwargs.update(kwargs)
|
||||||
sampling_kwargs = SamplingParams(**sampling_kwargs)
|
sampling_kwargs = SamplingParams(**sampling_kwargs)
|
||||||
|
self.logger.info('Sampling Params of vLLM: ')
|
||||||
|
self.logger.info(sampling_kwargs)
|
||||||
|
|
||||||
outputs = self.model.generate(messages, sampling_kwargs)
|
outputs = self.model.generate(messages, sampling_kwargs)
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from .icl_base_evaluator import BaseEvaluator # noqa
|
|||||||
from .icl_bpc_evaluator import BPCEvaluator # noqa
|
from .icl_bpc_evaluator import BPCEvaluator # noqa
|
||||||
from .icl_circular_evaluator import CircularEvaluator # noqa
|
from .icl_circular_evaluator import CircularEvaluator # noqa
|
||||||
from .icl_em_evaluator import EMEvaluator # noqa
|
from .icl_em_evaluator import EMEvaluator # noqa
|
||||||
|
from .icl_gpassk_evaluator import GPassKEvaluator # noqa
|
||||||
from .icl_hf_evaluator import * # noqa
|
from .icl_hf_evaluator import * # noqa
|
||||||
from .icl_jieba_rouge_evaluator import JiebaRougeEvaluator # noqa
|
from .icl_jieba_rouge_evaluator import JiebaRougeEvaluator # noqa
|
||||||
from .icl_misc_evaluator import AverageInferencePPLEvaluator # noqa
|
from .icl_misc_evaluator import AverageInferencePPLEvaluator # noqa
|
||||||
|
163
opencompass/openicl/icl_evaluator/icl_gpassk_evaluator.py
Normal file
163
opencompass/openicl/icl_evaluator/icl_gpassk_evaluator.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy.stats import hypergeom
|
||||||
|
|
||||||
|
from opencompass.registry import ICL_EVALUATORS
|
||||||
|
|
||||||
|
from .icl_base_evaluator import BaseEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
def compute_pass_at_k(n, c, k):
|
||||||
|
if n - c < k:
|
||||||
|
return 1.0
|
||||||
|
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_g_pass_at_k(n, c, k, m):
|
||||||
|
if m > min(c, k) or k > n or c < 0 or n <= 0 or m < 0:
|
||||||
|
return 0.0
|
||||||
|
return hypergeom.sf(m - 1, n, c, k)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_g_pass_at_k(n, c, k, t):
|
||||||
|
m = max(int(np.ceil(k * t)), 1)
|
||||||
|
return _compute_g_pass_at_k(n, c, k, m)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_mg_pass_at_k(n, c, k):
|
||||||
|
l, r = int(np.ceil(k * 0.5)), k
|
||||||
|
|
||||||
|
mg_pass_at_k = 0.0
|
||||||
|
for i in range(l + 1, r + 1):
|
||||||
|
mg_pass_at_k += _compute_g_pass_at_k(n, c, k, i)
|
||||||
|
mg_pass_at_k = 2 * mg_pass_at_k / k
|
||||||
|
|
||||||
|
return mg_pass_at_k
|
||||||
|
|
||||||
|
|
||||||
|
@ICL_EVALUATORS.register_module()
|
||||||
|
class GPassKEvaluator(BaseEvaluator):
|
||||||
|
"""Evaluator for computing the G-Pass@k Metric.
|
||||||
|
|
||||||
|
This evaluator performs the following steps:
|
||||||
|
1. Invokes task-specific `preprocess` on predictions to
|
||||||
|
assign a consistency label to each prediction and its
|
||||||
|
corresponding reference.
|
||||||
|
2. Calculates metrics for each input example based on
|
||||||
|
these labels.
|
||||||
|
3. Aggregates the overall metrics through a task-specific
|
||||||
|
`postprocess`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k (int or list of int): Number of predictions to be
|
||||||
|
considered in G-Pass@k. It can be a single integer
|
||||||
|
(e.g., `k=16` computes G-Pass@16) or a list of
|
||||||
|
integers (e.g., `[4, 8, 16]` computes G-Pass@4,
|
||||||
|
G-Pass@8, and G-Pass@16).
|
||||||
|
|
||||||
|
replication (int): Controls the number of generations
|
||||||
|
used to estimate G-Pass@k. The total number of
|
||||||
|
generations is determined by multiplying the
|
||||||
|
maximum of `k` with `replication`. This parameter
|
||||||
|
should be a single integer.
|
||||||
|
|
||||||
|
thresholds (list of float): A list of floating-point
|
||||||
|
numbers that define the thresholds for the G-Pass@k
|
||||||
|
metric.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
k: Union[int, List[int]] = 16,
|
||||||
|
replication: int = 3,
|
||||||
|
thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if isinstance(k, int):
|
||||||
|
k = [k]
|
||||||
|
|
||||||
|
self.k = k
|
||||||
|
self.replication = replication
|
||||||
|
self.n = max(k) * replication
|
||||||
|
self.thresholds = thresholds
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dir(self):
|
||||||
|
# please see opencompass/opencompass/tasks/openicl_eval.py Line 197-200
|
||||||
|
return self._out_dir
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def preprocess(self, predictions, references, test_set) -> None:
|
||||||
|
"""Perform operations on predictions before computing metrics, for
|
||||||
|
example, do answer_extraction and model_judge in mathematical reasoning
|
||||||
|
task.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
labels: A list contains the label which indicates whether
|
||||||
|
prediction is consistency with reference at each position.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def group(self, predictions, labels, test_set) -> Dict[str, Any]:
|
||||||
|
"""Group the predictions and references.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A dict contains the grouped predictions and references.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reduce(self, details) -> Dict[str, Any]:
|
||||||
|
"""Aggregate the overall metrics.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A dict contains overall metrics, like:
|
||||||
|
{'details': details for each example, 'G-Pass@16': xxx}
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def score(self, predictions, references, test_set) -> Dict[str, Any]:
|
||||||
|
"""Compute G-Pass@k metrics.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A dict contains metrics for each dataset sample and
|
||||||
|
overall metrics reduced by `self.reduce`, like:
|
||||||
|
{'details': details for each example, 'G-Pass@16': xxx}
|
||||||
|
"""
|
||||||
|
labels = self.preprocess(predictions, references, test_set)
|
||||||
|
grouped_examples = self.group(predictions, labels, test_set)
|
||||||
|
|
||||||
|
details = []
|
||||||
|
total_pass_num, count = 0, 0
|
||||||
|
for example_abbr, examples in grouped_examples.items():
|
||||||
|
detail = {
|
||||||
|
k: v
|
||||||
|
for k, v in examples[0].items()
|
||||||
|
if k not in ['prediction', 'label']
|
||||||
|
}
|
||||||
|
detail.update({
|
||||||
|
'predictions': [{
|
||||||
|
'prediction': example['prediction'],
|
||||||
|
'label': example['label']
|
||||||
|
} for example in examples],
|
||||||
|
})
|
||||||
|
|
||||||
|
current_example_labels = [e['label'] for e in examples]
|
||||||
|
c = int(np.sum(current_example_labels))
|
||||||
|
|
||||||
|
for k in self.k:
|
||||||
|
for threshold in self.thresholds:
|
||||||
|
detail[f'G-Pass@{k}_{threshold}'] = compute_g_pass_at_k(
|
||||||
|
n=self.n, c=c, k=k, t=threshold)
|
||||||
|
detail[f'mG-Pass@{k}'] = compute_mg_pass_at_k(n=self.n,
|
||||||
|
c=c,
|
||||||
|
k=k)
|
||||||
|
count += self.n
|
||||||
|
total_pass_num += c
|
||||||
|
|
||||||
|
details.append(detail)
|
||||||
|
|
||||||
|
return self.reduce(details)
|
Loading…
Reference in New Issue
Block a user