mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
multiple_code develop
This commit is contained in:
parent
b9de8b0e2b
commit
373a0cba9b
@ -0,0 +1,56 @@
|
||||
# Select the 10 most popular programming languages from MultiPL-E to compose the test set.
|
||||
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import MultiplEDataset, MultiplEEvaluator
|
||||
|
||||
|
||||
_TOP_TEN_LANGUAGE_ = ['cpp', 'cs', 'go', 'java', 'rb', 'js', 'php', 'r', 'rs', 'sh']
|
||||
|
||||
multiple_reader_cfg = dict(input_columns=['language', 'prompt'], output_column='tests')
|
||||
|
||||
multiple_infer_cfg = dict(
|
||||
prompt_template=dict(type=PromptTemplate, template='Based on the provided {language} code snippet, complete the subsequent content. The initial part of the completed code must match the provided code snippet exactly:\n{prompt}'),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=2048),
|
||||
)
|
||||
|
||||
multiple_eval_cfg = {
|
||||
lang: dict(
|
||||
evaluator=dict(
|
||||
type=MultiplEEvaluator,
|
||||
language=lang,
|
||||
ip_address='https://dongsheng-docker-test.hf.space', # https://opencompass-multiple-evaluator.hf.space
|
||||
),
|
||||
pred_role='BOT',
|
||||
) for lang in _TOP_TEN_LANGUAGE_
|
||||
}
|
||||
|
||||
multiple_datasets = [
|
||||
dict(
|
||||
type=MultiplEDataset,
|
||||
abbr=f'humaneval-multiple-{lang}',
|
||||
language=lang,
|
||||
num_repeats=1,
|
||||
path='opencompass/multipl_e',
|
||||
tag='humaneval',
|
||||
reader_cfg=multiple_reader_cfg,
|
||||
infer_cfg=multiple_infer_cfg,
|
||||
eval_cfg=multiple_eval_cfg[lang],
|
||||
) for lang in _TOP_TEN_LANGUAGE_
|
||||
]
|
||||
|
||||
multiple_datasets += [
|
||||
dict(
|
||||
type=MultiplEDataset,
|
||||
abbr=f'mbpp-multiple-{lang}',
|
||||
language=lang,
|
||||
num_repeats=1,
|
||||
path='opencompass/multipl_e',
|
||||
tag='mbpp',
|
||||
reader_cfg=multiple_reader_cfg,
|
||||
infer_cfg=multiple_infer_cfg,
|
||||
eval_cfg=multiple_eval_cfg[lang],
|
||||
) for lang in _TOP_TEN_LANGUAGE_
|
||||
]
|
12
opencompass/configs/models/phi/hf_phi_4.py
Normal file
12
opencompass/configs/models/phi/hf_phi_4.py
Normal file
@ -0,0 +1,12 @@
|
||||
from opencompass.models import HuggingFacewithChatTemplate
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFacewithChatTemplate,
|
||||
abbr='phi-4',
|
||||
path='microsoft/phi-4',
|
||||
max_out_len=1024,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=2),
|
||||
)
|
||||
]
|
@ -98,6 +98,7 @@ from .mmlu_cf import * # noqa: F401, F403
|
||||
from .mmlu_pro import * # noqa: F401, F403
|
||||
from .MMLUArabic import * # noqa: F401, F403
|
||||
from .mmmlu import * # noqa: F401, F403
|
||||
from .multipl_e import * # noqa: F401, F403
|
||||
from .multirc import * # noqa: F401, F403
|
||||
from .musr import * # noqa: F401, F403
|
||||
from .narrativeqa import * # noqa: F401, F403
|
||||
|
@ -183,6 +183,33 @@ class CustomDataset(BaseDataset):
|
||||
return Dataset.from_list(data)
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class CodeCustomDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path, file_name=None, local_mode=False, num_repeats=1, **kwargs):
|
||||
path = get_data_path(path, local_mode=local_mode)
|
||||
if file_name is not None:
|
||||
path = os.path.join(path, file_name)
|
||||
data = []
|
||||
if path.endswith('.jsonl'):
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
data.extend(
|
||||
[json.loads(line.strip()) for _ in range(num_repeats)])
|
||||
elif path.endswith('.csv'):
|
||||
with open(path, 'r', encoding='utf-8-sig') as f:
|
||||
reader = csv.reader(f)
|
||||
header = next(reader)
|
||||
for row in reader:
|
||||
data.extend(
|
||||
[dict(zip(header, row)) for _ in range(num_repeats)])
|
||||
else:
|
||||
raise ValueError(f'Unsupported file format: {path}')
|
||||
|
||||
return Dataset.from_list(data)
|
||||
|
||||
|
||||
class CircularCustomDataset(CustomDataset, metaclass=CircularDatasetMeta):
|
||||
dataset_class = CustomDataset
|
||||
|
||||
|
91
opencompass/datasets/multipl_e.py
Normal file
91
opencompass/datasets/multipl_e.py
Normal file
@ -0,0 +1,91 @@
|
||||
import json
|
||||
import os.path as osp
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
from opencompass.utils import get_data_path
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
# currently supporting languages
|
||||
_HUMANEVAL_LANGUAGE_ = [
|
||||
'adb', 'clj', 'cpp', 'cs', 'd', 'dart', 'elixir', 'go', 'hs', 'java', 'jl',
|
||||
'js', 'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala',
|
||||
'sh', 'swift', 'ts'
|
||||
]
|
||||
_MBPP_LANGUAGE_ = [
|
||||
'adb', 'clj', 'cpp', 'cs', 'd', 'elixir', 'go', 'hs', 'java', 'jl', 'js',
|
||||
'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh',
|
||||
'swift', 'ts'
|
||||
]
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class MultiplEDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str,
|
||||
language: str,
|
||||
num_repeats: int = 1,
|
||||
tag: str = 'humaneval',
|
||||
local_mode: bool = False):
|
||||
"""Load humaneval dataset for pass k mode.
|
||||
|
||||
Note that you can use num_repeats > 1 when your model does not support
|
||||
`num_return_sequence` in generation, otherwise use the raw
|
||||
humaneval dataset and set `num_return_sequence` in model config to
|
||||
generate multiple responses for testing pass@k>1.
|
||||
|
||||
It better to change your dataset abbr correspondingly if you want to
|
||||
change num_repeats>1, otherwise the number in
|
||||
`.cache/dataset_size.json` might be inconsistent.
|
||||
|
||||
Args:
|
||||
num_repeats(int): Number of repetition for this dataset to get
|
||||
multiple responses in special cases.
|
||||
"""
|
||||
path = get_data_path(path, local_mode=local_mode)
|
||||
assert tag in ['humaneval',
|
||||
'mbpp'], 'tag must be in ["humaneval", "mbpp"]'
|
||||
if tag == 'humaneval':
|
||||
assert language in _HUMANEVAL_LANGUAGE_, (
|
||||
f'language must be in {_HUMANEVAL_LANGUAGE_}')
|
||||
else:
|
||||
assert language in _MBPP_LANGUAGE_, (
|
||||
f'language must be in {_MBPP_LANGUAGE_}')
|
||||
file_path = osp.join(path, f'{tag}-{language}.jsonl')
|
||||
dataset = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
dataset.extend(
|
||||
[json.loads(line.strip()) for _ in range(num_repeats)])
|
||||
return Dataset.from_list(dataset)
|
||||
|
||||
|
||||
class MultiplEEvaluator(CodeEvaluator):
|
||||
|
||||
def _stop_at_stop_token(self, decoded_string, stop_tokens):
|
||||
"""Produces the prefix of decoded_string that ends at the first
|
||||
occurrence of a stop_token.
|
||||
|
||||
WARNING: the decoded_string *must not* include the prompt,
|
||||
which may have stop tokens itself.
|
||||
"""
|
||||
min_stop_index = len(decoded_string)
|
||||
for stop_token in stop_tokens:
|
||||
stop_index = decoded_string.find(stop_token)
|
||||
if stop_index != -1 and stop_index < min_stop_index:
|
||||
min_stop_index = stop_index
|
||||
return decoded_string[:min_stop_index]
|
||||
|
||||
def _process_completions(self, test_case, completions):
|
||||
processed_completions = []
|
||||
for comp in completions:
|
||||
comp = self._extract_code(comp)
|
||||
post_comp = self._remove_prefix(test_case['prompt'], comp)
|
||||
post_comp = self._stop_at_stop_token(post_comp,
|
||||
test_case['stop_tokens'])
|
||||
processed_completions.append(post_comp)
|
||||
return processed_completions
|
267
opencompass/openicl/icl_evaluator/code_evaluator.py
Normal file
267
opencompass/openicl/icl_evaluator/code_evaluator.py
Normal file
@ -0,0 +1,267 @@
|
||||
# flake8: noqa: E501
|
||||
|
||||
import difflib
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from datasets import Dataset
|
||||
from gradio_client import Client
|
||||
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
from opencompass.registry import ICL_EVALUATORS
|
||||
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class CodeEvaluator(BaseEvaluator):
|
||||
"""Evaluator for code generation tasks.
|
||||
|
||||
This evaluator sends code to a remote evaluation service to test its
|
||||
functionality against provided test cases. It handles code extraction,
|
||||
processing, and result analysis.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
language: str,
|
||||
ip_address: str = 'localhost',
|
||||
retry: int = 3) -> None:
|
||||
"""Initialize the CodeEvaluator.
|
||||
|
||||
Args:
|
||||
language (str): Programming language of the code to evaluate.
|
||||
ip_address (str, optional): IP address of the evaluation service. Defaults to 'localhost'.
|
||||
retry (int, optional): Number of retry attempts for failed connections. Defaults to 3.
|
||||
"""
|
||||
self.language = language
|
||||
self.retry = retry
|
||||
self.client = Client(ip_address)
|
||||
super().__init__()
|
||||
|
||||
def _extract_code(self, text: str) -> str:
|
||||
"""Extract code from markdown-formatted text.
|
||||
|
||||
Args:
|
||||
text (str): Text that may contain code blocks in markdown format.
|
||||
|
||||
Returns:
|
||||
str: Extracted code from the last code block, or the original text if no code blocks found.
|
||||
"""
|
||||
blocks = re.findall(r'```\w*\n(.*?)```', text, re.DOTALL)
|
||||
if len(blocks) >= 1:
|
||||
text = blocks[0]
|
||||
return text
|
||||
|
||||
def _code_eval_service(
|
||||
self, input_data: Union[Dict, List,
|
||||
str]) -> Tuple[bool, Union[Dict, List, Any]]:
|
||||
"""Send code to the remote evaluation service using gradio_client and
|
||||
get the results.
|
||||
|
||||
Args:
|
||||
input_data: Can be one of:
|
||||
- dict: Dictionary containing code information for a single test case
|
||||
- list: List of dictionaries for batch evaluation
|
||||
- str: File path to code file
|
||||
|
||||
Returns:
|
||||
tuple: (succeed, output)
|
||||
- succeed (bool): Whether the request was successful
|
||||
- output (dict/list/str): Evaluation results or error message
|
||||
"""
|
||||
try:
|
||||
temp_file_path = None
|
||||
# Handle file path input
|
||||
if isinstance(input_data, str):
|
||||
with tempfile.NamedTemporaryFile(suffix=f'.{self.language}',
|
||||
delete=False) as temp_file:
|
||||
temp_file_path = temp_file.name
|
||||
with open(input_data, 'r') as src_file:
|
||||
content = src_file.read()
|
||||
temp_file.write(content.encode())
|
||||
input_data = temp_file_path
|
||||
|
||||
# Send to evaluation service
|
||||
result = self.client.predict(input_data, api_name='/evaluate')
|
||||
|
||||
# Process the result
|
||||
if isinstance(result, (dict, list)):
|
||||
return True, result
|
||||
else:
|
||||
# Try to parse the result as JSON if it's a string
|
||||
try:
|
||||
import json
|
||||
parsed_result = json.loads(result)
|
||||
return True, parsed_result
|
||||
except: # noqa: E722
|
||||
return True, {'status': 'unknown', 'raw_result': result}
|
||||
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
finally:
|
||||
# Clean up temporary file if it was created
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.unlink(temp_file_path)
|
||||
except: # noqa: E722
|
||||
pass
|
||||
|
||||
def _remove_prefix(self,
|
||||
prompt: str,
|
||||
completion: str,
|
||||
threshold: float = 0.95) -> str:
|
||||
"""Determine the truncation point in the completion based on the last
|
||||
line of the prompt, remove all content before that line in the
|
||||
completion, and return the completion string after removing the prefix.
|
||||
This is done to convert chatbot-style inference mode to completion
|
||||
mode.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt text.
|
||||
completion (str): The completion text.
|
||||
threshold (float): Line similarity threshold.
|
||||
|
||||
Returns:
|
||||
str: The completion string after removing the prefix.
|
||||
"""
|
||||
prompt_lines = prompt.splitlines()
|
||||
completion_lines = completion.splitlines()
|
||||
|
||||
if not prompt_lines:
|
||||
return completion
|
||||
|
||||
last_prompt_line = prompt_lines[-1]
|
||||
cut_index = -1
|
||||
|
||||
for i, completion_line in enumerate(completion_lines):
|
||||
similarity = difflib.SequenceMatcher(None, last_prompt_line,
|
||||
completion_line).ratio()
|
||||
if similarity >= threshold:
|
||||
cut_index = i
|
||||
break
|
||||
|
||||
if cut_index != -1:
|
||||
return '\n'.join(completion_lines[cut_index + 1:])
|
||||
else:
|
||||
return completion
|
||||
|
||||
def _process_completions(self, test_case: dict, completions: list) -> list:
|
||||
"""Process code completion list, which typically involves extracting
|
||||
code, removing repetitive prefixes caused by chatbot mode, and other
|
||||
steps to ensure the model-generated code can be compiled successfully.
|
||||
|
||||
Args:
|
||||
test_case (dict): Dictionary containing test case information including:
|
||||
completions (list): List of code completions generated by the model.
|
||||
|
||||
Returns:
|
||||
list: Processed code completion list.
|
||||
"""
|
||||
processed_completions = []
|
||||
for comp in completions:
|
||||
comp = self._extract_code(comp)
|
||||
post_comp = self._remove_prefix(test_case['prompt'], comp)
|
||||
processed_completions.append(post_comp)
|
||||
return processed_completions
|
||||
|
||||
def _evaluate(
|
||||
self, input_data: Union[Dict, List]
|
||||
) -> Tuple[bool, Optional[Union[Dict, List]], Optional[str]]:
|
||||
"""Evaluate code with retry mechanism.
|
||||
|
||||
Args:
|
||||
input_data: Can be either:
|
||||
- dict: Dictionary containing code and test information for a single test case
|
||||
- list: List of dictionaries for batch evaluation
|
||||
|
||||
Returns:
|
||||
tuple: (success, output, error_message)
|
||||
- success (bool): Whether the evaluation was successful
|
||||
- output (dict or list): Evaluation output (if successful)
|
||||
- error_message (str): Error message (if failed)
|
||||
"""
|
||||
num_retry = 0
|
||||
while num_retry < self.retry:
|
||||
succeed, output = self._code_eval_service(input_data)
|
||||
if not succeed:
|
||||
num_retry += 1
|
||||
time.sleep(10)
|
||||
else:
|
||||
break
|
||||
|
||||
if not succeed:
|
||||
return False, None, f'code eval service connection failed: {output}'
|
||||
|
||||
return True, output, None
|
||||
|
||||
def score(self, predictions: List, references: List,
|
||||
test_set: Dataset) -> Dict:
|
||||
"""Score code generation predictions against references.
|
||||
|
||||
Args:
|
||||
predictions (list): List of model-generated code completions.
|
||||
references (list): List of reference solutions (not directly used in evaluation).
|
||||
test_set (Dataset): Dataset containing test cases and other metadata.
|
||||
|
||||
Returns:
|
||||
dict: Evaluation results including:
|
||||
- accuracy: Percentage of correctly solved problems
|
||||
- details: Detailed results for each test case
|
||||
- error: Error message if evaluation failed
|
||||
"""
|
||||
if len(predictions) != len(references):
|
||||
return {
|
||||
'error':
|
||||
'predictions and references have different '
|
||||
f'length. len(predictions): {len(predictions)}, '
|
||||
f'len(references): {len(references)}'
|
||||
}
|
||||
|
||||
test_set = test_set.to_pandas()
|
||||
# Use the first column as the unique identifier
|
||||
test_set_origin = test_set.drop_duplicates(subset=test_set.columns[0])
|
||||
num_repeats = int(len(test_set) / len(test_set_origin))
|
||||
|
||||
# 1. Prepare data for all test cases
|
||||
all_test_cases = []
|
||||
for i in range(len(test_set_origin)):
|
||||
test_case = test_set_origin.iloc[i]
|
||||
completions = predictions[i * num_repeats:(i + 1) * num_repeats]
|
||||
|
||||
# Process code completions
|
||||
processed_completions = self._process_completions(
|
||||
test_case, completions)
|
||||
|
||||
result_dict = {
|
||||
'name': test_case['name'],
|
||||
'language': test_case['language'],
|
||||
'prompt': test_case['prompt'],
|
||||
'tests': test_case['tests'],
|
||||
'processed_completions': processed_completions,
|
||||
'completions': completions
|
||||
}
|
||||
|
||||
all_test_cases.append(result_dict)
|
||||
|
||||
# 2. Send all test cases to the evaluation service
|
||||
success, outputs, error_message = self._evaluate(all_test_cases)
|
||||
if not success:
|
||||
return {'error': error_message}
|
||||
|
||||
# 3. Process the returned results
|
||||
details = []
|
||||
correct = 0
|
||||
for output in outputs:
|
||||
if output.get('status') == 'OK':
|
||||
output['correct'] = True
|
||||
correct += 1
|
||||
else:
|
||||
output['correct'] = False
|
||||
|
||||
details.append(output)
|
||||
|
||||
return {
|
||||
f'pass@{num_repeats}': 100 * correct / len(test_set_origin),
|
||||
'details': details
|
||||
}
|
@ -193,6 +193,12 @@ DATASETS_MAPPING = {
|
||||
"hf_id": "",
|
||||
"local": "./data/mmlu_pro",
|
||||
},
|
||||
# MultiPL-E
|
||||
"opencompass/multipl_e": {
|
||||
"ms_id": "",
|
||||
"hf_id": "",
|
||||
"local": "./data/multipl_e",
|
||||
},
|
||||
# NQ
|
||||
"opencompass/natural_question": {
|
||||
"ms_id": "opencompass/natural_question",
|
||||
@ -627,6 +633,11 @@ DATASETS_URL = {
|
||||
"http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mmlu_pro.zip",
|
||||
"md5": "e3200c7380f4cea5f13c768f2815fabb",
|
||||
},
|
||||
"multipl_e": {
|
||||
"url":
|
||||
"http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/multipl_e.zip",
|
||||
"md5": "24462aac7a38a4a62f5c5e89eb614e20",
|
||||
},
|
||||
"/Longbench": {
|
||||
"url":
|
||||
"http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/Longbench.zip",
|
||||
|
Loading…
Reference in New Issue
Block a user