[Feature] Support Math evaluation via judgemodel (#1094)

* support openai math evaluation

* support openai math evaluation

* support openai math evaluation

* support math llm judge

* support math llm judge
This commit is contained in:
bittersweet1999 2024-04-26 14:56:23 +08:00 committed by GitHub
parent 41196c48ae
commit 6ba1c4937d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 311 additions and 8 deletions

View File

@ -0,0 +1,35 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import MATHDataset, MATHEvaluator, math_postprocess
QUERY_TEMPLATE = """
Solve the following math problem step by step. The last line of your response should be of the form ANSWER: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
{problem}
Remember to put your answer on its own line after "ANSWER:", and you do not need to use a \\boxed command.
""".strip()
math_reader_cfg = dict(input_columns=['problem'], output_column='solution')
math_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(role="HUMAN", prompt=QUERY_TEMPLATE),
])),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=512))
math_eval_cfg = dict(
evaluator=dict(type=MATHEvaluator), pred_postprocessor=dict(type=math_postprocess))
math_datasets = [
dict(
type=MATHDataset,
abbr='math',
path='./data/math/math.json',
reader_cfg=math_reader_cfg,
infer_cfg=math_infer_cfg,
eval_cfg=math_eval_cfg)
]

View File

@ -0,0 +1,111 @@
# Most of the code in this file is copied from https://github.com/openai/simple-evals/blob/main/math_eval.py
from mmengine.config import read_base
with read_base():
from .models.hf_llama.hf_llama3_8b_instruct import models as hf_llama3_8b_instruct_model # noqa: F401, F403
from .models.hf_internlm.hf_internlm2_chat_20b import models as hf_internlm2_chat_20b_model # noqa: F401, F403
from .models.hf_llama.hf_llama3_70b_instruct import models as hf_llama3_70b_instruct_model # noqa: F401, F403
from .datasets.math.math_llm_judge import math_datasets # noqa: F401, F403
from opencompass.models.openai_api import OpenAIAllesAPIN
from opencompass.datasets import math_judement_preprocess
from opencompass.partitioners import NaivePartitioner, SizePartitioner
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
from opencompass.runners import LocalRunner
from opencompass.runners import SlurmSequentialRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
from opencompass.summarizers import AllObjSummarizer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.openicl.icl_prompt_template import PromptTemplate
# -------------Prompt Settings ----------------------------------------
eng_obj_prompt = """
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
Examples:
Expression 1: $2x+3$
Expression 2: $3+2x$
Result: [[Correct]]
Expression 1: 3/2
Expression 2: 1.5
Result: [[Correct]]
Expression 1: $x^2+2x+1$
Expression 2: $y^2+2y+1$
Result: [[Incorrect]]
Expression 1: $x^2+2x+1$
Expression 2: $(x+1)^2$
Result: [[Correct]]
Expression 1: 3245/5
Expression 2: 649
Result: [[Incorrect]]
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
Expression 1: 2/(-3)
Expression 2: -2/3
Result: [[Correct]]
(trivial simplifications are allowed)
Expression 1: 72 degrees
Expression 2: 72
Result: [[Correct]]
(give benefit of the doubt to units)
Expression 1: 64
Expression 2: 64 square feet
Result: [[Correct]]
(give benefit of the doubt to units)
---
YOUR TASK
Respond with only "Result: [[Correct]]" or "Result: [[Incorrect]]" (without quotes). Do not include a rationale.
Expression 1: {obj_gold}
Expression 2: {prediction}
""".strip()
# -------------Inferen Stage ----------------------------------------
# eval models
models = [*hf_llama3_8b_instruct_model]
# judge models
judge_models = hf_llama3_70b_instruct_model
eng_datasets = [*math_datasets]
chn_datasets = []
datasets = eng_datasets + chn_datasets
work_dir = 'outputs/obj_all/'
for d in eng_datasets:
d['eval_cfg']= dict(
evaluator=dict(
type=LMEvaluator,
# If you need to preprocess the prediction before judging,
# you can specify the pred_postprocessor function here
pred_postprocessor=dict(type=math_judement_preprocess),
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt = eng_obj_prompt
),
]),
),
),
pred_role="BOT",
)
infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=40000),
runner=dict(
type=LocalRunner,
max_num_workers=256,
task=dict(type=OpenICLInferTask)),
)
# ------------- Evaluation Configuration --------------------------------
eval = dict(
partitioner=dict(
type=SubjectiveSizePartitioner, max_task_size=80000, mode='singlescore', models=models, judge_models=judge_models,
),
runner=dict(type=LocalRunner,
max_num_workers=16, task=dict(type=SubjectiveEvalTask)),
)
summarizer = dict(
type=AllObjSummarizer
)

View File

@ -125,6 +125,15 @@ def normalize_final_answer(final_answer: str) -> str:
return final_answer return final_answer
ANSWER_PATTERN = r'(?i)ANSWER\s*:\s*([^\n]+)'
def extract_answer(response_text: str):
# We suggest to return an empty string but not None when extract failed
match = re.search(ANSWER_PATTERN, response_text)
return match.group(1) if match else ''
@LOAD_DATASET.register_module() @LOAD_DATASET.register_module()
class MATHDataset(BaseDataset): class MATHDataset(BaseDataset):
@ -156,6 +165,12 @@ def math_postprocess(text: str) -> str:
# text.split('Final Answer: ', 1)[-1].split('\n\n')[0]) # text.split('Final Answer: ', 1)[-1].split('\n\n')[0])
@TEXT_POSTPROCESSORS.register_module('math_judement_preprocess')
def math_judement_preprocess(text: str) -> str:
"""Preprocess prediction before judgement."""
return extract_answer(text)
@TEXT_POSTPROCESSORS.register_module('math_postprocess_v2') @TEXT_POSTPROCESSORS.register_module('math_postprocess_v2')
def math_postprocess_v2(text: str) -> str: def math_postprocess_v2(text: str) -> str:

View File

@ -12,8 +12,6 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.registry import ICL_PROMPT_TEMPLATES from opencompass.registry import ICL_PROMPT_TEMPLATES
from opencompass.utils import build_dataset_from_cfg, build_model_from_cfg from opencompass.utils import build_dataset_from_cfg, build_model_from_cfg
from opencompass.utils.logging import get_logger from opencompass.utils.logging import get_logger
from opencompass.utils.text_postprocessors import first_number_postprocess
from opencompass.utils.types import get_type_from_cfg
def extract_dicts(data): def extract_dicts(data):
@ -80,7 +78,7 @@ class LMEvaluator:
dataset_cfg (ConfigDict, optional): The config of the dataset to be dataset_cfg (ConfigDict, optional): The config of the dataset to be
evaluated. evaluated.
pack_all_predictions (bool, optional): For multiround evaluation, judge all round or judge every single round. pack_all_predictions (bool, optional): For multiround evaluation, judge all round or judge every single round.
postprocessor (ConfigDict): The model prediction's postprocessor pred_postprocessor (ConfigDict): The model prediction's postprocessor
config. config.
""" """
@ -92,7 +90,7 @@ class LMEvaluator:
meta_review_prompt_template: Optional[ConfigDict] = None, meta_review_prompt_template: Optional[ConfigDict] = None,
pack_all_predictions: Optional[bool] = False, pack_all_predictions: Optional[bool] = False,
dataset_cfg: Optional[ConfigDict] = None, dataset_cfg: Optional[ConfigDict] = None,
postprocessor: ConfigDict = dict(type=first_number_postprocess) pred_postprocessor: Optional[ConfigDict] = None,
) -> None: ) -> None:
self.output_path = output_path self.output_path = output_path
out_dir, out_name = osp.split(output_path) out_dir, out_name = osp.split(output_path)
@ -112,7 +110,6 @@ class LMEvaluator:
batch_size=batch_size, batch_size=batch_size,
output_json_filepath=out_dir, output_json_filepath=out_dir,
output_json_filename=out_name) output_json_filename=out_name)
self.postprocessor = get_type_from_cfg(postprocessor)
self.logger = get_logger() self.logger = get_logger()
self.dataset_cfg = dataset_cfg self.dataset_cfg = dataset_cfg
self.pack_all_predictions = pack_all_predictions self.pack_all_predictions = pack_all_predictions
@ -163,7 +160,9 @@ class LMEvaluator:
): #single chat for format like [['xxx', 'xxxx'], ['xxx', 'xxxx']] ): #single chat for format like [['xxx', 'xxxx'], ['xxx', 'xxxx']]
for i in range(len(predictions)): for i in range(len(predictions)):
key = 'prediction' if i == 0 else f'prediction{i + 1}' key = 'prediction' if i == 0 else f'prediction{i + 1}'
gold_key = 'obj_gold'
pred_dict[key] = predictions[i] pred_dict[key] = predictions[i]
pred_dict[gold_key] = references
if judgements: if judgements:
for i in range(len(judgements)): for i in range(len(judgements)):
key = 'judgement' if i == 0 else f'judgement{i + 1}' key = 'judgement' if i == 0 else f'judgement{i + 1}'
@ -189,6 +188,10 @@ class LMEvaluator:
if judgements: if judgements:
raise NotImplementedError( raise NotImplementedError(
'Not applied meta-reivew judge on multi-round dataset') 'Not applied meta-reivew judge on multi-round dataset')
else:
raise NotImplementedError(
f'{predictions[0][0]} with type {type(predictions[0][0])}, please check the postprocess you add to the prediction string is right or not, we suggest to return an empty string but not None'
)
if self.dataset_cfg: if self.dataset_cfg:
dataset = build_dataset_from_cfg(self.dataset_cfg) dataset = build_dataset_from_cfg(self.dataset_cfg)

View File

@ -1,5 +1,6 @@
# flake8: noqa: F401, E501 # flake8: noqa: F401, E501
from .alignmentbench import AlignmentBenchSummarizer from .alignmentbench import AlignmentBenchSummarizer
from .all_obj import AllObjSummarizer
from .alpacaeval import AlpacaSummarizer from .alpacaeval import AlpacaSummarizer
from .compass_arena import CompassArenaSummarizer from .compass_arena import CompassArenaSummarizer
from .corev2 import Corev2Summarizer from .corev2 import Corev2Summarizer

View File

@ -0,0 +1,122 @@
# flake8: noqa: E501
import csv
import os
import os.path as osp
import re
from collections import defaultdict
from datetime import datetime
import numpy as np
from mmengine import ConfigDict
from prettytable import from_csv
from opencompass.utils import dataset_abbr_from_cfg, model_abbr_from_cfg
from .utils import get_judgeanswer_and_reference, get_outdir
def post_process_allobj(judgement: str):
"""Input a string like below:
xxx[[correct]]xxx, and extract the judge
"""
pattern = r'(?i)\[(incorrect|correct|正确|错误)\]'
matched_result = re.findall(pattern, judgement)
if matched_result:
content = matched_result[0].lower()
if content in ['correct', '正确']:
return {'score': 1}
elif content in ['incorrect', '错误']:
return {'score': 0}
else:
return None
def get_capability_results(
judged_answers,
references,
fout,
fout_flag,
model,
):
capability_ratings = defaultdict(int)
capability_counts = defaultdict(int)
for ans, ref in zip(judged_answers, references):
capability_ratings['total'] += ans['score']
capability_counts['total'] += 1
capability_avg_ratings = defaultdict(float)
for capability, total_score in capability_ratings.items():
capability_avg_ratings[
capability] = total_score / capability_counts[capability]
columns = list(capability_avg_ratings.keys())
columns.insert(0, columns.pop(columns.index('total')))
with open(fout, 'a+', newline='') as csvfile:
writer = csv.writer(csvfile)
if fout_flag == 0:
writer.writerow(['model'] + columns)
writer.writerow([model] +
[capability_avg_ratings[column] for column in columns])
class AllObjSummarizer:
"""Do the subjectivity analyze based on evaluation results.
Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""
def __init__(self, config: ConfigDict, judge_type='single') -> None:
self.judge_type = judge_type
self.tasks = []
self.cfg = config
if self.judge_type == 'single':
self.eval_model_cfgs = self.cfg['eval']['partitioner']['models']
self.eval_model_abbrs = [
model_abbr_from_cfg(model) for model in self.eval_model_cfgs
]
elif self.judge_type == 'pair':
self.base_models = self.cfg['eval']['partitioner']['base_models']
self.compare_models = self.cfg['eval']['partitioner'][
'compare_models']
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_models'][0])
self.judge_map = {'single': post_process_allobj}
self.judge_function = self.judge_map[self.judge_type]
def summarize(self,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
"""Summarize the subjectivity analysis based on evaluation results.
Args:
time_str (str): Timestamp for file naming.
Returns:
pd.DataFrame: The summary results.
"""
if self.judge_type == 'single':
dataset_cfgs = self.cfg['datasets']
judge_model = self.judge_abbr
output_dir, results_folder = get_outdir(self.cfg, time_str)
for dataset in dataset_cfgs:
dataset_abbr = dataset_abbr_from_cfg(dataset)
fout = osp.join(
output_dir,
'judged-by--' + judge_model + '-' + dataset_abbr + '.csv')
fout_flag = 0
for eval_model_abbr in self.eval_model_abbrs:
subdir = eval_model_abbr + '_judged-by--' + self.judge_abbr
subdir_path = os.path.join(results_folder, subdir)
if os.path.isdir(subdir_path):
model = eval_model_abbr
judged_answers, references = get_judgeanswer_and_reference(
dataset, subdir_path, self.judge_function)
get_capability_results(judged_answers, references,
fout, fout_flag, model)
fout_flag += 1
else:
print(subdir_path + ' is not exist! please check!')
with open(fout, 'r') as f:
x = from_csv(f)
print(x)

View File

@ -139,7 +139,8 @@ class SubjectiveEvalTask(BaseTask):
# If no predictions get in predictions dir # If no predictions get in predictions dir
assert osp.exists(filename) or osp.exists( assert osp.exists(filename) or osp.exists(
osp.realpath(partial_filename) osp.realpath(partial_filename)
), 'No predictions found for {filename}.'.format(filename=filename) ), 'No predictions found for {filename} and {partial_filename}'.format(
filename=filename, partial_filename=partial_filename)
# If use Naive partition in infer stage # If use Naive partition in infer stage
if osp.exists(osp.realpath(filename)): if osp.exists(osp.realpath(filename)):
@ -188,10 +189,14 @@ class SubjectiveEvalTask(BaseTask):
if fnmatch.fnmatch(ds_abbr, pattern): if fnmatch.fnmatch(ds_abbr, pattern):
pred_postprocessor = model_postprocessors[pattern] pred_postprocessor = model_postprocessors[pattern]
break break
if 'pred_postprocessor' in eval_cfg or pred_postprocessor: if 'pred_postprocessor' in eval_cfg['evaluator'] or pred_postprocessor:
kwargs = pred_postprocessor or eval_cfg['pred_postprocessor'] kwargs = pred_postprocessor or eval_cfg['evaluator'][
'pred_postprocessor']
proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type')) proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type'))
self.logger.info('Get postprocessor {postprocessor}.')
pred_strs = [proc(s, **kwargs) for s in pred_strs] pred_strs = [proc(s, **kwargs) for s in pred_strs]
else:
self.logger.info('No postprocessor found.')
return { return {
'model_name': model_abbr_from_cfg(model_cfg), 'model_name': model_abbr_from_cfg(model_cfg),

View File

@ -77,6 +77,17 @@ def get_config_from_arg(args) -> Config:
if args.accelerator in ['vllm', 'lmdeploy']: if args.accelerator in ['vllm', 'lmdeploy']:
config['models'] = change_accelerator(config['models'], config['models'] = change_accelerator(config['models'],
args.accelerator) args.accelerator)
if 'eval' in config and 'partitioner' in config['eval']:
if 'models' in config['eval']['partitioner']:
config['eval']['partitioner'][
'models'] = change_accelerator(
config['eval']['partitioner']['models'],
args.accelerator)
if 'judge_models' in config['eval']['partitioner']:
config['eval']['partitioner'][
'judge_models'] = change_accelerator(
config['eval']['partitioner']['judge_models'],
args.accelerator)
return config return config
# parse dataset args # parse dataset args
if not args.datasets and not args.custom_dataset_path: if not args.datasets and not args.custom_dataset_path: