multiple_code develop

This commit is contained in:
Dongsheng Zhu 2025-03-20 05:56:27 +00:00
parent b9de8b0e2b
commit 373a0cba9b
7 changed files with 465 additions and 0 deletions

View File

@ -0,0 +1,56 @@
# Select the 10 most popular programming languages from MultiPL-E to compose the test set.
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import MultiplEDataset, MultiplEEvaluator
_TOP_TEN_LANGUAGE_ = ['cpp', 'cs', 'go', 'java', 'rb', 'js', 'php', 'r', 'rs', 'sh']
multiple_reader_cfg = dict(input_columns=['language', 'prompt'], output_column='tests')
multiple_infer_cfg = dict(
prompt_template=dict(type=PromptTemplate, template='Based on the provided {language} code snippet, complete the subsequent content. The initial part of the completed code must match the provided code snippet exactly:\n{prompt}'),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=2048),
)
multiple_eval_cfg = {
lang: dict(
evaluator=dict(
type=MultiplEEvaluator,
language=lang,
ip_address='https://dongsheng-docker-test.hf.space', # https://opencompass-multiple-evaluator.hf.space
),
pred_role='BOT',
) for lang in _TOP_TEN_LANGUAGE_
}
multiple_datasets = [
dict(
type=MultiplEDataset,
abbr=f'humaneval-multiple-{lang}',
language=lang,
num_repeats=1,
path='opencompass/multipl_e',
tag='humaneval',
reader_cfg=multiple_reader_cfg,
infer_cfg=multiple_infer_cfg,
eval_cfg=multiple_eval_cfg[lang],
) for lang in _TOP_TEN_LANGUAGE_
]
multiple_datasets += [
dict(
type=MultiplEDataset,
abbr=f'mbpp-multiple-{lang}',
language=lang,
num_repeats=1,
path='opencompass/multipl_e',
tag='mbpp',
reader_cfg=multiple_reader_cfg,
infer_cfg=multiple_infer_cfg,
eval_cfg=multiple_eval_cfg[lang],
) for lang in _TOP_TEN_LANGUAGE_
]

View File

@ -0,0 +1,12 @@
from opencompass.models import HuggingFacewithChatTemplate
models = [
dict(
type=HuggingFacewithChatTemplate,
abbr='phi-4',
path='microsoft/phi-4',
max_out_len=1024,
batch_size=8,
run_cfg=dict(num_gpus=2),
)
]

View File

@ -98,6 +98,7 @@ from .mmlu_cf import * # noqa: F401, F403
from .mmlu_pro import * # noqa: F401, F403
from .MMLUArabic import * # noqa: F401, F403
from .mmmlu import * # noqa: F401, F403
from .multipl_e import * # noqa: F401, F403
from .multirc import * # noqa: F401, F403
from .musr import * # noqa: F401, F403
from .narrativeqa import * # noqa: F401, F403

View File

@ -183,6 +183,33 @@ class CustomDataset(BaseDataset):
return Dataset.from_list(data)
@LOAD_DATASET.register_module()
class CodeCustomDataset(BaseDataset):
@staticmethod
def load(path, file_name=None, local_mode=False, num_repeats=1, **kwargs):
path = get_data_path(path, local_mode=local_mode)
if file_name is not None:
path = os.path.join(path, file_name)
data = []
if path.endswith('.jsonl'):
with open(path, 'r', encoding='utf-8') as f:
for line in f:
data.extend(
[json.loads(line.strip()) for _ in range(num_repeats)])
elif path.endswith('.csv'):
with open(path, 'r', encoding='utf-8-sig') as f:
reader = csv.reader(f)
header = next(reader)
for row in reader:
data.extend(
[dict(zip(header, row)) for _ in range(num_repeats)])
else:
raise ValueError(f'Unsupported file format: {path}')
return Dataset.from_list(data)
class CircularCustomDataset(CustomDataset, metaclass=CircularDatasetMeta):
dataset_class = CustomDataset

View File

@ -0,0 +1,91 @@
import json
import os.path as osp
from datasets import Dataset
from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator
from opencompass.registry import LOAD_DATASET
from opencompass.utils import get_data_path
from .base import BaseDataset
# currently supporting languages
_HUMANEVAL_LANGUAGE_ = [
'adb', 'clj', 'cpp', 'cs', 'd', 'dart', 'elixir', 'go', 'hs', 'java', 'jl',
'js', 'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala',
'sh', 'swift', 'ts'
]
_MBPP_LANGUAGE_ = [
'adb', 'clj', 'cpp', 'cs', 'd', 'elixir', 'go', 'hs', 'java', 'jl', 'js',
'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh',
'swift', 'ts'
]
@LOAD_DATASET.register_module()
class MultiplEDataset(BaseDataset):
@staticmethod
def load(path: str,
language: str,
num_repeats: int = 1,
tag: str = 'humaneval',
local_mode: bool = False):
"""Load humaneval dataset for pass k mode.
Note that you can use num_repeats > 1 when your model does not support
`num_return_sequence` in generation, otherwise use the raw
humaneval dataset and set `num_return_sequence` in model config to
generate multiple responses for testing pass@k>1.
It better to change your dataset abbr correspondingly if you want to
change num_repeats>1, otherwise the number in
`.cache/dataset_size.json` might be inconsistent.
Args:
num_repeats(int): Number of repetition for this dataset to get
multiple responses in special cases.
"""
path = get_data_path(path, local_mode=local_mode)
assert tag in ['humaneval',
'mbpp'], 'tag must be in ["humaneval", "mbpp"]'
if tag == 'humaneval':
assert language in _HUMANEVAL_LANGUAGE_, (
f'language must be in {_HUMANEVAL_LANGUAGE_}')
else:
assert language in _MBPP_LANGUAGE_, (
f'language must be in {_MBPP_LANGUAGE_}')
file_path = osp.join(path, f'{tag}-{language}.jsonl')
dataset = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
dataset.extend(
[json.loads(line.strip()) for _ in range(num_repeats)])
return Dataset.from_list(dataset)
class MultiplEEvaluator(CodeEvaluator):
def _stop_at_stop_token(self, decoded_string, stop_tokens):
"""Produces the prefix of decoded_string that ends at the first
occurrence of a stop_token.
WARNING: the decoded_string *must not* include the prompt,
which may have stop tokens itself.
"""
min_stop_index = len(decoded_string)
for stop_token in stop_tokens:
stop_index = decoded_string.find(stop_token)
if stop_index != -1 and stop_index < min_stop_index:
min_stop_index = stop_index
return decoded_string[:min_stop_index]
def _process_completions(self, test_case, completions):
processed_completions = []
for comp in completions:
comp = self._extract_code(comp)
post_comp = self._remove_prefix(test_case['prompt'], comp)
post_comp = self._stop_at_stop_token(post_comp,
test_case['stop_tokens'])
processed_completions.append(post_comp)
return processed_completions

View File

@ -0,0 +1,267 @@
# flake8: noqa: E501
import difflib
import os
import re
import tempfile
import time
from typing import Any, Dict, List, Optional, Tuple, Union
from datasets import Dataset
from gradio_client import Client
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS
@ICL_EVALUATORS.register_module()
class CodeEvaluator(BaseEvaluator):
"""Evaluator for code generation tasks.
This evaluator sends code to a remote evaluation service to test its
functionality against provided test cases. It handles code extraction,
processing, and result analysis.
"""
def __init__(self,
language: str,
ip_address: str = 'localhost',
retry: int = 3) -> None:
"""Initialize the CodeEvaluator.
Args:
language (str): Programming language of the code to evaluate.
ip_address (str, optional): IP address of the evaluation service. Defaults to 'localhost'.
retry (int, optional): Number of retry attempts for failed connections. Defaults to 3.
"""
self.language = language
self.retry = retry
self.client = Client(ip_address)
super().__init__()
def _extract_code(self, text: str) -> str:
"""Extract code from markdown-formatted text.
Args:
text (str): Text that may contain code blocks in markdown format.
Returns:
str: Extracted code from the last code block, or the original text if no code blocks found.
"""
blocks = re.findall(r'```\w*\n(.*?)```', text, re.DOTALL)
if len(blocks) >= 1:
text = blocks[0]
return text
def _code_eval_service(
self, input_data: Union[Dict, List,
str]) -> Tuple[bool, Union[Dict, List, Any]]:
"""Send code to the remote evaluation service using gradio_client and
get the results.
Args:
input_data: Can be one of:
- dict: Dictionary containing code information for a single test case
- list: List of dictionaries for batch evaluation
- str: File path to code file
Returns:
tuple: (succeed, output)
- succeed (bool): Whether the request was successful
- output (dict/list/str): Evaluation results or error message
"""
try:
temp_file_path = None
# Handle file path input
if isinstance(input_data, str):
with tempfile.NamedTemporaryFile(suffix=f'.{self.language}',
delete=False) as temp_file:
temp_file_path = temp_file.name
with open(input_data, 'r') as src_file:
content = src_file.read()
temp_file.write(content.encode())
input_data = temp_file_path
# Send to evaluation service
result = self.client.predict(input_data, api_name='/evaluate')
# Process the result
if isinstance(result, (dict, list)):
return True, result
else:
# Try to parse the result as JSON if it's a string
try:
import json
parsed_result = json.loads(result)
return True, parsed_result
except: # noqa: E722
return True, {'status': 'unknown', 'raw_result': result}
except Exception as e:
return False, str(e)
finally:
# Clean up temporary file if it was created
if temp_file_path and os.path.exists(temp_file_path):
try:
os.unlink(temp_file_path)
except: # noqa: E722
pass
def _remove_prefix(self,
prompt: str,
completion: str,
threshold: float = 0.95) -> str:
"""Determine the truncation point in the completion based on the last
line of the prompt, remove all content before that line in the
completion, and return the completion string after removing the prefix.
This is done to convert chatbot-style inference mode to completion
mode.
Args:
prompt (str): The prompt text.
completion (str): The completion text.
threshold (float): Line similarity threshold.
Returns:
str: The completion string after removing the prefix.
"""
prompt_lines = prompt.splitlines()
completion_lines = completion.splitlines()
if not prompt_lines:
return completion
last_prompt_line = prompt_lines[-1]
cut_index = -1
for i, completion_line in enumerate(completion_lines):
similarity = difflib.SequenceMatcher(None, last_prompt_line,
completion_line).ratio()
if similarity >= threshold:
cut_index = i
break
if cut_index != -1:
return '\n'.join(completion_lines[cut_index + 1:])
else:
return completion
def _process_completions(self, test_case: dict, completions: list) -> list:
"""Process code completion list, which typically involves extracting
code, removing repetitive prefixes caused by chatbot mode, and other
steps to ensure the model-generated code can be compiled successfully.
Args:
test_case (dict): Dictionary containing test case information including:
completions (list): List of code completions generated by the model.
Returns:
list: Processed code completion list.
"""
processed_completions = []
for comp in completions:
comp = self._extract_code(comp)
post_comp = self._remove_prefix(test_case['prompt'], comp)
processed_completions.append(post_comp)
return processed_completions
def _evaluate(
self, input_data: Union[Dict, List]
) -> Tuple[bool, Optional[Union[Dict, List]], Optional[str]]:
"""Evaluate code with retry mechanism.
Args:
input_data: Can be either:
- dict: Dictionary containing code and test information for a single test case
- list: List of dictionaries for batch evaluation
Returns:
tuple: (success, output, error_message)
- success (bool): Whether the evaluation was successful
- output (dict or list): Evaluation output (if successful)
- error_message (str): Error message (if failed)
"""
num_retry = 0
while num_retry < self.retry:
succeed, output = self._code_eval_service(input_data)
if not succeed:
num_retry += 1
time.sleep(10)
else:
break
if not succeed:
return False, None, f'code eval service connection failed: {output}'
return True, output, None
def score(self, predictions: List, references: List,
test_set: Dataset) -> Dict:
"""Score code generation predictions against references.
Args:
predictions (list): List of model-generated code completions.
references (list): List of reference solutions (not directly used in evaluation).
test_set (Dataset): Dataset containing test cases and other metadata.
Returns:
dict: Evaluation results including:
- accuracy: Percentage of correctly solved problems
- details: Detailed results for each test case
- error: Error message if evaluation failed
"""
if len(predictions) != len(references):
return {
'error':
'predictions and references have different '
f'length. len(predictions): {len(predictions)}, '
f'len(references): {len(references)}'
}
test_set = test_set.to_pandas()
# Use the first column as the unique identifier
test_set_origin = test_set.drop_duplicates(subset=test_set.columns[0])
num_repeats = int(len(test_set) / len(test_set_origin))
# 1. Prepare data for all test cases
all_test_cases = []
for i in range(len(test_set_origin)):
test_case = test_set_origin.iloc[i]
completions = predictions[i * num_repeats:(i + 1) * num_repeats]
# Process code completions
processed_completions = self._process_completions(
test_case, completions)
result_dict = {
'name': test_case['name'],
'language': test_case['language'],
'prompt': test_case['prompt'],
'tests': test_case['tests'],
'processed_completions': processed_completions,
'completions': completions
}
all_test_cases.append(result_dict)
# 2. Send all test cases to the evaluation service
success, outputs, error_message = self._evaluate(all_test_cases)
if not success:
return {'error': error_message}
# 3. Process the returned results
details = []
correct = 0
for output in outputs:
if output.get('status') == 'OK':
output['correct'] = True
correct += 1
else:
output['correct'] = False
details.append(output)
return {
f'pass@{num_repeats}': 100 * correct / len(test_set_origin),
'details': details
}

View File

@ -193,6 +193,12 @@ DATASETS_MAPPING = {
"hf_id": "",
"local": "./data/mmlu_pro",
},
# MultiPL-E
"opencompass/multipl_e": {
"ms_id": "",
"hf_id": "",
"local": "./data/multipl_e",
},
# NQ
"opencompass/natural_question": {
"ms_id": "opencompass/natural_question",
@ -627,6 +633,11 @@ DATASETS_URL = {
"http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mmlu_pro.zip",
"md5": "e3200c7380f4cea5f13c768f2815fabb",
},
"multipl_e": {
"url":
"http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/multipl_e.zip",
"md5": "24462aac7a38a4a62f5c5e89eb614e20",
},
"/Longbench": {
"url":
"http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/Longbench.zip",