diff --git a/opencompass/configs/datasets/humaneval_pro/README.md b/opencompass/configs/datasets/humaneval_pro/README.md new file mode 100644 index 00000000..853b59f2 --- /dev/null +++ b/opencompass/configs/datasets/humaneval_pro/README.md @@ -0,0 +1,17 @@ +# HumanEval pro + +## OC results + +| model | pass@1 | +|:--------------------------:|---------:| +|qwen2.5-coder-7b-instruct-hf| 65 | +| qwen2.5-14b-instruct-hf | 67 | +| deepseek-v2-lite-chat-hf | 35 | + +## CodeEval-pro results + +| model | pass@1 | +|:--------------------------:|---------:| +|qwen2.5-coder-7b-instruct-hf| 65 | +| qwen2.5-14b-instruct-hf | 65 | +| deepseek-v2-lite-chat-hf | 28 | \ No newline at end of file diff --git a/opencompass/configs/datasets/humaneval_pro/humaneval_pro.py b/opencompass/configs/datasets/humaneval_pro/humaneval_pro.py new file mode 100644 index 00000000..34c73f48 --- /dev/null +++ b/opencompass/configs/datasets/humaneval_pro/humaneval_pro.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .humaneval_pro_gen_ import humanevalpro_datasets # noqa: F401, F403 diff --git a/opencompass/configs/datasets/humaneval_pro/humaneval_pro_gen_3dc067.py b/opencompass/configs/datasets/humaneval_pro/humaneval_pro_gen_3dc067.py new file mode 100644 index 00000000..606cd8b1 --- /dev/null +++ b/opencompass/configs/datasets/humaneval_pro/humaneval_pro_gen_3dc067.py @@ -0,0 +1,56 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import HumanevalevalProDataset, HumanevalProEvaluator, humaneval_postprocess_v2 + +OFFICIAL_PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions. +@@ Instruction +Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution. +```python +{raw_problem} +{new_problem} +``` +@@ Response +Please put the two solutions to the above problems in one Python code block. +""" + +PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions. +Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution. +```python +{raw_problem} +{new_problem} +``` +Please put the two solutions within the Python code block provided below, and make sure that the block contains no other unrelated content: +```python +``` +""" + + +humanevalpro_reader_cfg = dict( + input_columns=['raw_problem', 'new_problem'], output_column='test_code') + +humanevalpro_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict( + role='HUMAN', + prompt=PROMPT_WRAPPER), + ])), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer)) + +humanevalpro_eval_cfg = dict( + evaluator=dict(type=HumanevalProEvaluator, + ip_address='https://opencompass-multiple-evaluator.hf.space') +) + +humanevalpro_datasets = [ + dict( + abbr='humaneval_pro', + type=HumanevalevalProDataset, + path='opencompass/humaneval_pro', + reader_cfg=humanevalpro_reader_cfg, + infer_cfg=humanevalpro_infer_cfg, + eval_cfg=humanevalpro_eval_cfg,) +] \ No newline at end of file diff --git a/opencompass/configs/datasets/mbpp_pro/README.md b/opencompass/configs/datasets/mbpp_pro/README.md new file mode 100644 index 00000000..d34980e1 --- /dev/null +++ b/opencompass/configs/datasets/mbpp_pro/README.md @@ -0,0 +1,17 @@ +# MBPP pro + +## OC results + +| model | pass@1 | +|:--------------------------:|---------:| +|qwen2.5-coder-7b-instruct-hf| 66 | +| qwen2.5-14b-instruct-hf | 64 | +| deepseek-v2-lite-chat-hf | 36 | + +## CodeEval-pro results + +| model | pass@1 | +|:--------------------------:|---------:| +|qwen2.5-coder-7b-instruct-hf| 65 | +| qwen2.5-14b-instruct-hf | 65 | +| deepseek-v2-lite-chat-hf | 39 | \ No newline at end of file diff --git a/opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen.py b/opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen.py new file mode 100644 index 00000000..bded1658 --- /dev/null +++ b/opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .mbpp_pro_gen_ import mbpppro_datasets # noqa: F401, F403 diff --git a/opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen_3dc067.py b/opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen_3dc067.py new file mode 100644 index 00000000..0c8a882d --- /dev/null +++ b/opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen_3dc067.py @@ -0,0 +1,56 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import MBPPProDataset, MBPPProEvaluator + +OFFICIAL_PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions. +@@ Instruction +Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution. +```python +{raw_problem} +{new_problem} +``` +@@ Response +Please put the two solutions to the above problems in one Python code block. +""" + +PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions. +Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution. +```python +{raw_problem} +{new_problem} +``` +Please put the two solutions within the Python code block provided below, and make sure that the block contains no other unrelated content: +```python +``` +""" + + +mbpppro_reader_cfg = dict( + input_columns=['raw_problem', 'new_problem'], output_column='test_code') + +mbpppro_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict( + role='HUMAN', + prompt=PROMPT_WRAPPER), + ])), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer)) + +mbpppro_eval_cfg = dict( + evaluator=dict(type=MBPPProEvaluator, + ip_address='https://opencompass-multiple-evaluator.hf.space'), +) + +mbpppro_datasets = [ + dict( + abbr='mbpp_pro', + type=MBPPProDataset, + path='opencompass/mbpp_pro', + reader_cfg=mbpppro_reader_cfg, + infer_cfg=mbpppro_infer_cfg, + eval_cfg=mbpppro_eval_cfg) +] \ No newline at end of file diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index 220ce030..e3c8e59e 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -63,6 +63,7 @@ from .hle import * # noqa: F401, F403 from .huggingface import * # noqa: F401, F403 from .humaneval import * # noqa: F401, F403 from .humaneval_multi import * # noqa: F401, F403 +from .humaneval_pro import * # noqa: F401, F403 from .humanevalx import * # noqa: F401, F403 from .hungarian_math import * # noqa: F401, F403 from .IFEval.ifeval import IFEvalDataset, IFEvaluator # noqa: F401, F403 @@ -95,6 +96,7 @@ from .math401 import * # noqa: F401, F403 from .math_intern import * # noqa: F401, F403 from .mathbench import * # noqa: F401, F403 from .mbpp import * # noqa: F401, F403 +from .mbpp_pro import * # noqa: F401, F403 from .medbench import * # noqa: F401, F403 from .MedXpertQA import * # noqa: F401, F403 from .mgsm import * # noqa: F401, F403 diff --git a/opencompass/datasets/humaneval_pro.py b/opencompass/datasets/humaneval_pro.py new file mode 100644 index 00000000..310b3b41 --- /dev/null +++ b/opencompass/datasets/humaneval_pro.py @@ -0,0 +1,96 @@ +import json +from typing import Dict, List + +import numpy as np +from datasets import Dataset + +from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator +from opencompass.utils import get_data_path + +from .base import BaseDataset + + +class HumanevalevalProDataset(BaseDataset): + + @staticmethod + def load(path, num_repeats=1, local_mode=False): + path = get_data_path(path, local_mode=local_mode) + dataset = [] + with open(path, encoding='utf-8') as f: + raw_data = json.load(f) + for data in raw_data: + dataset.extend([data for _ in range(num_repeats)]) + return Dataset.from_list(dataset) + + +class HumanevalProEvaluator(CodeEvaluator): + + def _process_completions(self, test_case: dict, completions: list) -> list: + processed_completions = [] + for comp in completions: + post_comp = self._extract_code(comp) + processed_completions.append(post_comp) + return processed_completions + + def score(self, predictions: List, references: List, + test_set: Dataset) -> Dict: + if len(predictions) != len(references): + return { + 'error': + 'predictions and references have different ' + f'length. len(predictions): {len(predictions)}, ' + f'len(references): {len(references)}' + } + + test_set = test_set.to_pandas() + # Use the first column as the unique identifier + test_set_origin = test_set.drop_duplicates(subset=test_set.columns[0]) + num_repeats = int(len(test_set) / len(test_set_origin)) + + # 1. Prepare data for all test cases + all_test_cases = [] + for i in range(len(test_set_origin)): + test_case = test_set_origin.iloc[i] + completions = predictions[i * num_repeats:(i + 1) * num_repeats] + + # Process code completions + processed_completions = self._process_completions( + test_case, completions) + + sub_data_dict = { + 'name': int(test_case['id']), + 'language': self.language, + 'prompt': '', + 'tests': test_case['test_code'], + 'processed_completions': processed_completions, + 'completions': completions + } + + all_test_cases.append(sub_data_dict) + + # 2. Send all test cases to the evaluation service + success, outputs, error_message = self._evaluate(all_test_cases) + if not success: + return {'error': error_message} + + # 3. Process the returned results + details = [] + total, correct = [], [] + for output in outputs: + passed = [m['status'] == 'OK' for m in output['meta_data']] + total.append(len(passed)) + correct.append(sum(passed)) + details.append(output) + total = np.array(total) + correct = np.array(correct) + + pass_at_k = { + f'pass@{k}': + self.estimate_pass_at_k(total, correct, k).mean() * 100 + for k in self.k if (total >= k).all() + } + + return { + **pass_at_k, + 'details': details, + } diff --git a/opencompass/datasets/mbpp_pro.py b/opencompass/datasets/mbpp_pro.py new file mode 100644 index 00000000..51a086d7 --- /dev/null +++ b/opencompass/datasets/mbpp_pro.py @@ -0,0 +1,97 @@ +import json +from typing import Dict, List + +import numpy as np +from datasets import Dataset + +from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator +from opencompass.utils import get_data_path + +from .base import BaseDataset + + +class MBPPProDataset(BaseDataset): + + @staticmethod + def load(path, num_repeats=1, local_mode=False): + path = get_data_path(path, local_mode=local_mode) + print(path) + dataset = [] + with open(path, encoding='utf-8') as f: + for line in f: + dataset.extend( + [json.loads(line.strip()) for _ in range(num_repeats)]) + return Dataset.from_list(dataset) + + +class MBPPProEvaluator(CodeEvaluator): + + def _process_completions(self, test_case: dict, completions: list) -> list: + processed_completions = [] + for comp in completions: + post_comp = self._extract_code(comp) + processed_completions.append(post_comp) + return processed_completions + + def score(self, predictions: List, references: List, + test_set: Dataset) -> Dict: + if len(predictions) != len(references): + return { + 'error': + 'predictions and references have different ' + f'length. len(predictions): {len(predictions)}, ' + f'len(references): {len(references)}' + } + + test_set = test_set.to_pandas() + # Use the first column as the unique identifier + test_set_origin = test_set.drop_duplicates(subset=test_set.columns[0]) + num_repeats = int(len(test_set) / len(test_set_origin)) + + # 1. Prepare data for all test cases + all_test_cases = [] + for i in range(len(test_set_origin)): + test_case = test_set_origin.iloc[i] + completions = predictions[i * num_repeats:(i + 1) * num_repeats] + + # Process code completions + processed_completions = self._process_completions( + test_case, completions) + + sub_data_dict = { + 'name': int(test_case['id']), + 'language': self.language, + 'prompt': '', + 'tests': test_case['test_code'], + 'processed_completions': processed_completions, + 'completions': completions + } + + all_test_cases.append(sub_data_dict) + + # 2. Send all test cases to the evaluation service + success, outputs, error_message = self._evaluate(all_test_cases) + if not success: + return {'error': error_message} + + # 3. Process the returned results + details = [] + total, correct = [], [] + for output in outputs: + passed = [m['status'] == 'OK' for m in output['meta_data']] + total.append(len(passed)) + correct.append(sum(passed)) + details.append(output) + total = np.array(total) + correct = np.array(correct) + + pass_at_k = { + f'pass@{k}': + self.estimate_pass_at_k(total, correct, k).mean() * 100 + for k in self.k if (total >= k).all() + } + + return { + **pass_at_k, + 'details': details, + } diff --git a/opencompass/openicl/icl_evaluator/code_evaluator.py b/opencompass/openicl/icl_evaluator/code_evaluator.py index d586cd6e..b79488cb 100644 --- a/opencompass/openicl/icl_evaluator/code_evaluator.py +++ b/opencompass/openicl/icl_evaluator/code_evaluator.py @@ -1,12 +1,14 @@ # flake8: noqa: E501 import difflib +import itertools import os import re import tempfile import time from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np from datasets import Dataset from gradio_client import Client @@ -24,7 +26,7 @@ class CodeEvaluator(BaseEvaluator): """ def __init__(self, - language: str, + language: str = 'py', ip_address: str = 'localhost', retry: int = 3) -> None: """Initialize the CodeEvaluator. @@ -221,19 +223,18 @@ class CodeEvaluator(BaseEvaluator): test_set = test_set.to_pandas() # Use the first column as the unique identifier test_set_origin = test_set.drop_duplicates(subset=test_set.columns[0]) - num_repeats = int(len(test_set) / len(test_set_origin)) # 1. Prepare data for all test cases all_test_cases = [] for i in range(len(test_set_origin)): test_case = test_set_origin.iloc[i] - completions = predictions[i * num_repeats:(i + 1) * num_repeats] + completions = predictions[i] # Process code completions processed_completions = self._process_completions( test_case, completions) - result_dict = { + sub_data_dict = { 'name': test_case['name'], 'language': test_case['language'], 'prompt': test_case['prompt'], @@ -242,7 +243,7 @@ class CodeEvaluator(BaseEvaluator): 'completions': completions } - all_test_cases.append(result_dict) + all_test_cases.append(sub_data_dict) # 2. Send all test cases to the evaluation service success, outputs, error_message = self._evaluate(all_test_cases) @@ -262,6 +263,6 @@ class CodeEvaluator(BaseEvaluator): details.append(output) return { - f'pass@{num_repeats}': 100 * correct / len(test_set_origin), + f'pass@1': 100 * correct / len(test_set_origin), 'details': details - } + } \ No newline at end of file diff --git a/opencompass/utils/datasets_info.py b/opencompass/utils/datasets_info.py index 10ca4436..ce12af64 100644 --- a/opencompass/utils/datasets_info.py +++ b/opencompass/utils/datasets_info.py @@ -451,7 +451,16 @@ DATASETS_MAPPING = { "hf_id": "", "local": "./data/nejmaibench/NEJM_All_Questions_And_Answers.csv", }, - + "opencompass/humaneval_pro": { + "ms_id": "", + "hf_id": "", + "local": "./data/humaneval_pro/humaneval_pro.json", + }, + "opencompass/mbpp_pro": { + "ms_id": "", + "hf_id": "", + "local": "./data/mbpp_pro/mbpp_pro.json", + }, } DATASETS_URL = { @@ -808,6 +817,13 @@ DATASETS_URL = { "url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/nejmaibench.zip", "md5": "e6082cae3596b3ebea73e23ba445b99e" - } - + }, + "humaneval_pro": { + "url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/humaneval_pro.zip", + "md5": "4c6fe556e84e905e4f0902d699e46de5", + }, + "mbpp_pro": { + "url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mbpp_pro.zip", + "md5": "eac330b8a0a8687f006265c9383503ce", + }, }