diff --git a/configs/datasets/math/math_llm_judge.py b/configs/datasets/math/math_llm_judge.py new file mode 100644 index 00000000..c3cc2ecd --- /dev/null +++ b/configs/datasets/math/math_llm_judge.py @@ -0,0 +1,35 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import MATHDataset, MATHEvaluator, math_postprocess + +QUERY_TEMPLATE = """ +Solve the following math problem step by step. The last line of your response should be of the form ANSWER: $ANSWER (without quotes) where $ANSWER is the answer to the problem. +{problem} +Remember to put your answer on its own line after "ANSWER:", and you do not need to use a \\boxed command. +""".strip() + +math_reader_cfg = dict(input_columns=['problem'], output_column='solution') + +math_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + + template=dict(round=[ + dict(role="HUMAN", prompt=QUERY_TEMPLATE), + ])), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=512)) + +math_eval_cfg = dict( + evaluator=dict(type=MATHEvaluator), pred_postprocessor=dict(type=math_postprocess)) + +math_datasets = [ + dict( + type=MATHDataset, + abbr='math', + path='./data/math/math.json', + reader_cfg=math_reader_cfg, + infer_cfg=math_infer_cfg, + eval_cfg=math_eval_cfg) +] \ No newline at end of file diff --git a/configs/eval_math_llm_judge.py b/configs/eval_math_llm_judge.py new file mode 100644 index 00000000..8dac6fd0 --- /dev/null +++ b/configs/eval_math_llm_judge.py @@ -0,0 +1,111 @@ +# Most of the code in this file is copied from https://github.com/openai/simple-evals/blob/main/math_eval.py +from mmengine.config import read_base +with read_base(): + from .models.hf_llama.hf_llama3_8b_instruct import models as hf_llama3_8b_instruct_model # noqa: F401, F403 + from .models.hf_internlm.hf_internlm2_chat_20b import models as hf_internlm2_chat_20b_model # noqa: F401, F403 + from .models.hf_llama.hf_llama3_70b_instruct import models as hf_llama3_70b_instruct_model # noqa: F401, F403 + from .datasets.math.math_llm_judge import math_datasets # noqa: F401, F403 +from opencompass.models.openai_api import OpenAIAllesAPIN +from opencompass.datasets import math_judement_preprocess +from opencompass.partitioners import NaivePartitioner, SizePartitioner +from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner +from opencompass.partitioners.sub_size import SubjectiveSizePartitioner +from opencompass.runners import LocalRunner +from opencompass.runners import SlurmSequentialRunner +from opencompass.tasks import OpenICLInferTask +from opencompass.tasks.subjective_eval import SubjectiveEvalTask +from opencompass.summarizers import AllObjSummarizer +from opencompass.openicl.icl_evaluator import LMEvaluator +from opencompass.openicl.icl_prompt_template import PromptTemplate + + +# -------------Prompt Settings ---------------------------------------- +eng_obj_prompt = """ +Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications +Examples: + Expression 1: $2x+3$ + Expression 2: $3+2x$ +Result: [[Correct]] + Expression 1: 3/2 + Expression 2: 1.5 +Result: [[Correct]] + Expression 1: $x^2+2x+1$ + Expression 2: $y^2+2y+1$ +Result: [[Incorrect]] + Expression 1: $x^2+2x+1$ + Expression 2: $(x+1)^2$ +Result: [[Correct]] + Expression 1: 3245/5 + Expression 2: 649 +Result: [[Incorrect]] +(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications) + Expression 1: 2/(-3) + Expression 2: -2/3 +Result: [[Correct]] +(trivial simplifications are allowed) + Expression 1: 72 degrees + Expression 2: 72 +Result: [[Correct]] +(give benefit of the doubt to units) + Expression 1: 64 + Expression 2: 64 square feet +Result: [[Correct]] +(give benefit of the doubt to units) +--- +YOUR TASK +Respond with only "Result: [[Correct]]" or "Result: [[Incorrect]]" (without quotes). Do not include a rationale. + Expression 1: {obj_gold} + Expression 2: {prediction} +""".strip() + +# -------------Inferen Stage ---------------------------------------- +# eval models +models = [*hf_llama3_8b_instruct_model] +# judge models +judge_models = hf_llama3_70b_instruct_model + +eng_datasets = [*math_datasets] +chn_datasets = [] +datasets = eng_datasets + chn_datasets +work_dir = 'outputs/obj_all/' + +for d in eng_datasets: + d['eval_cfg']= dict( + evaluator=dict( + type=LMEvaluator, + # If you need to preprocess the prediction before judging, + # you can specify the pred_postprocessor function here + pred_postprocessor=dict(type=math_judement_preprocess), + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict( + role='HUMAN', + prompt = eng_obj_prompt + ), + ]), + ), + ), + pred_role="BOT", + ) + +infer = dict( + partitioner=dict(type=SizePartitioner, max_task_size=40000), + runner=dict( + type=LocalRunner, + max_num_workers=256, + task=dict(type=OpenICLInferTask)), +) + +# ------------- Evaluation Configuration -------------------------------- +eval = dict( + partitioner=dict( + type=SubjectiveSizePartitioner, max_task_size=80000, mode='singlescore', models=models, judge_models=judge_models, + ), + runner=dict(type=LocalRunner, + max_num_workers=16, task=dict(type=SubjectiveEvalTask)), +) + +summarizer = dict( + type=AllObjSummarizer +) diff --git a/opencompass/datasets/math.py b/opencompass/datasets/math.py index 36bcd6d7..19e38baf 100644 --- a/opencompass/datasets/math.py +++ b/opencompass/datasets/math.py @@ -125,6 +125,15 @@ def normalize_final_answer(final_answer: str) -> str: return final_answer +ANSWER_PATTERN = r'(?i)ANSWER\s*:\s*([^\n]+)' + + +def extract_answer(response_text: str): + # We suggest to return an empty string but not None when extract failed + match = re.search(ANSWER_PATTERN, response_text) + return match.group(1) if match else '' + + @LOAD_DATASET.register_module() class MATHDataset(BaseDataset): @@ -156,6 +165,12 @@ def math_postprocess(text: str) -> str: # text.split('Final Answer: ', 1)[-1].split('\n\n')[0]) +@TEXT_POSTPROCESSORS.register_module('math_judement_preprocess') +def math_judement_preprocess(text: str) -> str: + """Preprocess prediction before judgement.""" + return extract_answer(text) + + @TEXT_POSTPROCESSORS.register_module('math_postprocess_v2') def math_postprocess_v2(text: str) -> str: diff --git a/opencompass/openicl/icl_evaluator/lm_evaluator.py b/opencompass/openicl/icl_evaluator/lm_evaluator.py index bb3d502e..bd89533c 100644 --- a/opencompass/openicl/icl_evaluator/lm_evaluator.py +++ b/opencompass/openicl/icl_evaluator/lm_evaluator.py @@ -12,8 +12,6 @@ from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.registry import ICL_PROMPT_TEMPLATES from opencompass.utils import build_dataset_from_cfg, build_model_from_cfg from opencompass.utils.logging import get_logger -from opencompass.utils.text_postprocessors import first_number_postprocess -from opencompass.utils.types import get_type_from_cfg def extract_dicts(data): @@ -80,7 +78,7 @@ class LMEvaluator: dataset_cfg (ConfigDict, optional): The config of the dataset to be evaluated. pack_all_predictions (bool, optional): For multiround evaluation, judge all round or judge every single round. - postprocessor (ConfigDict): The model prediction's postprocessor + pred_postprocessor (ConfigDict): The model prediction's postprocessor config. """ @@ -92,7 +90,7 @@ class LMEvaluator: meta_review_prompt_template: Optional[ConfigDict] = None, pack_all_predictions: Optional[bool] = False, dataset_cfg: Optional[ConfigDict] = None, - postprocessor: ConfigDict = dict(type=first_number_postprocess) + pred_postprocessor: Optional[ConfigDict] = None, ) -> None: self.output_path = output_path out_dir, out_name = osp.split(output_path) @@ -112,7 +110,6 @@ class LMEvaluator: batch_size=batch_size, output_json_filepath=out_dir, output_json_filename=out_name) - self.postprocessor = get_type_from_cfg(postprocessor) self.logger = get_logger() self.dataset_cfg = dataset_cfg self.pack_all_predictions = pack_all_predictions @@ -163,7 +160,9 @@ class LMEvaluator: ): #single chat for format like [['xxx', 'xxxx'], ['xxx', 'xxxx']] for i in range(len(predictions)): key = 'prediction' if i == 0 else f'prediction{i + 1}' + gold_key = 'obj_gold' pred_dict[key] = predictions[i] + pred_dict[gold_key] = references if judgements: for i in range(len(judgements)): key = 'judgement' if i == 0 else f'judgement{i + 1}' @@ -189,6 +188,10 @@ class LMEvaluator: if judgements: raise NotImplementedError( 'Not applied meta-reivew judge on multi-round dataset') + else: + raise NotImplementedError( + f'{predictions[0][0]} with type {type(predictions[0][0])}, please check the postprocess you add to the prediction string is right or not, we suggest to return an empty string but not None' + ) if self.dataset_cfg: dataset = build_dataset_from_cfg(self.dataset_cfg) diff --git a/opencompass/summarizers/subjective/__init__.py b/opencompass/summarizers/subjective/__init__.py index 4fe21a33..4e8e1497 100644 --- a/opencompass/summarizers/subjective/__init__.py +++ b/opencompass/summarizers/subjective/__init__.py @@ -1,5 +1,6 @@ # flake8: noqa: F401, E501 from .alignmentbench import AlignmentBenchSummarizer +from .all_obj import AllObjSummarizer from .alpacaeval import AlpacaSummarizer from .compass_arena import CompassArenaSummarizer from .corev2 import Corev2Summarizer diff --git a/opencompass/summarizers/subjective/all_obj.py b/opencompass/summarizers/subjective/all_obj.py new file mode 100644 index 00000000..c2f171d1 --- /dev/null +++ b/opencompass/summarizers/subjective/all_obj.py @@ -0,0 +1,122 @@ +# flake8: noqa: E501 +import csv +import os +import os.path as osp +import re +from collections import defaultdict +from datetime import datetime + +import numpy as np +from mmengine import ConfigDict +from prettytable import from_csv + +from opencompass.utils import dataset_abbr_from_cfg, model_abbr_from_cfg + +from .utils import get_judgeanswer_and_reference, get_outdir + + +def post_process_allobj(judgement: str): + """Input a string like below: + + xxx[[correct]]xxx, and extract the judge + """ + pattern = r'(?i)\[(incorrect|correct|正确|错误)\]' + matched_result = re.findall(pattern, judgement) + if matched_result: + content = matched_result[0].lower() + if content in ['correct', '正确']: + return {'score': 1} + elif content in ['incorrect', '错误']: + return {'score': 0} + else: + return None + + +def get_capability_results( + judged_answers, + references, + fout, + fout_flag, + model, +): + capability_ratings = defaultdict(int) + capability_counts = defaultdict(int) + for ans, ref in zip(judged_answers, references): + capability_ratings['total'] += ans['score'] + capability_counts['total'] += 1 + + capability_avg_ratings = defaultdict(float) + + for capability, total_score in capability_ratings.items(): + capability_avg_ratings[ + capability] = total_score / capability_counts[capability] + columns = list(capability_avg_ratings.keys()) + columns.insert(0, columns.pop(columns.index('total'))) + with open(fout, 'a+', newline='') as csvfile: + writer = csv.writer(csvfile) + if fout_flag == 0: + writer.writerow(['model'] + columns) + writer.writerow([model] + + [capability_avg_ratings[column] for column in columns]) + + +class AllObjSummarizer: + """Do the subjectivity analyze based on evaluation results. + + Args: + config (ConfigDict): The configuration object of the evaluation task. + It's expected to be filled out at runtime. + """ + + def __init__(self, config: ConfigDict, judge_type='single') -> None: + self.judge_type = judge_type + self.tasks = [] + self.cfg = config + if self.judge_type == 'single': + self.eval_model_cfgs = self.cfg['eval']['partitioner']['models'] + self.eval_model_abbrs = [ + model_abbr_from_cfg(model) for model in self.eval_model_cfgs + ] + elif self.judge_type == 'pair': + self.base_models = self.cfg['eval']['partitioner']['base_models'] + self.compare_models = self.cfg['eval']['partitioner'][ + 'compare_models'] + self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_models'][0]) + self.judge_map = {'single': post_process_allobj} + self.judge_function = self.judge_map[self.judge_type] + + def summarize(self, + time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')): + """Summarize the subjectivity analysis based on evaluation results. + + Args: + time_str (str): Timestamp for file naming. + + Returns: + pd.DataFrame: The summary results. + """ + if self.judge_type == 'single': + dataset_cfgs = self.cfg['datasets'] + judge_model = self.judge_abbr + output_dir, results_folder = get_outdir(self.cfg, time_str) + for dataset in dataset_cfgs: + dataset_abbr = dataset_abbr_from_cfg(dataset) + fout = osp.join( + output_dir, + 'judged-by--' + judge_model + '-' + dataset_abbr + '.csv') + fout_flag = 0 + for eval_model_abbr in self.eval_model_abbrs: + subdir = eval_model_abbr + '_judged-by--' + self.judge_abbr + subdir_path = os.path.join(results_folder, subdir) + if os.path.isdir(subdir_path): + model = eval_model_abbr + judged_answers, references = get_judgeanswer_and_reference( + dataset, subdir_path, self.judge_function) + get_capability_results(judged_answers, references, + fout, fout_flag, model) + fout_flag += 1 + else: + print(subdir_path + ' is not exist! please check!') + with open(fout, 'r') as f: + x = from_csv(f) + print(x) diff --git a/opencompass/tasks/subjective_eval.py b/opencompass/tasks/subjective_eval.py index a455f38d..a16d3141 100644 --- a/opencompass/tasks/subjective_eval.py +++ b/opencompass/tasks/subjective_eval.py @@ -139,7 +139,8 @@ class SubjectiveEvalTask(BaseTask): # If no predictions get in predictions dir assert osp.exists(filename) or osp.exists( osp.realpath(partial_filename) - ), 'No predictions found for {filename}.'.format(filename=filename) + ), 'No predictions found for {filename} and {partial_filename}'.format( + filename=filename, partial_filename=partial_filename) # If use Naive partition in infer stage if osp.exists(osp.realpath(filename)): @@ -188,10 +189,14 @@ class SubjectiveEvalTask(BaseTask): if fnmatch.fnmatch(ds_abbr, pattern): pred_postprocessor = model_postprocessors[pattern] break - if 'pred_postprocessor' in eval_cfg or pred_postprocessor: - kwargs = pred_postprocessor or eval_cfg['pred_postprocessor'] + if 'pred_postprocessor' in eval_cfg['evaluator'] or pred_postprocessor: + kwargs = pred_postprocessor or eval_cfg['evaluator'][ + 'pred_postprocessor'] proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type')) + self.logger.info('Get postprocessor {postprocessor}.') pred_strs = [proc(s, **kwargs) for s in pred_strs] + else: + self.logger.info('No postprocessor found.') return { 'model_name': model_abbr_from_cfg(model_cfg), diff --git a/opencompass/utils/run.py b/opencompass/utils/run.py index 5a53da0f..de6a8724 100644 --- a/opencompass/utils/run.py +++ b/opencompass/utils/run.py @@ -77,6 +77,17 @@ def get_config_from_arg(args) -> Config: if args.accelerator in ['vllm', 'lmdeploy']: config['models'] = change_accelerator(config['models'], args.accelerator) + if 'eval' in config and 'partitioner' in config['eval']: + if 'models' in config['eval']['partitioner']: + config['eval']['partitioner'][ + 'models'] = change_accelerator( + config['eval']['partitioner']['models'], + args.accelerator) + if 'judge_models' in config['eval']['partitioner']: + config['eval']['partitioner'][ + 'judge_models'] = change_accelerator( + config['eval']['partitioner']['judge_models'], + args.accelerator) return config # parse dataset args if not args.datasets and not args.custom_dataset_path: