add bench

This commit is contained in:
Dongsheng Zhu 2025-05-09 02:36:39 +00:00
parent a685ed7daf
commit d939e32438
11 changed files with 376 additions and 10 deletions

View File

@ -0,0 +1,17 @@
# HumanEval pro
## OC results
| model | pass@1 |
|:--------------------------:|---------:|
|qwen2.5-coder-7b-instruct-hf| 65 |
| qwen2.5-14b-instruct-hf | 67 |
| deepseek-v2-lite-chat-hf | 35 |
## CodeEval-pro results
| model | pass@1 |
|:--------------------------:|---------:|
|qwen2.5-coder-7b-instruct-hf| 65 |
| qwen2.5-14b-instruct-hf | 65 |
| deepseek-v2-lite-chat-hf | 28 |

View File

@ -0,0 +1,4 @@
from mmengine.config import read_base
with read_base():
from .humaneval_pro_gen_ import humanevalpro_datasets # noqa: F401, F403

View File

@ -0,0 +1,56 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import HumanevalevalProDataset, HumanevalProEvaluator, humaneval_postprocess_v2
OFFICIAL_PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.
@@ Instruction
Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution.
```python
{raw_problem}
{new_problem}
```
@@ Response
Please put the two solutions to the above problems in one Python code block.
"""
PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.
Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution.
```python
{raw_problem}
{new_problem}
```
Please put the two solutions within the Python code block provided below, and make sure that the block contains no other unrelated content:
```python
```
"""
humanevalpro_reader_cfg = dict(
input_columns=['raw_problem', 'new_problem'], output_column='test_code')
humanevalpro_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt=PROMPT_WRAPPER),
])),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer))
humanevalpro_eval_cfg = dict(
evaluator=dict(type=HumanevalProEvaluator,
ip_address='https://opencompass-multiple-evaluator.hf.space')
)
humanevalpro_datasets = [
dict(
abbr='humaneval_pro',
type=HumanevalevalProDataset,
path='opencompass/humaneval_pro',
reader_cfg=humanevalpro_reader_cfg,
infer_cfg=humanevalpro_infer_cfg,
eval_cfg=humanevalpro_eval_cfg,)
]

View File

@ -0,0 +1,17 @@
# MBPP pro
## OC results
| model | pass@1 |
|:--------------------------:|---------:|
|qwen2.5-coder-7b-instruct-hf| 66 |
| qwen2.5-14b-instruct-hf | 64 |
| deepseek-v2-lite-chat-hf | 36 |
## CodeEval-pro results
| model | pass@1 |
|:--------------------------:|---------:|
|qwen2.5-coder-7b-instruct-hf| 65 |
| qwen2.5-14b-instruct-hf | 65 |
| deepseek-v2-lite-chat-hf | 39 |

View File

@ -0,0 +1,4 @@
from mmengine.config import read_base
with read_base():
from .mbpp_pro_gen_ import mbpppro_datasets # noqa: F401, F403

View File

@ -0,0 +1,56 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import MBPPProDataset, MBPPProEvaluator
OFFICIAL_PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.
@@ Instruction
Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution.
```python
{raw_problem}
{new_problem}
```
@@ Response
Please put the two solutions to the above problems in one Python code block.
"""
PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.
Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution.
```python
{raw_problem}
{new_problem}
```
Please put the two solutions within the Python code block provided below, and make sure that the block contains no other unrelated content:
```python
```
"""
mbpppro_reader_cfg = dict(
input_columns=['raw_problem', 'new_problem'], output_column='test_code')
mbpppro_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt=PROMPT_WRAPPER),
])),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer))
mbpppro_eval_cfg = dict(
evaluator=dict(type=MBPPProEvaluator,
ip_address='https://opencompass-multiple-evaluator.hf.space'),
)
mbpppro_datasets = [
dict(
abbr='mbpp_pro',
type=MBPPProDataset,
path='opencompass/mbpp_pro',
reader_cfg=mbpppro_reader_cfg,
infer_cfg=mbpppro_infer_cfg,
eval_cfg=mbpppro_eval_cfg)
]

View File

@ -63,6 +63,7 @@ from .hle import * # noqa: F401, F403
from .huggingface import * # noqa: F401, F403
from .humaneval import * # noqa: F401, F403
from .humaneval_multi import * # noqa: F401, F403
from .humaneval_pro import * # noqa: F401, F403
from .humanevalx import * # noqa: F401, F403
from .hungarian_math import * # noqa: F401, F403
from .IFEval.ifeval import IFEvalDataset, IFEvaluator # noqa: F401, F403
@ -95,6 +96,7 @@ from .math401 import * # noqa: F401, F403
from .math_intern import * # noqa: F401, F403
from .mathbench import * # noqa: F401, F403
from .mbpp import * # noqa: F401, F403
from .mbpp_pro import * # noqa: F401, F403
from .medbench import * # noqa: F401, F403
from .MedXpertQA import * # noqa: F401, F403
from .mgsm import * # noqa: F401, F403

View File

@ -0,0 +1,96 @@
import json
from typing import Dict, List
import numpy as np
from datasets import Dataset
from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator
from opencompass.utils import get_data_path
from .base import BaseDataset
class HumanevalevalProDataset(BaseDataset):
@staticmethod
def load(path, num_repeats=1, local_mode=False):
path = get_data_path(path, local_mode=local_mode)
dataset = []
with open(path, encoding='utf-8') as f:
raw_data = json.load(f)
for data in raw_data:
dataset.extend([data for _ in range(num_repeats)])
return Dataset.from_list(dataset)
class HumanevalProEvaluator(CodeEvaluator):
def _process_completions(self, test_case: dict, completions: list) -> list:
processed_completions = []
for comp in completions:
post_comp = self._extract_code(comp)
processed_completions.append(post_comp)
return processed_completions
def score(self, predictions: List, references: List,
test_set: Dataset) -> Dict:
if len(predictions) != len(references):
return {
'error':
'predictions and references have different '
f'length. len(predictions): {len(predictions)}, '
f'len(references): {len(references)}'
}
test_set = test_set.to_pandas()
# Use the first column as the unique identifier
test_set_origin = test_set.drop_duplicates(subset=test_set.columns[0])
num_repeats = int(len(test_set) / len(test_set_origin))
# 1. Prepare data for all test cases
all_test_cases = []
for i in range(len(test_set_origin)):
test_case = test_set_origin.iloc[i]
completions = predictions[i * num_repeats:(i + 1) * num_repeats]
# Process code completions
processed_completions = self._process_completions(
test_case, completions)
sub_data_dict = {
'name': int(test_case['id']),
'language': self.language,
'prompt': '',
'tests': test_case['test_code'],
'processed_completions': processed_completions,
'completions': completions
}
all_test_cases.append(sub_data_dict)
# 2. Send all test cases to the evaluation service
success, outputs, error_message = self._evaluate(all_test_cases)
if not success:
return {'error': error_message}
# 3. Process the returned results
details = []
total, correct = [], []
for output in outputs:
passed = [m['status'] == 'OK' for m in output['meta_data']]
total.append(len(passed))
correct.append(sum(passed))
details.append(output)
total = np.array(total)
correct = np.array(correct)
pass_at_k = {
f'pass@{k}':
self.estimate_pass_at_k(total, correct, k).mean() * 100
for k in self.k if (total >= k).all()
}
return {
**pass_at_k,
'details': details,
}

View File

@ -0,0 +1,97 @@
import json
from typing import Dict, List
import numpy as np
from datasets import Dataset
from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator
from opencompass.utils import get_data_path
from .base import BaseDataset
class MBPPProDataset(BaseDataset):
@staticmethod
def load(path, num_repeats=1, local_mode=False):
path = get_data_path(path, local_mode=local_mode)
print(path)
dataset = []
with open(path, encoding='utf-8') as f:
for line in f:
dataset.extend(
[json.loads(line.strip()) for _ in range(num_repeats)])
return Dataset.from_list(dataset)
class MBPPProEvaluator(CodeEvaluator):
def _process_completions(self, test_case: dict, completions: list) -> list:
processed_completions = []
for comp in completions:
post_comp = self._extract_code(comp)
processed_completions.append(post_comp)
return processed_completions
def score(self, predictions: List, references: List,
test_set: Dataset) -> Dict:
if len(predictions) != len(references):
return {
'error':
'predictions and references have different '
f'length. len(predictions): {len(predictions)}, '
f'len(references): {len(references)}'
}
test_set = test_set.to_pandas()
# Use the first column as the unique identifier
test_set_origin = test_set.drop_duplicates(subset=test_set.columns[0])
num_repeats = int(len(test_set) / len(test_set_origin))
# 1. Prepare data for all test cases
all_test_cases = []
for i in range(len(test_set_origin)):
test_case = test_set_origin.iloc[i]
completions = predictions[i * num_repeats:(i + 1) * num_repeats]
# Process code completions
processed_completions = self._process_completions(
test_case, completions)
sub_data_dict = {
'name': int(test_case['id']),
'language': self.language,
'prompt': '',
'tests': test_case['test_code'],
'processed_completions': processed_completions,
'completions': completions
}
all_test_cases.append(sub_data_dict)
# 2. Send all test cases to the evaluation service
success, outputs, error_message = self._evaluate(all_test_cases)
if not success:
return {'error': error_message}
# 3. Process the returned results
details = []
total, correct = [], []
for output in outputs:
passed = [m['status'] == 'OK' for m in output['meta_data']]
total.append(len(passed))
correct.append(sum(passed))
details.append(output)
total = np.array(total)
correct = np.array(correct)
pass_at_k = {
f'pass@{k}':
self.estimate_pass_at_k(total, correct, k).mean() * 100
for k in self.k if (total >= k).all()
}
return {
**pass_at_k,
'details': details,
}

View File

@ -1,12 +1,14 @@
# flake8: noqa: E501
import difflib
import itertools
import os
import re
import tempfile
import time
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from datasets import Dataset
from gradio_client import Client
@ -24,7 +26,7 @@ class CodeEvaluator(BaseEvaluator):
"""
def __init__(self,
language: str,
language: str = 'py',
ip_address: str = 'localhost',
retry: int = 3) -> None:
"""Initialize the CodeEvaluator.
@ -221,19 +223,18 @@ class CodeEvaluator(BaseEvaluator):
test_set = test_set.to_pandas()
# Use the first column as the unique identifier
test_set_origin = test_set.drop_duplicates(subset=test_set.columns[0])
num_repeats = int(len(test_set) / len(test_set_origin))
# 1. Prepare data for all test cases
all_test_cases = []
for i in range(len(test_set_origin)):
test_case = test_set_origin.iloc[i]
completions = predictions[i * num_repeats:(i + 1) * num_repeats]
completions = predictions[i]
# Process code completions
processed_completions = self._process_completions(
test_case, completions)
result_dict = {
sub_data_dict = {
'name': test_case['name'],
'language': test_case['language'],
'prompt': test_case['prompt'],
@ -242,7 +243,7 @@ class CodeEvaluator(BaseEvaluator):
'completions': completions
}
all_test_cases.append(result_dict)
all_test_cases.append(sub_data_dict)
# 2. Send all test cases to the evaluation service
success, outputs, error_message = self._evaluate(all_test_cases)
@ -262,6 +263,6 @@ class CodeEvaluator(BaseEvaluator):
details.append(output)
return {
f'pass@{num_repeats}': 100 * correct / len(test_set_origin),
f'pass@1': 100 * correct / len(test_set_origin),
'details': details
}
}

View File

@ -451,7 +451,16 @@ DATASETS_MAPPING = {
"hf_id": "",
"local": "./data/nejmaibench/NEJM_All_Questions_And_Answers.csv",
},
"opencompass/humaneval_pro": {
"ms_id": "",
"hf_id": "",
"local": "./data/humaneval_pro/humaneval_pro.json",
},
"opencompass/mbpp_pro": {
"ms_id": "",
"hf_id": "",
"local": "./data/mbpp_pro/mbpp_pro.json",
},
}
DATASETS_URL = {
@ -808,6 +817,13 @@ DATASETS_URL = {
"url":
"http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/nejmaibench.zip",
"md5": "e6082cae3596b3ebea73e23ba445b99e"
}
},
"humaneval_pro": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/humaneval_pro.zip",
"md5": "4c6fe556e84e905e4f0902d699e46de5",
},
"mbpp_pro": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mbpp_pro.zip",
"md5": "eac330b8a0a8687f006265c9383503ce",
},
}