diff --git a/configs/eval_korbench.py b/configs/eval_korbench.py new file mode 100644 index 00000000..91851c12 --- /dev/null +++ b/configs/eval_korbench.py @@ -0,0 +1,9 @@ +from mmengine import read_base + +with read_base(): + from opencompass.configs.datasets.korbench.korbench_single_0_shot_gen import korbench_0shot_single_datasets as zero_shot_datasets + from opencompass.configs.datasets.korbench.korbench_single_3_shot_gen import korbench_3shot_single_datasets as three_shot_datasets + from opencompass.configs.datasets.korbench.korbench_mixed_gen_d00bdd import korbench_mixed_datasets as mixed_datasets + from opencompass.configs.models.hf_internlm.hf_internlm2_5_7b import models as hf_internlm2_5_7b +datasets = zero_shot_datasets + three_shot_datasets + mixed_datasets +models = hf_internlm2_5_7b diff --git a/opencompass/configs/datasets/korbench/korbench_mixed_gen_d00bdd.py b/opencompass/configs/datasets/korbench/korbench_mixed_gen_d00bdd.py new file mode 100644 index 00000000..6447dfe3 --- /dev/null +++ b/opencompass/configs/datasets/korbench/korbench_mixed_gen_d00bdd.py @@ -0,0 +1,59 @@ +from opencompass.datasets.korbench.korbench import korbenchDataset, korbenchEvaluator +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +korbench_mixed_datasets = [] + +categories = ["Multi-Q", "Multi-R", "Multi-RQ"] # Define available modes for mixed mode + +for category in categories: + # Prompt template + prompt_template = dict( + type=PromptTemplate, + template=dict( + begin=[ + dict( + role="HUMAN", + prompt="" + ) + ], + round=[ + dict( + role="HUMAN", + prompt="{prompt}" # f-string + ) + ] + ) + ) + + # Reader configuration + reader_cfg = dict( + input_columns=["prompt"], + output_column="answer", + ) + + # Inference configuration + infer_cfg = dict( + prompt_template=prompt_template, + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=1024), + ) + + # Evaluation configuration + eval_cfg = dict( + evaluator=dict(type=korbenchEvaluator), + pred_role="BOT", + ) + + korbench_dataset = dict( + type=korbenchDataset, + abbr=f"korbench_mixed_{category}", + path="opencompass/korbench", + category=category, + mode='mixed', + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg, + ) + + korbench_mixed_datasets.append(korbench_dataset) \ No newline at end of file diff --git a/opencompass/configs/datasets/korbench/korbench_single_0_shot_gen.py b/opencompass/configs/datasets/korbench/korbench_single_0_shot_gen.py new file mode 100644 index 00000000..d04c9f60 --- /dev/null +++ b/opencompass/configs/datasets/korbench/korbench_single_0_shot_gen.py @@ -0,0 +1,60 @@ +from opencompass.datasets.korbench.korbench import korbenchDataset, korbenchEvaluator +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever + +categories = ["cipher", "counterfactual", "logic", "operation", "puzzle"] + +korbench_0shot_single_datasets = [] + +for category in categories: + # Prompt template + prompt_template = dict( + type=PromptTemplate, + template=dict( + begin=[ + dict( + role="HUMAN", + prompt="" + ) + ], + round=[ + dict( + role="HUMAN", + prompt="{prompt}" # f-string + ) + ] + ) + ) + + # Reader configuration + reader_cfg = dict( + input_columns=["prompt"], + output_column="answer", + ) + + # Inference configuration + infer_cfg = dict( + prompt_template=prompt_template, + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=1024), + ) + + # Evaluation configuration + eval_cfg = dict( + evaluator=dict(type=korbenchEvaluator), + pred_role="BOT", + ) + + korbench_dataset = dict( + type=korbenchDataset, + abbr=f"korbench_{category}_0shot", + path="opencompass/korbench", + mode='0_shot', + category=category, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg, + ) + + korbench_0shot_single_datasets.append(korbench_dataset) diff --git a/opencompass/configs/datasets/korbench/korbench_single_3_shot_gen.py b/opencompass/configs/datasets/korbench/korbench_single_3_shot_gen.py new file mode 100644 index 00000000..0d70f5f8 --- /dev/null +++ b/opencompass/configs/datasets/korbench/korbench_single_3_shot_gen.py @@ -0,0 +1,61 @@ +from opencompass.datasets.korbench.korbench import korbenchDataset, korbenchEvaluator + +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever + +categories = ["cipher", "counterfactual", "logic", "operation", "puzzle"] + +korbench_3shot_single_datasets = [] + +for category in categories: + # Prompt template + prompt_template = dict( + type=PromptTemplate, + template=dict( + begin=[ + dict( + role="HUMAN", + prompt="" + ) + ], + round=[ + dict( + role="HUMAN", + prompt="{prompt}" # f-string + ) + ] + ) + ) + + # Reader configuration + reader_cfg = dict( + input_columns=["prompt"], + output_column="answer", + ) + + # Inference configuration + infer_cfg = dict( + prompt_template=prompt_template, + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=1024), + ) + + # Evaluation configuration + eval_cfg = dict( + evaluator=dict(type=korbenchEvaluator), + pred_role="BOT", + ) + + korbench_dataset = dict( + type=korbenchDataset, + abbr=f"korbench_{category}_3shot", + path="opencompass/korbench", + mode='3_shot', + category=category, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg, + ) + + korbench_3shot_single_datasets.append(korbench_dataset) diff --git a/opencompass/configs/summarizers/groups/korbench.py b/opencompass/configs/summarizers/groups/korbench.py new file mode 100644 index 00000000..101fd65d --- /dev/null +++ b/opencompass/configs/summarizers/groups/korbench.py @@ -0,0 +1,5 @@ +korbench_summary_groups = [] +categories = ['cipher', 'counterfactual', 'logic', 'operation', 'puzzle'] +mixed_categories = ['Multi-Q', 'Multi-R', 'Multi-RQ'] +korbench_summary_groups.append({'name': 'korbench_single', 'subsets': [f'korbench_{c}' for c in categories]}) +korbench_summary_groups.append({'name': 'korbench_mixed', 'subsets': [f'korbench_{c}' for c in mixed_categories]}) diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index e96ffc28..ddb70b12 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -65,6 +65,7 @@ from .iwslt2017 import * # noqa: F401, F403 from .jigsawmultilingual import * # noqa: F401, F403 from .jsonl import JsonlDataset # noqa: F401, F403 from .kaoshi import KaoshiDataset, KaoshiEvaluator # noqa: F401, F403 +from .korbench import * # noqa: F401, F403 from .lambada import * # noqa: F401, F403 from .lawbench import * # noqa: F401, F403 from .LCBench import * # noqa: F401, F403 diff --git a/opencompass/datasets/korbench/korbench.py b/opencompass/datasets/korbench/korbench.py new file mode 100644 index 00000000..b0f649e1 --- /dev/null +++ b/opencompass/datasets/korbench/korbench.py @@ -0,0 +1,215 @@ +import os + +from datasets import Dataset + +from opencompass.datasets.korbench.korbench_utils import ( + evaluate_responses, find_file, load_json_or_jsonl, + load_json_or_jsonl_with_idx, load_yaml) +from opencompass.openicl.icl_evaluator import BaseEvaluator +from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET +from opencompass.utils import get_data_path + +from ..base import BaseDataset + + +@LOAD_DATASET.register_module() +class korbenchDataset(BaseDataset): + """Dataset loader for the task in KOR-Bench.""" + + @staticmethod + def load(path, mode, category): + """Load the dataset using shared .""" + base_path = get_data_path(path) + rule_file = None + sample_file = None + mixed_file = None + mixed_data = None + if '0_shot' in mode or '3_shot' in mode: + rule_file = find_file(base_path, os.path.join(category, 'rule')) + sample_file = find_file(base_path, + os.path.join(category, 'sample')) + elif mode == 'mixed': + mixed_file = find_file(base_path, os.path.join('mixed', category)) + mixed_data = load_json_or_jsonl(mixed_file) or [] + else: + raise ValueError(f'Unsupported mode: {mode}') + three_shot_file = None + if mode == '3_shot': + ts_path = os.path.join(category, 'three-shot') + three_shot_file = find_file(base_path, ts_path) + # Load data + if mode in ['0_shot', '3_shot']: + rules = load_json_or_jsonl(rule_file) or [] + samples = load_json_or_jsonl(sample_file) or [] + template_path = None + if mode == '0_shot': + template_path = os.path.join( + os.path.dirname(__file__), + 'korbench_dataset_config/prompt/0_shot.yaml') + elif mode == '3_shot': + template_path = os.path.join( + os.path.dirname(__file__), + 'korbench_dataset_config/prompt/3_shot.yaml') + elif mode == 'mixed': + template_path = os.path.join( + os.path.dirname(__file__), + 'korbench_dataset_config/prompt/mixed.yaml') + try: + template = load_yaml(template_path) + except FileNotFoundError: + print(f'[ERROR] Missing prompt template: {template_path}') + return Dataset.from_list([]) + + # Process data + data = [] + if mode == '0_shot': + for sample in samples: + rule_id = sample['rule_id'] + rule = next((r for r in rules if r['idx'] == rule_id), None) + if not rule: + print(f"[WARNING] Rule ID {sample['rule_id']} not found." + 'Skipping...') + continue + prompt_key = f'{category}_prompt_format' + prompt = template[prompt_key][0].format( + rule['rule_content'], sample['question']) + + # Add processed item + data.append({ + 'rule_content': rule['rule_content'], + 'question': sample['question'], + 'answer': sample['answer'], + 'prompt': prompt, + 'rule_id': rule['idx'], + 'mode': '0_shot', + 'category': category, + }) + + return Dataset.from_list(data) + + if mode == '3_shot': + data = [] + three_shot = load_json_or_jsonl(three_shot_file) or [] + for sample in samples: + rule_id = sample['rule_id'] + rule = next((r for r in rules if r['idx'] == rule_id), None) + three_shot_qa = [ + item for fs in three_shot if fs['rule_id'] == rule_id + for item in [fs['question'], fs['answer']] + ] + if not rule: + print(f"[WARNING] Rule ID {sample['rule_id']} not found." + 'Skipping...') + continue + prompt_key = f'{category}_prompt_format' + prompt = template[prompt_key][0].format( + rule['rule_content'], *three_shot_qa, sample['question']) + # Add processed item + data.append({ + 'rule_content': rule['rule_content'], + 'question': sample['question'], + 'answer': sample['answer'], + 'prompt': prompt, + 'rule_id': rule['idx'], + 'mode': '3_shot', + 'category': category, + }) + + return Dataset.from_list(data) + + if mode == 'mixed': + # Process data + data = [] + for item in mixed_data: + rule_list = item['rule_list'] + question_list = item['question_list'] + rule_content_list = [] + question_content_list = [] + + # Fetch rules and questions + for rule in rule_list: + category, rule_idx = rule.rsplit('_', 1) + rule_content = load_json_or_jsonl_with_idx(base_path, + os.path.join( + category, + 'rule'), + idx=rule_idx) + rule_content_list.append(rule_content['rule_content']) + + for question in question_list: + category, question_idx = question.rsplit('_', 1) + question_content = load_json_or_jsonl_with_idx( + base_path, + os.path.join(category, 'sample'), + idx=question_idx) + question_content_list.append(question_content['question']) + + # Prepare prompt + rules_str = '\n'.join( + f'Rule {i+1}: {content}' + for i, content in enumerate(rule_content_list)) + questions_str = '\n'.join( + f'Question {i+1}: {content}' + for i, content in enumerate(question_content_list)) + prompt_format = [rules_str, questions_str] + prompt = template['prompt_format'][0].format(*prompt_format) + + # Add processed item + data.append({ + 'rule_list': rule_list, + 'question_list': question_list, + 'prompt': prompt, + 'mode': 'mixed', + 'answer': '', + 'base_path': base_path, + }) + + return Dataset.from_list(data) + + +@ICL_EVALUATORS.register_module() +class korbenchEvaluator(BaseEvaluator): + + def __init__(self): + super().__init__() + + def score(self, predictions, references, test_set): + """Evaluate predictions for a single mode in KOR-Bench.""" + if not test_set: + raise ValueError('Test set is empty.') + + mode = test_set[0]['mode'] # Determine the mode from the first entry + data = {} + + # Organize data for the given mode + for i in range(len(predictions)): + entry = { + 'prediction': predictions[i], + 'gold': references[i], + 'rule_id': test_set[i].get('rule_id', None), + 'category': test_set[i].get('category', None), + 'rule_list': test_set[i].get('rule_list', None), + 'question_list': test_set[i].get('question_list', None), + 'base_path': test_set[i].get('base_path', None), + } + data[i] = entry + + if not data: + raise ValueError(f"No data found for mode '{mode}'") + + # Evaluate based on the mode + if mode == '0_shot': + evaluation_results = evaluate_responses(data, '0_shot') + elif mode == '3_shot': + evaluation_results = evaluate_responses(data, '3_shot') + elif mode in ['Multi-Q', 'Multi-R', 'Multi-RQ', 'mixed']: + evaluation_results = evaluate_responses(data, 'mixed', + test_set[0]['base_path']) + else: + raise ValueError(f'Unsupported mode: {mode}') + # Calculate accuracy + correct_count = sum(res['is_correct'] for res in evaluation_results) + accuracy = (correct_count / len(evaluation_results)) * 100 + + # Return scores + return {'accuracy': accuracy} diff --git a/opencompass/datasets/korbench/korbench_dataset_config/config.yaml b/opencompass/datasets/korbench/korbench_dataset_config/config.yaml new file mode 100644 index 00000000..c9e8bef0 --- /dev/null +++ b/opencompass/datasets/korbench/korbench_dataset_config/config.yaml @@ -0,0 +1,15 @@ +# Necessary +response_key: 'response' +error_key: 'error' +id_key: + - 'idx' + - 'step' +prompt_key: 'prompt' + +# Optional +history_key: 'history' +status_key: 'status' + +save_prompt: True +max_tokens: 2000 +max_rounds: 5 diff --git a/opencompass/datasets/korbench/korbench_dataset_config/config_wrapper.py b/opencompass/datasets/korbench/korbench_dataset_config/config_wrapper.py new file mode 100644 index 00000000..13c2caa7 --- /dev/null +++ b/opencompass/datasets/korbench/korbench_dataset_config/config_wrapper.py @@ -0,0 +1,90 @@ +import yaml + + +class ConfigWrapper: + + def __init__(self, config_path): + self._config = {} + with open(config_path, 'r') as file: + self._config = yaml.safe_load(file) + for key, value in self._config.items(): + setattr(self, key, value) + + def __setattr__(self, key, value): + if key.startswith('_'): + super().__setattr__(key, value) + else: + self._config[key] = value + super().__setattr__(key, value) + + def __getattr__(self, key): + if key in self._config: + return self._config[key] + raise AttributeError( + f"'ConfigWrapper' object has no attribute '{key}'") + + def get_id(self, data): + if isinstance(self._config.get('id_key'), str): + return data.get(self._config.get('id_key'), None) + elif isinstance(self._config.get('id_key'), list): + return '_'.join([ + str(data[key]) for key in self._config.get('id_key') + if key in data + ]) + + def print_all_keys(self): + print('config keys:') + for key, value in self._config.items(): + print(f' - {key}: {value}') + + +config_wrapper = None + + +def initialize_config(config_path): + global config_wrapper + config_wrapper = ConfigWrapper(config_path) + + +def get_config_wrapper(): + global config_wrapper + if config_wrapper is None: + raise RuntimeError( + 'ConfigWrapper not initialized. Call initialize_config first.') + return config_wrapper + + +if __name__ == '__main__': + config_path = 'config/config.yaml' + initialize_config(config_path) + data = { + 'idx': + '50', + 'step': + 21, + 'question': + ('Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"\n\n' + 'Please provide the decrypted answer, ' + 'encapsulated in double square brackets. ' + 'For example, the format should be: [[decrypted answer]].'), + 'answer': + '[[P]]', + 'category': + 'Decryption', + 'rule_id': + '23', + 'input': + 'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"', + 'steps_num': + 23, + 'description': + ('For a number c=228 in the ciphertext:\nCalculate z = c^e mod n.' + ' Here ^ means multiplication.\nz is 80.\n' + 'Based on the decimal number represented by z, ' + 'use the ascii code to find the corresponding' + ' letter as the plaintext letter p.\n' + 'Please give the letter p in [[...]] format.\n'), + 'atom': + 80 + } + print(config_wrapper.get_id(data)) diff --git a/opencompass/datasets/korbench/korbench_dataset_config/prompt/0_shot.yaml b/opencompass/datasets/korbench/korbench_dataset_config/prompt/0_shot.yaml new file mode 100644 index 00000000..1caa4144 --- /dev/null +++ b/opencompass/datasets/korbench/korbench_dataset_config/prompt/0_shot.yaml @@ -0,0 +1,94 @@ +cipher_prompt_format: + - | + You are an intelligent assistant that specializes in encryption and decryption tasks. Below are the rules for a specific cipher. When responding, please ensure that your output adheres to the specified encryption and decryption rules and format. + + ### Instructions: + + 1. Identify the relevant properties and objects specified in the rule, including the plaintext, keyword, and ciphertext. + 2. Follow the specified encryption or decryption operations precisely as described in the rules. + 3. Ensure your output is formatted according to the specified notation and symbols. + + ### Cipher Rule: + + {} + + ### Question: + {} + + ### Answer: + +counterfactual_prompt_format: + - | + You are an advanced assistant with expertise in storytelling and rule-based reasoning. Your task is to carefully analyze the provided story, which includes specific rules and details, and use this information to accurately answer related questions. + + ### Instructions: + + 1. Thoroughly review the story to identify and understand the relevant details and rules. + 2. Use the context provided by the story to offer precise and insightful answers. + 3. Ensure your responses align with the rules and information given in the story. + + ### Story Rule: + + {} + + ### Question: + {} + + ### Answer: + +logic_prompt_format: + - | + You are an intelligent assistant that helps with various logical reasoning tasks. Below is a custom-defined rule for a specific type of logic. When responding, please ensure that your output adheres to the specified logical rules and format. + + ### Instructions: + + 1. Identify the relevant properties and objects as specified in the rule. + 2. Apply the given logical operations or reasoning patterns. + 3. Ensure your output is formatted according to the specified notation and symbols. + + ### Logic Rule: + + {} + + ### Question: + {} + + ### Answer: + +operation_prompt_format: + - | + You are an intelligent assistant specializing in evaluating custom operations. Below is a specific rule defined for a custom operation. Your task is to apply this rule accurately to the provided question. + + ### Instructions: + + 1. Carefully read and understand the definitions of the new operations in the rule. + 2. If the question does not specifically ask for it, your answer should be a number or a group of numbers. + 3. Double-check your final answer to ensure it follows the rule accurately. + + ### Operation Rule: + + {} + + ### Question: + {} + + ### Answer: + +puzzle_prompt_format: + - | + You are an intelligent assistant specializing in solving custom puzzle problems. Below is a specific rule defined for a custom puzzle. Your task is to apply this rule accurately to the provided question. + + ### Instructions: + + 1. Thoroughly understand the rule provided. If needed, break down the rule into simpler components or steps. + 2. Apply the rule carefully to address the question presented. + 3. Verify your answer to ensure it aligns with the rule and the context of the puzzle. + + ### Puzzle Rule: + + {} + + ### Question: + {} + + ### Answer: diff --git a/opencompass/datasets/korbench/korbench_dataset_config/prompt/3_shot.yaml b/opencompass/datasets/korbench/korbench_dataset_config/prompt/3_shot.yaml new file mode 100644 index 00000000..5de1e6b5 --- /dev/null +++ b/opencompass/datasets/korbench/korbench_dataset_config/prompt/3_shot.yaml @@ -0,0 +1,184 @@ +cipher_prompt_format: + - | + You are an intelligent assistant that specializes in encryption and decryption tasks. Below are the rules for a specific cipher. When responding, please ensure that your output adheres to the specified encryption and decryption rules and format. + + ### Instructions: + + 1. Identify the relevant properties and objects specified in the rule, including the plaintext, keyword, and ciphertext. + 2. Follow the specified encryption or decryption operations precisely as described in the rules. + 3. Ensure your output is formatted according to the specified notation and symbols. + + ### Cipher Rule: + + {} + + ### Question: + {} + + ### Answer: + {} + + ### Question: + {} + + ### Answer: + {} + + ### Question: + {} + + ### Answer: + {} + + ### Question: + {} + + ### Answer: + +counterfactual_prompt_format: + - | + You are an advanced assistant with expertise in storytelling and rule-based reasoning. Your task is to carefully analyze the provided story, which includes specific rules and details, and use this information to accurately answer related questions. + + ### Instructions: + + 1. Thoroughly review the story to identify and understand the relevant details and rules. + 2. Use the context provided by the story to offer precise and insightful answers. + 3. Ensure your responses align with the rules and information given in the story. + + ### Story Rule: + + {} + + ### Question: + {} + + ### Answer: + {} + + ### Question: + {} + + ### Answer: + {} + + ### Question: + {} + + ### Answer: + {} + + ### Question: + {} + + ### Answer: + +logic_prompt_format: + - | + You are an intelligent assistant that helps with various logical reasoning tasks. Below is a custom-defined rule for a specific type of logic. When responding, please ensure that your output adheres to the specified logical rules and format. + + ### Instructions: + + 1. Identify the relevant properties and objects as specified in the rule. + 2. Apply the given logical operations or reasoning patterns. + 3. Ensure your output is formatted according to the specified notation and symbols. + + ### Logic Rule: + + {} + + ### Question: + {} + + ### Answer: + {} + + ### Question: + {} + + ### Answer: + {} + + ### Question: + {} + + ### Answer: + {} + + ### Question: + {} + + ### Answer: + +operation_prompt_format: + - | + You are an intelligent assistant specializing in evaluating custom operations. Below is a specific rule defined for a custom operation. Your task is to apply this rule accurately to the provided question. + + ### Instructions: + + 1. Carefully read and understand the definitions of the new operations in the rule. + 2. If the question does not specifically ask for it, your answer should be a number or a group of numbers. + 3. Double-check your final answer to ensure it follows the rule accurately. + + ### Operation Rule: + + {} + + ### Question: + {} + + ### Answer: + {} + + ### Question: + {} + + ### Answer: + {} + + ### Question: + {} + + ### Answer: + {} + + ### Question: + {} + + ### Answer: + +puzzle_prompt_format: + - | + You are an intelligent assistant specializing in solving custom puzzle problems. Below is a specific rule defined for a custom puzzle. Your task is to apply this rule accurately to the provided question. + + ### Instructions: + + 1. Thoroughly understand the rule provided. If needed, break down the rule into simpler components or steps. + 2. Apply the rule carefully to address the question presented. + 3. Verify your answer to ensure it aligns with the rule and the context of the puzzle. + + ### Puzzle Rule: + + {} + + ### Question: + {} + + ### Answer: + {} + + ### Question: + {} + + ### Answer: + {} + + ### Question: + {} + + ### Answer: + {} + + ### Question: + {} + + ### Answer: diff --git a/opencompass/datasets/korbench/korbench_dataset_config/prompt/mixed.yaml b/opencompass/datasets/korbench/korbench_dataset_config/prompt/mixed.yaml new file mode 100644 index 00000000..d13cbb5b --- /dev/null +++ b/opencompass/datasets/korbench/korbench_dataset_config/prompt/mixed.yaml @@ -0,0 +1,22 @@ +prompt_format: + - | + You are an intelligent assistant capable of handling all types of reasoning and problem-solving tasks. Below is the text of a set of rules. Your task is to apply the appropriate rules to solve a series of problems. + + ### Instructions: + 1. Read each question carefully and rules to find something relevant to that question. + 2. Use the relevant rules to answer each question accurately. + 3. Provide the final answers to all questions in JSON format. + {{ + "question1": "your answer", + "question2": "your answer", + "question3": "your answer", + }} + + ### Rules: + + {} + + ### Questions: + {} + + ### Answers: diff --git a/opencompass/datasets/korbench/korbench_dataset_config/prompt/self-correction.yaml b/opencompass/datasets/korbench/korbench_dataset_config/prompt/self-correction.yaml new file mode 100644 index 00000000..36a7b0a9 --- /dev/null +++ b/opencompass/datasets/korbench/korbench_dataset_config/prompt/self-correction.yaml @@ -0,0 +1,3 @@ +prompt_format: + - | + Your answer is incorrect, please check your answer and provide a correct one. diff --git a/opencompass/datasets/korbench/korbench_dataset_config/prompt/trick.yaml b/opencompass/datasets/korbench/korbench_dataset_config/prompt/trick.yaml new file mode 100644 index 00000000..a415c916 --- /dev/null +++ b/opencompass/datasets/korbench/korbench_dataset_config/prompt/trick.yaml @@ -0,0 +1,20 @@ +prompt_format: + - | + You are an intelligent assistant specializing in solving custom puzzle problems. Below is a specific rule defined for a custom puzzle. Your task is to apply this rule accurately to the provided question. + + ### Instructions: + + 1. Thoroughly understand the rule provided. If needed, break down the rule into simpler components or steps. + 2. Apply the rule carefully to address the question presented. + 3. Verify your answer to ensure it aligns with the rule and the context of the puzzle. + + ### Puzzle Rule: + + {} + + ### Question: + {} + + {} + + ### Answer: diff --git a/opencompass/datasets/korbench/korbench_utils.py b/opencompass/datasets/korbench/korbench_utils.py new file mode 100644 index 00000000..8b59766a --- /dev/null +++ b/opencompass/datasets/korbench/korbench_utils.py @@ -0,0 +1,699 @@ +import json +import os +import re + +import sympy as sp +import yaml +from sympy.parsing.latex import parse_latex + + +def load_yaml(yaml_path): + """Load a YAML file.""" + if not os.path.exists(yaml_path): + raise FileNotFoundError(f'YAML file not found: {yaml_path}') + with open(yaml_path, 'r', encoding='utf-8') as file: + return yaml.safe_load(file) + + +def load_json_or_jsonl(file_path): + """Load data from a JSON or JSONL file.""" + if not os.path.exists(file_path): + return None + with open(file_path, 'r', encoding='utf-8') as file: + if file_path.endswith('.json'): + return json.load(file) + elif file_path.endswith('.jsonl'): + return [json.loads(line) for line in file] + return None + + +def find_file(base_path, sub_path, extensions=('json', 'jsonl')): + """Find the first available file with given extensions.""" + for ext in extensions: + file_path = os.path.join(base_path, f'{sub_path}.{ext}') + if os.path.exists(file_path): + return file_path + return None + + +def load_json_or_jsonl_with_idx(data_path, split='', idx=None): + base_path = os.path.join(data_path, split) + if os.path.exists(f'{base_path}.json'): + file_path = f'{base_path}.json' + elif os.path.exists(f'{base_path}.jsonl'): + file_path = f'{base_path}.jsonl' + elif base_path.endswith('.json') or base_path.endswith('.jsonl'): + file_path = base_path + else: + raise FileNotFoundError('No JSON or JSONL file found.') + + with open(file_path, 'r', encoding='utf-8') as file: + if file_path.endswith('.json'): + data = json.load(file) + elif file_path.endswith('.jsonl'): + data = [json.loads(line) for line in file] + + if idx is not None: + try: + return next(item for item in data if item.get('idx') == idx) + except StopIteration: + raise ValueError(f'No entry found for idx {idx}') + else: + return data + + +def load_split_data(base_path, split_name): + """Load the rule and sample data for a specific split.""" + split_path = os.path.join(base_path, split_name) + rule_path = find_file(split_path, 'rule') + sample_path = find_file(split_path, 'sample') + + rules = load_json_or_jsonl(rule_path) if rule_path else [] + samples = load_json_or_jsonl(sample_path) if sample_path else [] + + return {'rules': rules, 'samples': samples} + + +def process_mixed_data(base_path, mode): + """Load and process data for the 'mixed' split and specific mode.""" + mixed_path = os.path.join(base_path, 'mixed') + file_path = find_file(mixed_path, mode) + if not file_path: + print(f'[WARNING] Missing file for mixed mode: {mode}') + return [] + + data = load_json_or_jsonl(file_path) + template_path = os.path.join(base_path, 'config/prompt/mixed.yaml') + template = load_yaml(template_path) + + processed = [] + for item in data: + rules = '\n'.join(item.get('rule_list', [])) + questions = '\n'.join(item.get('question_list', [])) + item['prompt'] = template['prompt_format'][0].format(rules, questions) + processed.append(item) + + return processed + + +class ConfigWrapper: + + def __init__(self, config_path): + self._config = {} + with open(config_path, 'r') as file: + self._config = yaml.safe_load(file) + for key, value in self._config.items(): + setattr(self, key, value) + + def __setattr__(self, key, value): + if key.startswith('_'): + super().__setattr__(key, value) + else: + self._config[key] = value + super().__setattr__(key, value) + + def __getattr__(self, key): + if key in self._config: + return self._config[key] + raise AttributeError( + f"'ConfigWrapper' object has no attribute '{key}'") + + def get_id(self, data): + if isinstance(self._config.get('id_key'), str): + return data.get(self._config.get('id_key'), None) + elif isinstance(self._config.get('id_key'), list): + return '_'.join([ + str(data[key]) for key in self._config.get('id_key') + if key in data + ]) + + def print_all_keys(self): + print('config keys:') + for key, value in self._config.items(): + print(f' - {key}: {value}') + + +config_wrapper = None + + +def initialize_config(config_path): + global config_wrapper + config_wrapper = ConfigWrapper(config_path) + + +def get_config_wrapper(): + global config_wrapper + if config_wrapper is None: + raise RuntimeError( + 'ConfigWrapper not initialized. Call initialize_config first.') + return config_wrapper + + +if __name__ == '__main__': + config_path = 'config/config.yaml' + initialize_config(config_path) + data = { + 'idx': + '50', + 'step': + 21, + 'question': + ('Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"\n\n' + 'Please provide the decrypted answer, encapsulated in double ' + 'square brackets. ' + 'For example, the format should be: [[decrypted answer]].'), + 'answer': + '[[P]]', + 'category': + 'Decryption', + 'rule_id': + '23', + 'input': + 'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"', + 'steps_num': + 23, + 'description': + ('For a number c=228 in the ciphertext:\n' + 'Calculate z = c^e mod n. Here ^ means multiplication.\n' + 'z is 80.\nBased on the decimal number represented by z, ' + 'use the ascii code to find the corresponding letter ' + 'as the plaintext letter p.\n' + 'Please give the letter p in [[...]] format.\n'), + 'atom': + 80 + } + print(config_wrapper.get_id(data)) + + +def read_yaml(config='default'): + if os.path.exists(f'config/prompt/{config}.yaml'): + yaml_file = f'config/prompt/{config}.yaml' + else: + yaml_file = config + with open(yaml_file, 'r') as yaml_file: + return yaml.safe_load(yaml_file) + + +def write_jsonl_lines(file, data): + config_wrapper = get_config_wrapper() + if config_wrapper.save_prompt: + json.dump(data, file, ensure_ascii=False) + else: + data.pop(config_wrapper.prompt_key) + json.dump(data, file, ensure_ascii=False) + file.write('\n') + file.flush() + + +def print_info(info): + print('-' * 100) + print('[INFO] model_name:', info['model_name']) + print('[INFO] splits:', info['splits']) + print('[INFO] modes:', info['modes']) + print('[INFO] output_dir:', info['output_dir']) + print('[INFO] Infer Limit:', + 'No limit' if info['infer_limit'] is None else info['infer_limit']) + print('[INFO] Number of Workers:', info['num_workers']) + print('[INFO] Batch Size:', info['batch_size']) + print('[INFO] Use Accel:', info['use_accel']) + print('-' * 100) + + +def read_json_or_jsonl(data_path, split='', mapping_key=None): + base_path = os.path.join(data_path, split) + if os.path.exists(f'{base_path}.json'): + file_path = f'{base_path}.json' + elif os.path.exists(f'{base_path}.jsonl'): + file_path = f'{base_path}.jsonl' + elif base_path.endswith('.json') or base_path.endswith('.jsonl'): + file_path = base_path + else: + raise FileNotFoundError('No JSON or JSONL file found.') + + with open(file_path, 'r') as file: + if file_path.endswith('.json'): + data = json.load(file) + elif file_path.endswith('.jsonl'): + data = [json.loads(line) for line in file] + + if mapping_key: + return { + item[mapping_key]: item + for item in data if mapping_key in item + } + else: + return data + + +def read_json_or_jsonl_with_idx(data_path, split='', idx=None): + base_path = os.path.join(data_path, split) + if os.path.exists(f'{base_path}.json'): + file_path = f'{base_path}.json' + elif os.path.exists(f'{base_path}.jsonl'): + file_path = f'{base_path}.jsonl' + elif base_path.endswith('.json') or base_path.endswith('.jsonl'): + file_path = base_path + else: + raise FileNotFoundError('No JSON or JSONL file found.') + + with open(file_path, 'r', encoding='utf-8') as file: + if file_path.endswith('.json'): + data = json.load(file) + elif file_path.endswith('.jsonl'): + data = [json.loads(line) for line in file] + + if idx is not None: + try: + return next(item for item in data if item.get('idx') == idx) + except StopIteration: + raise ValueError(f'No entry found for idx {idx}') + else: + return data + + +idx_ranges = [ + [18], + [73, 74, 77], + [94], + [115, 116, 117], + [121, 122, 123, 125], + [131, 132, 134, 135, 136], + [141, 143, 149], + list(range(145, 148)), + list(range(151, 157)), + [160, 161, 162], + [164, 165, 166], + [170], + [206, 209], + list(range(211, 216)), + [217, 218], +] + + +def clean_json_string(json_str): + json_str = re.sub(r'[\x00-\x1F\x7F]', '', json_str) + return json_str + + +def is_in_idx_ranges(idx, idx_ranges): + for range_list in idx_ranges: + if int(idx) in range_list: + return True + return False + + +def extract_json(text): + matches = re.findall(r'{.*}', text, re.DOTALL) + if matches: + json_str = matches[-1] + json_str = clean_json_string(json_str) + try: + data = json.loads(json_str) + return data + except json.JSONDecodeError as e: + print(f'Error decoding JSON: {e}') + return 'NULL' + return 'NULL' + + +def extract_all_responses_from_json(response_json): + results = [] + for key, value in response_json.items(): + results.append(str(value)) + return results + + +def clean_latex(latex_expr): + if '=' in latex_expr: + latex_expr = latex_expr.rsplit('=', 1)[1] + latex_expr = re.sub(r'\\[()\[\]]', '', latex_expr) + latex_expr = re.sub(r'\\text\{.*?\}', '', latex_expr) + latex_expr = re.sub(r'\\(left|right|displaystyle)', '', latex_expr) + latex_expr = latex_expr.replace('\\\\', '\\') + return latex_expr + + +def extract_text_from_brackets(text, clean_level='basic'): + matches = re.findall(r'\[\[\s*(.*?)\s*\]\]', text, re.DOTALL) + if not matches: + matches = re.findall(r'\$\\boxed\{(.*?)\}\$', text, re.DOTALL) + if not matches: + matches = re.findall(r'\[\s*(.*?)\s*\]', text, re.DOTALL) + if matches: + match_str = matches[0].strip() + if clean_level == 'clean': + match_str = match_str.replace('"', '').replace('\n', '').replace( + ' ', '').replace('[', '').replace(']', '') + elif clean_level == 'logic': + match_str = match_str.replace('"', '').replace('\n', '').replace( + ' ', '').replace('.', '') + elif clean_level == 'math': + match_str = match_str.replace('"', '').replace('\n', '').replace( + '[', '').replace(']', '').replace('$', '') + return f'{clean_latex(match_str)}' + return f'[[{match_str}]]' + return 'NULL' + + +def extract_inner_text_from_brackets(text): + if not isinstance(text, str): + print(f'text type: {type(text)}, text value: {text}') + return 'NULL' + match = re.search(r'\[\[(.*?)\]\]', text, re.DOTALL) + return match.group(1) if match else 'NULL' + + +def extract_numbers(str): + numbers = re.findall(r'\d+', str) + numbers = list(map(int, numbers)) + return numbers + + +def extract_and_sort_inequalities(latex_expr): + pattern = r'(≥|≤)\s*([-]?\d+\.?\d*)' + matches = re.findall(pattern, latex_expr) + extracted_inequalities = [''.join(match) for match in matches] + sorted_inequalities = sorted(extracted_inequalities) + return sorted_inequalities + + +def rule5_normalize_content(content): + parts = [part for part in content.split(';')] + sorted_parts = sorted(parts) + return sorted_parts + + +def normalize_string(s): + s = re.sub(r'[^0-9]', '', s) + pairs = s.split(',') + pairs.sort() + return pairs + + +def remove_commas_and_spaces(s): + return re.sub(r'[,\s\[\]]+', '', s) + + +def remove_non_alphanumeric(s): + return re.sub(r'\W+', '', s) + + +def contains_or(answer): + return 'or' in answer + + +def compare_multi_results(response, answer): + try: + response_text = extract_text_from_brackets(response, 'clean') + response_text = re.sub(r'\\text\{or\}', 'or', response_text) + if response_text == 'NULL': + return False + answer = extract_text_from_brackets(answer, 'clean') + response_split = response_text.strip('[[]]').split('or') + answer_split = answer.strip('[[]]').split('or') + response_sorted = sorted([x.strip() for x in response_split]) + answer_sorted = sorted([x.strip() for x in answer_split]) + return response_sorted == answer_sorted + except Exception as e: + print(f'Error during comparison: {e}') + return False + + +def split_or_expression(expression): + return [part.strip() for part in expression.split('or')] + + +def compare_math_expressions(response, answer): + response_text = extract_text_from_brackets(response, 'math') + answer_text = extract_text_from_brackets(answer, 'math') + if response_text == 'NULL': + return False + if contains_or(answer_text): + response_parts = split_or_expression(response_text) + answer_parts = split_or_expression(answer_text) + try: + response_exprs = { + sp.simplify(parse_latex(part)) + for part in response_parts + } + answer_exprs = { + sp.simplify(parse_latex(part)) + for part in answer_parts + } + return response_exprs == answer_exprs + except Exception as e: + print(f'Error during simplification or parsing: {e}') + return response_text == answer_text + else: + try: + response_expr = sp.simplify(parse_latex(response_text)) + answer_expr = sp.simplify(parse_latex(answer_text)) + return response_expr == answer_expr + except Exception as e: + print(f'Error during simplification or parsing: {e}') + return response_text == answer_text + + +def method_equal(response_text, answer): + return response_text == answer + + +def method_1(response_text, answer): + cleaned_string = re.sub(r'[^A-Za-z]', '', response_text) + cleaned_string = cleaned_string.lower() + answer = re.sub(r'[^A-Za-z]', '', answer) + answer = answer.lower() + return cleaned_string == answer + + +def method_2(response_text, answer): + cleaned_string = re.sub(r'[^A-Za-z]', '', response_text) + cleaned_string = cleaned_string.lower() + answer = answer.split(',') + return cleaned_string in answer + + +def method_3(response_text, answer): + response_text = response_text.lower() + pairs1 = re.split(r'\W+', response_text) + pairs2 = answer.split(' ') + pairs1 = [word for word in pairs1 if word] + pairs1.sort() + pairs2.sort() + return pairs1 == pairs2 + + +def method_4(response_text, answer): + cleaned_string = re.sub(r'[^A-Za-z]', '', response_text) + cleaned_string = cleaned_string.lower() + return cleaned_string in answer + + +def method_5(response_text, answer): + response_text = re.sub(r'\s+', '', response_text) + response_text = response_text.split(',') + answer = answer.split(',') + response_text.sort() + answer.sort() + return response_text == answer + + +def method_9(response_text, answer): + response_text = response_text.replace('×', '*').replace('−', '-') + answer = answer.replace('×', '*').replace('−', '-') + + def extract_operators(s): + return re.findall(r'[+\-*/]', s) + + response_ops = extract_operators(response_text.split('=')[0]) + answer_ops = extract_operators(answer.split('=')[0]) + if response_ops != answer_ops: + return False + match = re.search(r'=\s*(-?\d+)', answer) + expected_result = int(match.group(1)) + try: + left_side = response_text.split('=')[0] + result = eval(left_side) + except Exception as e: + print(f'Error during evaluation: {e}') + return False + return result == expected_result + + +def method_10(response_text, answer): + response_text = response_text.replace('×', '*').replace('−', '-') + response_text = response_text.split('=')[0] + answer = answer.split('\n')[0].split('=')[0] + response_ops = sorted(remove_non_alphanumeric(response_text)) + answer_ops = sorted(remove_non_alphanumeric(answer)) + if response_ops != answer_ops: + return False + try: + result = eval(response_text) + except Exception as e: + print(f'Error during evaluation: {e}') + return False + return result == 24 + + +def method_18(response_text, answer): + cleaned_s1 = remove_commas_and_spaces(response_text) + cleaned_s2 = remove_commas_and_spaces(answer) + return cleaned_s1 == cleaned_s2 + + +def method_general(response_text, answer): + cleaned_s1 = remove_non_alphanumeric(response_text) + cleaned_s2 = remove_non_alphanumeric(answer) + return cleaned_s1 == cleaned_s2 + + +question_methods = { + '1': method_1, + '2': method_2, + '3': method_3, + '4': method_4, + '5': method_5, + '9': method_9, + '10': method_10, + '18': method_18, +} + + +def evaluate_response_vs_answer(response, answer, question_type, rule_id, idx): + if question_type == 'logic' and rule_id == '5': + response_text = extract_text_from_brackets(response, 'logic') + answer_text = extract_text_from_brackets(answer, 'logic') + if response_text is None: + return False + normalized_response = rule5_normalize_content(response_text) + normalized_answer = rule5_normalize_content(answer) + return normalized_response == normalized_answer + elif question_type == 'logic': + response_text = extract_text_from_brackets(response, 'logic') + answer_text = extract_text_from_brackets(answer, 'logic') + return response_text == answer_text + elif question_type == 'operation' and (idx == '178' or idx == '179'): + response_text = extract_text_from_brackets(response, 'clean') + response_text = extract_and_sort_inequalities(response_text) + answer_text = extract_and_sort_inequalities(answer) + # print(response_text, answer_text) + return response_text == answer_text + elif question_type == 'operation' and rule_id == '18': + response_text = extract_text_from_brackets(response, 'clean') + answer = extract_inner_text_from_brackets(answer) + response_text = ''.join(sorted(re.sub(r'\W+', '', response_text))) + answer = ''.join(sorted(re.sub(r'\W+', '', answer))) + return response_text == answer + elif question_type == 'operation' and rule_id in {'23', '24', '25'}: + response_text = extract_text_from_brackets(response, 'clean') + if response_text is None: + return False + response_text = extract_numbers(response_text) + answer_text = extract_numbers(answer) + return response_text == answer_text + elif question_type == 'operation' and is_in_idx_ranges(idx, idx_ranges): + return compare_math_expressions(response, answer) + elif question_type == 'operation' and contains_or(answer): + return compare_multi_results(response, answer) + elif question_type == 'puzzle': + response_text = extract_inner_text_from_brackets(response) + answer = extract_inner_text_from_brackets(answer) + method = question_methods.get(rule_id) + if method: + return method(response_text, answer) + return method_general(response_text, answer) + else: + response_text = extract_text_from_brackets(response, 'clean') + return response_text == answer + + +def compute_one_mixed_question_pass_rate(idx, + question_list, + response_json, + base_path=None): + if response_json == 'NULL': + result_dict = { + 'idx': idx, + 'response': response_json, + 'details': None, + 'pass_rate': 0, + 'is_correct': False + } + return result_dict + response_list = extract_all_responses_from_json(response_json) + correct_num = 0 + results = [] + for q_idx, question in enumerate(question_list): + category, question_idx = question.rsplit('_', 1) + question_content = load_json_or_jsonl_with_idx(base_path, + os.path.join( + category, 'sample'), + idx=question_idx) + answer = question_content['answer'] + if q_idx >= len(response_list): + break + response = response_list[q_idx] + response_text = extract_text_from_brackets(response) + rule_id = question_content['rule_id'] + is_correct = evaluate_response_vs_answer(response, answer, category, + rule_id, q_idx) + if is_correct: + correct_num += 1 + results.append({ + 'question': question, + 'response_text': response_text, + 'answer': answer, + 'is_correct': is_correct + }) + + pass_rate = correct_num / len(question_list) + question_correct = pass_rate == 1.0 + result_dict = { + 'idx': idx, + 'response': response_json, + 'details': results, + 'pass_rate': pass_rate, + 'is_correct': question_correct + } + return result_dict + + +def evaluate_responses(data, mode, base_path=None): + results = [] + + # Iterate over the values of the dictionary (numerical keys) + for key, record in data.items(): + idx = key # Use the dictionary key as the "idx" + response = record.get('prediction', '') + question_type = record.get('category', '') + if mode == 'mixed': + question_list = record.get('question_list') + response_json = extract_json(response) + result_dict = compute_one_mixed_question_pass_rate( + idx, question_list, response_json, base_path) + results.append(result_dict) + else: + response_text = extract_text_from_brackets(response) + answer = record.get('gold', '') + rule_id = record.get('rule_id', '') + is_correct = evaluate_response_vs_answer(response, answer, + question_type, rule_id, + idx) + result_dict = { + 'idx': idx, + 'response': response, + 'response_text': response_text, + 'answer': answer, + 'is_correct': is_correct + } + if question_type == 'counterfactual': + real_life_answer = record.get('real_life_answer', '') + is_real_life = evaluate_response_vs_answer( + response, real_life_answer, question_type, rule_id, idx) + result_dict['real_life_answer'] = real_life_answer + result_dict['is_real_life'] = is_real_life + if question_type == 'cipher' and mode == 'subquestions': + result_dict['type'] = record.get('type', '') + results.append(result_dict) + return results diff --git a/opencompass/openicl/icl_evaluator/icl_korbench_evaluator.py b/opencompass/openicl/icl_evaluator/icl_korbench_evaluator.py new file mode 100644 index 00000000..f51ca40f --- /dev/null +++ b/opencompass/openicl/icl_evaluator/icl_korbench_evaluator.py @@ -0,0 +1,267 @@ +# flake8: noqa +"""KOR-Bench Evaluator.""" + +import json +import os +import re + +from .icl_base_evaluator import BaseEvaluator + + +def read_json_or_jsonl(data_path, split='', mapping_key=None): + base_path = os.path.join(data_path, split) + if os.path.exists(f'{base_path}.json'): + file_path = f'{base_path}.json' + elif os.path.exists(f'{base_path}.jsonl'): + file_path = f'{base_path}.jsonl' + elif base_path.endswith('.json') or base_path.endswith('.jsonl'): + file_path = base_path + else: + raise FileNotFoundError('No JSON or JSONL file found.') + + with open(file_path, 'r') as file: + if file_path.endswith('.json'): + data = json.load(file) + elif file_path.endswith('.jsonl'): + data = [json.loads(line) for line in file] + + if mapping_key: + return { + item[mapping_key]: item + for item in data if mapping_key in item + } + else: + return data + + +def read_json_or_jsonl_with_idx(data_path, split='', idx=None): + base_path = os.path.join(data_path, split) + if os.path.exists(f'{base_path}.json'): + file_path = f'{base_path}.json' + elif os.path.exists(f'{base_path}.jsonl'): + file_path = f'{base_path}.jsonl' + elif base_path.endswith('.json') or base_path.endswith('.jsonl'): + file_path = base_path + else: + raise FileNotFoundError('No JSON or JSONL file found.') + + with open(file_path, 'r', encoding='utf-8') as file: + if file_path.endswith('.json'): + data = json.load(file) + elif file_path.endswith('.jsonl'): + data = [json.loads(line) for line in file] + + if idx is not None: + try: + return next(item for item in data if item.get('idx') == idx) + except StopIteration: + raise ValueError(f'No entry found for idx {idx}') + else: + return data + + +class korbenchEvaluator(BaseEvaluator): + """Evaluator class for KOR-Bench tasks, inheriting from BaseEvaluator. + + This class implements the `score` method to evaluate the model's + predictions against the reference answers, using the evaluation logic + specific to KOR-Bench. + """ + + def __init__(self, question_type, mode): + """Initialize the evaluator with question type and mode. + + Args: + question_type (str): Type of questions (e.g., 'logic', 'operation', 'puzzle'). + mode (str): Evaluation mode (e.g., 'zero-shot', 'self-correction'). + """ + super().__init__() + self.question_type = question_type + self.mode = mode + + # Predefined index ranges for special evaluation cases + self.idx_ranges = [ + [18], + [73, 74, 77], + [94], + [115, 116, 117], + [121, 122, 123, 125], + [131, 132, 134, 135, 136], + [141, 143, 149], + list(range(145, 148)), + list(range(151, 157)), + [160, 161, 162], + [164, 165, 166], + [170], + [206, 209], + list(range(211, 216)), + [217, 218], + ] + + def score(self, predictions, references): + """Evaluates the model's predictions against the references. + + Args: + predictions (list): List of model predictions. + references (list): List of reference answers (each reference is a dict). + + Returns: + list: Evaluation results for each prediction. + """ + if len(predictions) != len(references): + return { + 'error': 'Predictions and references have different lengths' + } + + data = [] + for idx, (prediction, + reference) in enumerate(zip(predictions, references)): + record = { + 'idx': str(idx), + 'response': prediction, + 'answer': reference.get('answer'), + 'rule_id': reference.get('rule_id'), + 'question_type': self.question_type, + # Include other necessary fields from reference if needed + } + data.append(record) + + results = self.evaluate_responses(data, self.question_type, self.mode) + return results + + def evaluate_responses(self, data, question_type, mode): + """Evaluates a list of responses. + + Args: + data (list): List of records containing responses and answers. + question_type (str): Type of questions. + mode (str): Evaluation mode. + + Returns: + list: List of evaluation results. + """ + results = [] + for record in data: + idx = record.get('idx') + response = record.get('response') + answer = record.get('answer') + rule_id = record.get('rule_id') + + response_text = self.extract_text_from_brackets(response) + is_correct = self.evaluate_response_vs_answer( + response, answer, question_type, rule_id, idx) + + result_dict = { + 'idx': idx, + 'response': response, + 'response_text': response_text, + 'answer': answer, + 'is_correct': is_correct + } + results.append(result_dict) + return results + + # Helper methods + + def extract_text_from_brackets(self, text, clean_level='basic'): + """Extracts text enclosed in double brackets [[ ]]. + + Args: + text (str): The text to extract from. + clean_level (str): The level of cleaning to perform. + + Returns: + str: The extracted text or "NULL" if not found. + """ + matches = re.findall(r'\[\[\s*(.*?)\s*\]\]', text, re.DOTALL) + if not matches: + matches = re.findall(r'\$\\boxed\{(.*?)\}\$', text, re.DOTALL) + if not matches: + matches = re.findall(r'\[\s*(.*?)\s*\]', text, re.DOTALL) + if matches: + match_str = matches[0].strip() + if clean_level == 'clean': + match_str = match_str.replace('"', '').replace( + '\n', '').replace(' ', '').replace('[', + '').replace(']', '') + elif clean_level == 'logic': + match_str = match_str.replace('"', + '').replace('\n', '').replace( + ' ', '').replace('.', '') + elif clean_level == 'math': + match_str = match_str.replace('"', '').replace( + '\n', '').replace('[', '').replace(']', + '').replace('$', '') + return f'{self.clean_latex(match_str)}' + return f'[[{match_str}]]' + return 'NULL' + + def clean_latex(self, latex_expr): + """Cleans LaTeX expressions for parsing. + + Args: + latex_expr (str): The LaTeX expression to clean. + + Returns: + str: The cleaned expression. + """ + if '=' in latex_expr: + latex_expr = latex_expr.rsplit('=', 1)[1] + latex_expr = re.sub(r'\\[()\[\]]', '', latex_expr) + latex_expr = re.sub(r'\\text\{.*?\}', '', latex_expr) + latex_expr = re.sub(r'\\(left|right|displaystyle)', '', latex_expr) + latex_expr = latex_expr.replace('\\\\', '\\') + return latex_expr + + def evaluate_response_vs_answer(self, response, answer, question_type, + rule_id, idx): + """Evaluates a single response against the answer. + + Args: + response (str): The model's response. + answer (str): The reference answer. + question_type (str): The question type. + rule_id (str): The rule ID. + idx (str): The index of the question. + + Returns: + bool: True if the response is correct, False otherwise. + """ + if question_type == 'logic' and rule_id == '5': + response_text = self.extract_text_from_brackets(response, 'logic') + answer_text = self.extract_text_from_brackets(answer, 'logic') + if response_text is None: + return False + normalized_response = self.rule5_normalize_content(response_text) + normalized_answer = self.rule5_normalize_content(answer) + return normalized_response == normalized_answer + elif question_type == 'logic': + response_text = self.extract_text_from_brackets(response, 'logic') + answer_text = self.extract_text_from_brackets(answer, 'logic') + return response_text == answer_text + else: + response_text = self.extract_text_from_brackets(response, 'clean') + return response_text == answer + + def rule5_normalize_content(self, content): + """Normalizes content for rule 5. + + Args: + content (str): The content to normalize. + + Returns: + list: Sorted list of content parts. + """ + parts = [part.strip() for part in content.split(';')] + sorted_parts = sorted(parts) + return sorted_parts + + # Additional helper methods can be defined here + # For example: methods to handle mathematical expressions, logic comparisons, etc. + + # Implement other helper functions as per your evaluation logic + + +# Example usage: +# evaluator = korbenchEvaluator(question_type='logic', mode='zero-shot') +# results = evaluator.score(predictions, references) diff --git a/opencompass/utils/datasets_info.py b/opencompass/utils/datasets_info.py index 7d694ff1..8fe89971 100644 --- a/opencompass/utils/datasets_info.py +++ b/opencompass/utils/datasets_info.py @@ -151,6 +151,12 @@ DATASETS_MAPPING = { "hf_id": "opencompass/humaneval", "local": "./data/humaneval_cn/human-eval-cn-v2-20210705.jsonl", }, + #KORBENCH + "opencompass/korbench": { + "ms_id": "", + "hf_id": "", + "local": "./data/korbench", + }, # Lambada "opencompass/lambada": { "ms_id": "opencompass/lambada", @@ -544,4 +550,8 @@ DATASETS_URL = { "url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/babilong.zip", "md5": "e400864c31bc58d29eaa3e199751f99b", }, + "/korbench": { + "url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/korbench.zip", + "md5": "9107597d137e7362eaf7d218ddef7a6d", + }, }