mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
add bench
This commit is contained in:
parent
a685ed7daf
commit
d939e32438
17
opencompass/configs/datasets/humaneval_pro/README.md
Normal file
17
opencompass/configs/datasets/humaneval_pro/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
# HumanEval pro
|
||||
|
||||
## OC results
|
||||
|
||||
| model | pass@1 |
|
||||
|:--------------------------:|---------:|
|
||||
|qwen2.5-coder-7b-instruct-hf| 65 |
|
||||
| qwen2.5-14b-instruct-hf | 67 |
|
||||
| deepseek-v2-lite-chat-hf | 35 |
|
||||
|
||||
## CodeEval-pro results
|
||||
|
||||
| model | pass@1 |
|
||||
|:--------------------------:|---------:|
|
||||
|qwen2.5-coder-7b-instruct-hf| 65 |
|
||||
| qwen2.5-14b-instruct-hf | 65 |
|
||||
| deepseek-v2-lite-chat-hf | 28 |
|
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .humaneval_pro_gen_ import humanevalpro_datasets # noqa: F401, F403
|
@ -0,0 +1,56 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import HumanevalevalProDataset, HumanevalProEvaluator, humaneval_postprocess_v2
|
||||
|
||||
OFFICIAL_PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.
|
||||
@@ Instruction
|
||||
Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution.
|
||||
```python
|
||||
{raw_problem}
|
||||
{new_problem}
|
||||
```
|
||||
@@ Response
|
||||
Please put the two solutions to the above problems in one Python code block.
|
||||
"""
|
||||
|
||||
PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.
|
||||
Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution.
|
||||
```python
|
||||
{raw_problem}
|
||||
{new_problem}
|
||||
```
|
||||
Please put the two solutions within the Python code block provided below, and make sure that the block contains no other unrelated content:
|
||||
```python
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
humanevalpro_reader_cfg = dict(
|
||||
input_columns=['raw_problem', 'new_problem'], output_column='test_code')
|
||||
|
||||
humanevalpro_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt=PROMPT_WRAPPER),
|
||||
])),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
humanevalpro_eval_cfg = dict(
|
||||
evaluator=dict(type=HumanevalProEvaluator,
|
||||
ip_address='https://opencompass-multiple-evaluator.hf.space')
|
||||
)
|
||||
|
||||
humanevalpro_datasets = [
|
||||
dict(
|
||||
abbr='humaneval_pro',
|
||||
type=HumanevalevalProDataset,
|
||||
path='opencompass/humaneval_pro',
|
||||
reader_cfg=humanevalpro_reader_cfg,
|
||||
infer_cfg=humanevalpro_infer_cfg,
|
||||
eval_cfg=humanevalpro_eval_cfg,)
|
||||
]
|
17
opencompass/configs/datasets/mbpp_pro/README.md
Normal file
17
opencompass/configs/datasets/mbpp_pro/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
# MBPP pro
|
||||
|
||||
## OC results
|
||||
|
||||
| model | pass@1 |
|
||||
|:--------------------------:|---------:|
|
||||
|qwen2.5-coder-7b-instruct-hf| 66 |
|
||||
| qwen2.5-14b-instruct-hf | 64 |
|
||||
| deepseek-v2-lite-chat-hf | 36 |
|
||||
|
||||
## CodeEval-pro results
|
||||
|
||||
| model | pass@1 |
|
||||
|:--------------------------:|---------:|
|
||||
|qwen2.5-coder-7b-instruct-hf| 65 |
|
||||
| qwen2.5-14b-instruct-hf | 65 |
|
||||
| deepseek-v2-lite-chat-hf | 39 |
|
4
opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen.py
Normal file
4
opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .mbpp_pro_gen_ import mbpppro_datasets # noqa: F401, F403
|
56
opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen_3dc067.py
Normal file
56
opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen_3dc067.py
Normal file
@ -0,0 +1,56 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import MBPPProDataset, MBPPProEvaluator
|
||||
|
||||
OFFICIAL_PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.
|
||||
@@ Instruction
|
||||
Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution.
|
||||
```python
|
||||
{raw_problem}
|
||||
{new_problem}
|
||||
```
|
||||
@@ Response
|
||||
Please put the two solutions to the above problems in one Python code block.
|
||||
"""
|
||||
|
||||
PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.
|
||||
Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution.
|
||||
```python
|
||||
{raw_problem}
|
||||
{new_problem}
|
||||
```
|
||||
Please put the two solutions within the Python code block provided below, and make sure that the block contains no other unrelated content:
|
||||
```python
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
mbpppro_reader_cfg = dict(
|
||||
input_columns=['raw_problem', 'new_problem'], output_column='test_code')
|
||||
|
||||
mbpppro_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt=PROMPT_WRAPPER),
|
||||
])),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
mbpppro_eval_cfg = dict(
|
||||
evaluator=dict(type=MBPPProEvaluator,
|
||||
ip_address='https://opencompass-multiple-evaluator.hf.space'),
|
||||
)
|
||||
|
||||
mbpppro_datasets = [
|
||||
dict(
|
||||
abbr='mbpp_pro',
|
||||
type=MBPPProDataset,
|
||||
path='opencompass/mbpp_pro',
|
||||
reader_cfg=mbpppro_reader_cfg,
|
||||
infer_cfg=mbpppro_infer_cfg,
|
||||
eval_cfg=mbpppro_eval_cfg)
|
||||
]
|
@ -63,6 +63,7 @@ from .hle import * # noqa: F401, F403
|
||||
from .huggingface import * # noqa: F401, F403
|
||||
from .humaneval import * # noqa: F401, F403
|
||||
from .humaneval_multi import * # noqa: F401, F403
|
||||
from .humaneval_pro import * # noqa: F401, F403
|
||||
from .humanevalx import * # noqa: F401, F403
|
||||
from .hungarian_math import * # noqa: F401, F403
|
||||
from .IFEval.ifeval import IFEvalDataset, IFEvaluator # noqa: F401, F403
|
||||
@ -95,6 +96,7 @@ from .math401 import * # noqa: F401, F403
|
||||
from .math_intern import * # noqa: F401, F403
|
||||
from .mathbench import * # noqa: F401, F403
|
||||
from .mbpp import * # noqa: F401, F403
|
||||
from .mbpp_pro import * # noqa: F401, F403
|
||||
from .medbench import * # noqa: F401, F403
|
||||
from .MedXpertQA import * # noqa: F401, F403
|
||||
from .mgsm import * # noqa: F401, F403
|
||||
|
96
opencompass/datasets/humaneval_pro.py
Normal file
96
opencompass/datasets/humaneval_pro.py
Normal file
@ -0,0 +1,96 @@
|
||||
import json
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator
|
||||
from opencompass.utils import get_data_path
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
class HumanevalevalProDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path, num_repeats=1, local_mode=False):
|
||||
path = get_data_path(path, local_mode=local_mode)
|
||||
dataset = []
|
||||
with open(path, encoding='utf-8') as f:
|
||||
raw_data = json.load(f)
|
||||
for data in raw_data:
|
||||
dataset.extend([data for _ in range(num_repeats)])
|
||||
return Dataset.from_list(dataset)
|
||||
|
||||
|
||||
class HumanevalProEvaluator(CodeEvaluator):
|
||||
|
||||
def _process_completions(self, test_case: dict, completions: list) -> list:
|
||||
processed_completions = []
|
||||
for comp in completions:
|
||||
post_comp = self._extract_code(comp)
|
||||
processed_completions.append(post_comp)
|
||||
return processed_completions
|
||||
|
||||
def score(self, predictions: List, references: List,
|
||||
test_set: Dataset) -> Dict:
|
||||
if len(predictions) != len(references):
|
||||
return {
|
||||
'error':
|
||||
'predictions and references have different '
|
||||
f'length. len(predictions): {len(predictions)}, '
|
||||
f'len(references): {len(references)}'
|
||||
}
|
||||
|
||||
test_set = test_set.to_pandas()
|
||||
# Use the first column as the unique identifier
|
||||
test_set_origin = test_set.drop_duplicates(subset=test_set.columns[0])
|
||||
num_repeats = int(len(test_set) / len(test_set_origin))
|
||||
|
||||
# 1. Prepare data for all test cases
|
||||
all_test_cases = []
|
||||
for i in range(len(test_set_origin)):
|
||||
test_case = test_set_origin.iloc[i]
|
||||
completions = predictions[i * num_repeats:(i + 1) * num_repeats]
|
||||
|
||||
# Process code completions
|
||||
processed_completions = self._process_completions(
|
||||
test_case, completions)
|
||||
|
||||
sub_data_dict = {
|
||||
'name': int(test_case['id']),
|
||||
'language': self.language,
|
||||
'prompt': '',
|
||||
'tests': test_case['test_code'],
|
||||
'processed_completions': processed_completions,
|
||||
'completions': completions
|
||||
}
|
||||
|
||||
all_test_cases.append(sub_data_dict)
|
||||
|
||||
# 2. Send all test cases to the evaluation service
|
||||
success, outputs, error_message = self._evaluate(all_test_cases)
|
||||
if not success:
|
||||
return {'error': error_message}
|
||||
|
||||
# 3. Process the returned results
|
||||
details = []
|
||||
total, correct = [], []
|
||||
for output in outputs:
|
||||
passed = [m['status'] == 'OK' for m in output['meta_data']]
|
||||
total.append(len(passed))
|
||||
correct.append(sum(passed))
|
||||
details.append(output)
|
||||
total = np.array(total)
|
||||
correct = np.array(correct)
|
||||
|
||||
pass_at_k = {
|
||||
f'pass@{k}':
|
||||
self.estimate_pass_at_k(total, correct, k).mean() * 100
|
||||
for k in self.k if (total >= k).all()
|
||||
}
|
||||
|
||||
return {
|
||||
**pass_at_k,
|
||||
'details': details,
|
||||
}
|
97
opencompass/datasets/mbpp_pro.py
Normal file
97
opencompass/datasets/mbpp_pro.py
Normal file
@ -0,0 +1,97 @@
|
||||
import json
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator
|
||||
from opencompass.utils import get_data_path
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
class MBPPProDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path, num_repeats=1, local_mode=False):
|
||||
path = get_data_path(path, local_mode=local_mode)
|
||||
print(path)
|
||||
dataset = []
|
||||
with open(path, encoding='utf-8') as f:
|
||||
for line in f:
|
||||
dataset.extend(
|
||||
[json.loads(line.strip()) for _ in range(num_repeats)])
|
||||
return Dataset.from_list(dataset)
|
||||
|
||||
|
||||
class MBPPProEvaluator(CodeEvaluator):
|
||||
|
||||
def _process_completions(self, test_case: dict, completions: list) -> list:
|
||||
processed_completions = []
|
||||
for comp in completions:
|
||||
post_comp = self._extract_code(comp)
|
||||
processed_completions.append(post_comp)
|
||||
return processed_completions
|
||||
|
||||
def score(self, predictions: List, references: List,
|
||||
test_set: Dataset) -> Dict:
|
||||
if len(predictions) != len(references):
|
||||
return {
|
||||
'error':
|
||||
'predictions and references have different '
|
||||
f'length. len(predictions): {len(predictions)}, '
|
||||
f'len(references): {len(references)}'
|
||||
}
|
||||
|
||||
test_set = test_set.to_pandas()
|
||||
# Use the first column as the unique identifier
|
||||
test_set_origin = test_set.drop_duplicates(subset=test_set.columns[0])
|
||||
num_repeats = int(len(test_set) / len(test_set_origin))
|
||||
|
||||
# 1. Prepare data for all test cases
|
||||
all_test_cases = []
|
||||
for i in range(len(test_set_origin)):
|
||||
test_case = test_set_origin.iloc[i]
|
||||
completions = predictions[i * num_repeats:(i + 1) * num_repeats]
|
||||
|
||||
# Process code completions
|
||||
processed_completions = self._process_completions(
|
||||
test_case, completions)
|
||||
|
||||
sub_data_dict = {
|
||||
'name': int(test_case['id']),
|
||||
'language': self.language,
|
||||
'prompt': '',
|
||||
'tests': test_case['test_code'],
|
||||
'processed_completions': processed_completions,
|
||||
'completions': completions
|
||||
}
|
||||
|
||||
all_test_cases.append(sub_data_dict)
|
||||
|
||||
# 2. Send all test cases to the evaluation service
|
||||
success, outputs, error_message = self._evaluate(all_test_cases)
|
||||
if not success:
|
||||
return {'error': error_message}
|
||||
|
||||
# 3. Process the returned results
|
||||
details = []
|
||||
total, correct = [], []
|
||||
for output in outputs:
|
||||
passed = [m['status'] == 'OK' for m in output['meta_data']]
|
||||
total.append(len(passed))
|
||||
correct.append(sum(passed))
|
||||
details.append(output)
|
||||
total = np.array(total)
|
||||
correct = np.array(correct)
|
||||
|
||||
pass_at_k = {
|
||||
f'pass@{k}':
|
||||
self.estimate_pass_at_k(total, correct, k).mean() * 100
|
||||
for k in self.k if (total >= k).all()
|
||||
}
|
||||
|
||||
return {
|
||||
**pass_at_k,
|
||||
'details': details,
|
||||
}
|
@ -1,12 +1,14 @@
|
||||
# flake8: noqa: E501
|
||||
|
||||
import difflib
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
from gradio_client import Client
|
||||
|
||||
@ -24,7 +26,7 @@ class CodeEvaluator(BaseEvaluator):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
language: str,
|
||||
language: str = 'py',
|
||||
ip_address: str = 'localhost',
|
||||
retry: int = 3) -> None:
|
||||
"""Initialize the CodeEvaluator.
|
||||
@ -221,19 +223,18 @@ class CodeEvaluator(BaseEvaluator):
|
||||
test_set = test_set.to_pandas()
|
||||
# Use the first column as the unique identifier
|
||||
test_set_origin = test_set.drop_duplicates(subset=test_set.columns[0])
|
||||
num_repeats = int(len(test_set) / len(test_set_origin))
|
||||
|
||||
# 1. Prepare data for all test cases
|
||||
all_test_cases = []
|
||||
for i in range(len(test_set_origin)):
|
||||
test_case = test_set_origin.iloc[i]
|
||||
completions = predictions[i * num_repeats:(i + 1) * num_repeats]
|
||||
completions = predictions[i]
|
||||
|
||||
# Process code completions
|
||||
processed_completions = self._process_completions(
|
||||
test_case, completions)
|
||||
|
||||
result_dict = {
|
||||
sub_data_dict = {
|
||||
'name': test_case['name'],
|
||||
'language': test_case['language'],
|
||||
'prompt': test_case['prompt'],
|
||||
@ -242,7 +243,7 @@ class CodeEvaluator(BaseEvaluator):
|
||||
'completions': completions
|
||||
}
|
||||
|
||||
all_test_cases.append(result_dict)
|
||||
all_test_cases.append(sub_data_dict)
|
||||
|
||||
# 2. Send all test cases to the evaluation service
|
||||
success, outputs, error_message = self._evaluate(all_test_cases)
|
||||
@ -262,6 +263,6 @@ class CodeEvaluator(BaseEvaluator):
|
||||
details.append(output)
|
||||
|
||||
return {
|
||||
f'pass@{num_repeats}': 100 * correct / len(test_set_origin),
|
||||
f'pass@1': 100 * correct / len(test_set_origin),
|
||||
'details': details
|
||||
}
|
||||
}
|
@ -451,7 +451,16 @@ DATASETS_MAPPING = {
|
||||
"hf_id": "",
|
||||
"local": "./data/nejmaibench/NEJM_All_Questions_And_Answers.csv",
|
||||
},
|
||||
|
||||
"opencompass/humaneval_pro": {
|
||||
"ms_id": "",
|
||||
"hf_id": "",
|
||||
"local": "./data/humaneval_pro/humaneval_pro.json",
|
||||
},
|
||||
"opencompass/mbpp_pro": {
|
||||
"ms_id": "",
|
||||
"hf_id": "",
|
||||
"local": "./data/mbpp_pro/mbpp_pro.json",
|
||||
},
|
||||
}
|
||||
|
||||
DATASETS_URL = {
|
||||
@ -808,6 +817,13 @@ DATASETS_URL = {
|
||||
"url":
|
||||
"http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/nejmaibench.zip",
|
||||
"md5": "e6082cae3596b3ebea73e23ba445b99e"
|
||||
}
|
||||
|
||||
},
|
||||
"humaneval_pro": {
|
||||
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/humaneval_pro.zip",
|
||||
"md5": "4c6fe556e84e905e4f0902d699e46de5",
|
||||
},
|
||||
"mbpp_pro": {
|
||||
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mbpp_pro.zip",
|
||||
"md5": "eac330b8a0a8687f006265c9383503ce",
|
||||
},
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user