diff --git a/opencompass/configs/datasets/humaneval_pro/humaneval_pro.py b/opencompass/configs/datasets/humaneval_pro/humaneval_pro.py deleted file mode 100644 index 34c73f48..00000000 --- a/opencompass/configs/datasets/humaneval_pro/humaneval_pro.py +++ /dev/null @@ -1,4 +0,0 @@ -from mmengine.config import read_base - -with read_base(): - from .humaneval_pro_gen_ import humanevalpro_datasets # noqa: F401, F403 diff --git a/opencompass/configs/datasets/humaneval_pro/humaneval_pro_gen.py b/opencompass/configs/datasets/humaneval_pro/humaneval_pro_gen.py new file mode 100644 index 00000000..9bccdd66 --- /dev/null +++ b/opencompass/configs/datasets/humaneval_pro/humaneval_pro_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .humaneval_pro_gen_3dc067 import humanevalpro_datasets # noqa: F401, F403 diff --git a/opencompass/configs/datasets/humaneval_pro/humaneval_pro_gen_3dc067.py b/opencompass/configs/datasets/humaneval_pro/humaneval_pro_gen_3dc067.py index 606cd8b1..e3ed8349 100644 --- a/opencompass/configs/datasets/humaneval_pro/humaneval_pro_gen_3dc067.py +++ b/opencompass/configs/datasets/humaneval_pro/humaneval_pro_gen_3dc067.py @@ -3,16 +3,6 @@ from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.datasets import HumanevalevalProDataset, HumanevalProEvaluator, humaneval_postprocess_v2 -OFFICIAL_PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions. -@@ Instruction -Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution. -```python -{raw_problem} -{new_problem} -``` -@@ Response -Please put the two solutions to the above problems in one Python code block. -""" PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions. Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution. diff --git a/opencompass/configs/datasets/humaneval_pro/humaneval_pro_repeat_gen_3dc067.py b/opencompass/configs/datasets/humaneval_pro/humaneval_pro_repeat_gen_3dc067.py new file mode 100644 index 00000000..98320f78 --- /dev/null +++ b/opencompass/configs/datasets/humaneval_pro/humaneval_pro_repeat_gen_3dc067.py @@ -0,0 +1,48 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import HumanevalevalProDataset, HumanevalProEvaluator, humaneval_postprocess_v2 + + +PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions. +Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution. +```python +{raw_problem} +{new_problem} +``` +Please put the two solutions within the Python code block provided below, and make sure that the block contains no other unrelated content: +```python +``` +""" + + +humanevalpro_reader_cfg = dict( + input_columns=['raw_problem', 'new_problem'], output_column='test_code') + +humanevalpro_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict( + role='HUMAN', + prompt=PROMPT_WRAPPER), + ])), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer)) + +humanevalpro_eval_cfg = dict( + evaluator=dict(type=HumanevalProEvaluator, + ip_address='https://opencompass-multiple-evaluator.hf.space') +) + +humanevalpro_datasets = [ + dict( + abbr='humaneval_pro', + type=HumanevalevalProDataset, + path='opencompass/humaneval_pro', + reader_cfg=humanevalpro_reader_cfg, + infer_cfg=humanevalpro_infer_cfg, + eval_cfg=humanevalpro_eval_cfg, + n=5, + k=3) +] \ No newline at end of file diff --git a/opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen.py b/opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen.py index bded1658..84d45d83 100644 --- a/opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen.py +++ b/opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen.py @@ -1,4 +1,4 @@ from mmengine.config import read_base with read_base(): - from .mbpp_pro_gen_ import mbpppro_datasets # noqa: F401, F403 + from .mbpp_pro_gen_3dc067 import mbpppro_datasets # noqa: F401, F403 diff --git a/opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen_3dc067.py b/opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen_3dc067.py index 0c8a882d..c14b05cb 100644 --- a/opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen_3dc067.py +++ b/opencompass/configs/datasets/mbpp_pro/mbpp_pro_gen_3dc067.py @@ -3,16 +3,6 @@ from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.datasets import MBPPProDataset, MBPPProEvaluator -OFFICIAL_PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions. -@@ Instruction -Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution. -```python -{raw_problem} -{new_problem} -``` -@@ Response -Please put the two solutions to the above problems in one Python code block. -""" PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions. Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution. diff --git a/opencompass/configs/datasets/mbpp_pro/mbpp_pro_repeat_gen_3dc067.py b/opencompass/configs/datasets/mbpp_pro/mbpp_pro_repeat_gen_3dc067.py new file mode 100644 index 00000000..631713b8 --- /dev/null +++ b/opencompass/configs/datasets/mbpp_pro/mbpp_pro_repeat_gen_3dc067.py @@ -0,0 +1,48 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import MBPPProDataset, MBPPProEvaluator + + +PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions. +Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution. +```python +{raw_problem} +{new_problem} +``` +Please put the two solutions within the Python code block provided below, and make sure that the block contains no other unrelated content: +```python +``` +""" + + +mbpppro_reader_cfg = dict( + input_columns=['raw_problem', 'new_problem'], output_column='test_code') + +mbpppro_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict( + role='HUMAN', + prompt=PROMPT_WRAPPER), + ])), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer)) + +mbpppro_eval_cfg = dict( + evaluator=dict(type=MBPPProEvaluator, + ip_address='https://opencompass-multiple-evaluator.hf.space'), +) + +mbpppro_datasets = [ + dict( + abbr='mbpp_pro', + type=MBPPProDataset, + path='opencompass/mbpp_pro', + reader_cfg=mbpppro_reader_cfg, + infer_cfg=mbpppro_infer_cfg, + eval_cfg=mbpppro_eval_cfg, + n=5, + k=3) +] \ No newline at end of file diff --git a/opencompass/configs/datasets/multipl_e/multiple_gen.py b/opencompass/configs/datasets/multipl_e/multiple_gen.py new file mode 100644 index 00000000..b32af567 --- /dev/null +++ b/opencompass/configs/datasets/multipl_e/multiple_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .multiple_top_ten_gen_f44aaf import multiple_datasets # noqa: F401, F403 diff --git a/opencompass/configs/datasets/multipl_e/multiple_top_ten_gen.py b/opencompass/configs/datasets/multipl_e/multiple_top_ten_gen_f44aaf.py similarity index 97% rename from opencompass/configs/datasets/multipl_e/multiple_top_ten_gen.py rename to opencompass/configs/datasets/multipl_e/multiple_top_ten_gen_f44aaf.py index 93ab2962..040c5ba5 100644 --- a/opencompass/configs/datasets/multipl_e/multiple_top_ten_gen.py +++ b/opencompass/configs/datasets/multipl_e/multiple_top_ten_gen_f44aaf.py @@ -32,7 +32,6 @@ multiple_datasets = [ type=MultiplEDataset, abbr=f'humaneval-multiple-{lang}', language=lang, - num_repeats=1, path='opencompass/multipl_e', tag='humaneval', reader_cfg=multiple_reader_cfg, @@ -46,7 +45,6 @@ multiple_datasets += [ type=MultiplEDataset, abbr=f'mbpp-multiple-{lang}', language=lang, - num_repeats=1, path='opencompass/multipl_e', tag='mbpp', reader_cfg=multiple_reader_cfg, diff --git a/opencompass/configs/datasets/multipl_e/multiple_top_ten_repeat_gen_f44aaf.py b/opencompass/configs/datasets/multipl_e/multiple_top_ten_repeat_gen_f44aaf.py new file mode 100644 index 00000000..248fdf54 --- /dev/null +++ b/opencompass/configs/datasets/multipl_e/multiple_top_ten_repeat_gen_f44aaf.py @@ -0,0 +1,56 @@ +# Select the 10 most popular programming languages from MultiPL-E to compose the test set. + +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import MultiplEDataset, MultiplEEvaluator + + +_TOP_TEN_LANGUAGE_ = ['cpp', 'cs', 'go', 'java', 'rb', 'js', 'php', 'r', 'rs', 'sh'] + +multiple_reader_cfg = dict(input_columns=['language', 'prompt'], output_column='tests') + +multiple_infer_cfg = dict( + prompt_template=dict(type=PromptTemplate, template='Based on the provided {language} code snippet, complete the subsequent content. The initial part of the completed code must match the provided code snippet exactly:\n{prompt}'), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +multiple_eval_cfg = { + lang: dict( + evaluator=dict( + type=MultiplEEvaluator, + language=lang, + ip_address='https://opencompass-multiple-evaluator.hf.space', + ), + pred_role='BOT', + ) for lang in _TOP_TEN_LANGUAGE_ +} + +multiple_datasets = [ + dict( + type=MultiplEDataset, + abbr=f'humaneval-multiple-{lang}', + language=lang, + path='opencompass/multipl_e', + tag='humaneval', + reader_cfg=multiple_reader_cfg, + infer_cfg=multiple_infer_cfg, + eval_cfg=multiple_eval_cfg[lang], + ) for lang in _TOP_TEN_LANGUAGE_ +] + +multiple_datasets += [ + dict( + type=MultiplEDataset, + abbr=f'mbpp-multiple-{lang}', + language=lang, + path='opencompass/multipl_e', + tag='mbpp', + reader_cfg=multiple_reader_cfg, + infer_cfg=multiple_infer_cfg, + eval_cfg=multiple_eval_cfg[lang], + n=5, + k=3 + ) for lang in _TOP_TEN_LANGUAGE_ +] diff --git a/opencompass/datasets/humaneval_pro.py b/opencompass/datasets/humaneval_pro.py index 310b3b41..871b468f 100644 --- a/opencompass/datasets/humaneval_pro.py +++ b/opencompass/datasets/humaneval_pro.py @@ -1,7 +1,8 @@ +# flake8: noqa: E501s + import json from typing import Dict, List -import numpy as np from datasets import Dataset from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator @@ -9,29 +10,33 @@ from opencompass.utils import get_data_path from .base import BaseDataset +PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions. +Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution. +```python +{raw_problem} +{new_problem} +``` +Please put the two solutions within the Python code block provided below, and make sure that the block contains no other unrelated content: +```python +``` +""" + class HumanevalevalProDataset(BaseDataset): @staticmethod - def load(path, num_repeats=1, local_mode=False): + def load(path, local_mode=False): path = get_data_path(path, local_mode=local_mode) dataset = [] with open(path, encoding='utf-8') as f: raw_data = json.load(f) for data in raw_data: - dataset.extend([data for _ in range(num_repeats)]) + dataset.append(data) return Dataset.from_list(dataset) class HumanevalProEvaluator(CodeEvaluator): - def _process_completions(self, test_case: dict, completions: list) -> list: - processed_completions = [] - for comp in completions: - post_comp = self._extract_code(comp) - processed_completions.append(post_comp) - return processed_completions - def score(self, predictions: List, references: List, test_set: Dataset) -> Dict: if len(predictions) != len(references): @@ -45,52 +50,32 @@ class HumanevalProEvaluator(CodeEvaluator): test_set = test_set.to_pandas() # Use the first column as the unique identifier test_set_origin = test_set.drop_duplicates(subset=test_set.columns[0]) - num_repeats = int(len(test_set) / len(test_set_origin)) # 1. Prepare data for all test cases - all_test_cases = [] + all_test_cases, prompts = [], [] for i in range(len(test_set_origin)): test_case = test_set_origin.iloc[i] - completions = predictions[i * num_repeats:(i + 1) * num_repeats] + completion = predictions[i] # Process code completions - processed_completions = self._process_completions( - test_case, completions) - + processed_completion = self._process_completions(completion) + code = processed_completion + '\n' + test_case['test_code'] sub_data_dict = { 'name': int(test_case['id']), 'language': self.language, - 'prompt': '', - 'tests': test_case['test_code'], - 'processed_completions': processed_completions, - 'completions': completions + 'code': code, } - all_test_cases.append(sub_data_dict) + prompt = PROMPT_WRAPPER.format( + raw_problem=test_case['raw_problem'], + new_problem=test_case['new_problem']) + prompts.append(prompt) + # 2. Send all test cases to the evaluation service success, outputs, error_message = self._evaluate(all_test_cases) if not success: return {'error': error_message} # 3. Process the returned results - details = [] - total, correct = [], [] - for output in outputs: - passed = [m['status'] == 'OK' for m in output['meta_data']] - total.append(len(passed)) - correct.append(sum(passed)) - details.append(output) - total = np.array(total) - correct = np.array(correct) - - pass_at_k = { - f'pass@{k}': - self.estimate_pass_at_k(total, correct, k).mean() * 100 - for k in self.k if (total >= k).all() - } - - return { - **pass_at_k, - 'details': details, - } + return self._process_results(outputs, prompts, len(test_set_origin)) diff --git a/opencompass/datasets/mbpp_pro.py b/opencompass/datasets/mbpp_pro.py index 51a086d7..fe7d01a4 100644 --- a/opencompass/datasets/mbpp_pro.py +++ b/opencompass/datasets/mbpp_pro.py @@ -1,7 +1,8 @@ +# flake8: noqa: E501 + import json from typing import Dict, List -import numpy as np from datasets import Dataset from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator @@ -9,30 +10,33 @@ from opencompass.utils import get_data_path from .base import BaseDataset +PROMPT_WRAPPER = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions. +Write a solution of python file to the following problems, the solution of the second problem requires single or multiple calls to the first solution. +```python +{raw_problem} +{new_problem} +``` +Please put the two solutions within the Python code block provided below, and make sure that the block contains no other unrelated content: +```python +``` +""" + class MBPPProDataset(BaseDataset): @staticmethod - def load(path, num_repeats=1, local_mode=False): + def load(path, local_mode=False): path = get_data_path(path, local_mode=local_mode) print(path) dataset = [] with open(path, encoding='utf-8') as f: for line in f: - dataset.extend( - [json.loads(line.strip()) for _ in range(num_repeats)]) + dataset.append(json.loads(line.strip())) return Dataset.from_list(dataset) class MBPPProEvaluator(CodeEvaluator): - def _process_completions(self, test_case: dict, completions: list) -> list: - processed_completions = [] - for comp in completions: - post_comp = self._extract_code(comp) - processed_completions.append(post_comp) - return processed_completions - def score(self, predictions: List, references: List, test_set: Dataset) -> Dict: if len(predictions) != len(references): @@ -46,52 +50,32 @@ class MBPPProEvaluator(CodeEvaluator): test_set = test_set.to_pandas() # Use the first column as the unique identifier test_set_origin = test_set.drop_duplicates(subset=test_set.columns[0]) - num_repeats = int(len(test_set) / len(test_set_origin)) # 1. Prepare data for all test cases - all_test_cases = [] + all_test_cases, prompts = [], [] for i in range(len(test_set_origin)): test_case = test_set_origin.iloc[i] - completions = predictions[i * num_repeats:(i + 1) * num_repeats] + completion = predictions[i] # Process code completions - processed_completions = self._process_completions( - test_case, completions) - + processed_completion = self._process_completions(completion) + code = processed_completion + '\n' + test_case['test_code'] sub_data_dict = { 'name': int(test_case['id']), 'language': self.language, - 'prompt': '', - 'tests': test_case['test_code'], - 'processed_completions': processed_completions, - 'completions': completions + 'code': code, } - all_test_cases.append(sub_data_dict) + prompt = PROMPT_WRAPPER.format( + raw_problem=test_case['raw_problem'], + new_problem=test_case['new_problem']) + prompts.append(prompt) + # 2. Send all test cases to the evaluation service success, outputs, error_message = self._evaluate(all_test_cases) if not success: return {'error': error_message} # 3. Process the returned results - details = [] - total, correct = [], [] - for output in outputs: - passed = [m['status'] == 'OK' for m in output['meta_data']] - total.append(len(passed)) - correct.append(sum(passed)) - details.append(output) - total = np.array(total) - correct = np.array(correct) - - pass_at_k = { - f'pass@{k}': - self.estimate_pass_at_k(total, correct, k).mean() * 100 - for k in self.k if (total >= k).all() - } - - return { - **pass_at_k, - 'details': details, - } + return self._process_results(outputs, prompts, len(test_set_origin)) diff --git a/opencompass/datasets/multipl_e.py b/opencompass/datasets/multipl_e.py index 657b52de..893d3911 100644 --- a/opencompass/datasets/multipl_e.py +++ b/opencompass/datasets/multipl_e.py @@ -1,3 +1,4 @@ +import difflib import json import os.path as osp @@ -28,7 +29,6 @@ class MultiplEDataset(BaseDataset): @staticmethod def load(path: str, language: str, - num_repeats: int = 1, tag: str = 'humaneval', local_mode: bool = False): """Load dataset for pass k mode. @@ -56,8 +56,7 @@ class MultiplEDataset(BaseDataset): dataset = [] with open(file_path, 'r', encoding='utf-8') as f: for line in f: - dataset.extend( - [json.loads(line.strip()) for _ in range(num_repeats)]) + dataset.append(json.loads(line.strip())) return Dataset.from_list(dataset) @@ -84,20 +83,56 @@ class MultiplEEvaluator(CodeEvaluator): min_stop_index = stop_index return decoded_string[:min_stop_index] - def _process_completions(self, test_case, completions): + def _remove_prefix(self, + prompt: str, + completion: str, + threshold: float = 0.95) -> str: + """Determine the truncation point in the completion based on the last + line of the prompt, remove all content before that line in the + completion, and return the completion string after removing the prefix. + This is done to convert chatbot-style inference mode to completion + mode. + + Args: + prompt (str): The prompt text. + completion (str): The completion text. + threshold (float): Line similarity threshold. + + Returns: + str: The completion string after removing the prefix. + """ + prompt_lines = prompt.splitlines() + completion_lines = completion.splitlines() + + if not prompt_lines: + return completion + + last_prompt_line = prompt_lines[-1] + cut_index = -1 + + for i, completion_line in enumerate(completion_lines): + similarity = difflib.SequenceMatcher(None, last_prompt_line, + completion_line).ratio() + if similarity >= threshold: + cut_index = i + break + + if cut_index != -1: + return '\n'.join(completion_lines[cut_index + 1:]) + else: + return completion + + def _process_completions(self, test_case, completion): """Process completions with a test case. Args: - test_case: A test case. - completions: A list of completions. + test_case (dict): A test case containing prompt and stop tokens. + completion (str): The generated code completion. Returns: - A list of processed completions. + str: Processed code completion. """ - processed_completions = [] - for comp in completions: - comp = self._extract_code(comp) - post_comp = self._remove_prefix(test_case['prompt'], comp) - post_comp = self._stop_at_stop_token(post_comp, - test_case['stop_tokens']) - processed_completions.append(post_comp) - return processed_completions + post_comp = self._extract_code(completion) + post_comp = self._stop_at_stop_token(post_comp, + test_case['stop_tokens']) + post_comp = self._remove_prefix(test_case['prompt'], post_comp) + return post_comp diff --git a/opencompass/openicl/icl_evaluator/code_evaluator.py b/opencompass/openicl/icl_evaluator/code_evaluator.py index b79488cb..70502d1e 100644 --- a/opencompass/openicl/icl_evaluator/code_evaluator.py +++ b/opencompass/openicl/icl_evaluator/code_evaluator.py @@ -1,7 +1,5 @@ # flake8: noqa: E501 -import difflib -import itertools import os import re import tempfile @@ -73,6 +71,7 @@ class CodeEvaluator(BaseEvaluator): - output (dict/list/str): Evaluation results or error message """ try: + import requests temp_file_path = None # Handle file path input if isinstance(input_data, str): @@ -85,7 +84,15 @@ class CodeEvaluator(BaseEvaluator): input_data = temp_file_path # Send to evaluation service - result = self.client.predict(input_data, api_name='/evaluate') + try: + result = self.client.predict(input_data, api_name='/evaluate') + except Exception as e: + # Catch timeout and other exceptions + if 'timed out' in str(e).lower() or 'timeout' in str( + e).lower(): + return False, f'Request to code eval service timed out: {e}' + else: + raise # Process the result if isinstance(result, (dict, list)): @@ -109,63 +116,16 @@ class CodeEvaluator(BaseEvaluator): except: # noqa: E722 pass - def _remove_prefix(self, - prompt: str, - completion: str, - threshold: float = 0.95) -> str: - """Determine the truncation point in the completion based on the last - line of the prompt, remove all content before that line in the - completion, and return the completion string after removing the prefix. - This is done to convert chatbot-style inference mode to completion - mode. + def _process_completions(self, completion: str) -> list: + """Process code completions to extract the relevant code. Args: - prompt (str): The prompt text. - completion (str): The completion text. - threshold (float): Line similarity threshold. - + completion (str): Code completion string. Returns: - str: The completion string after removing the prefix. + list: List of processed code completions. """ - prompt_lines = prompt.splitlines() - completion_lines = completion.splitlines() - - if not prompt_lines: - return completion - - last_prompt_line = prompt_lines[-1] - cut_index = -1 - - for i, completion_line in enumerate(completion_lines): - similarity = difflib.SequenceMatcher(None, last_prompt_line, - completion_line).ratio() - if similarity >= threshold: - cut_index = i - break - - if cut_index != -1: - return '\n'.join(completion_lines[cut_index + 1:]) - else: - return completion - - def _process_completions(self, test_case: dict, completions: list) -> list: - """Process code completion list, which typically involves extracting - code, removing repetitive prefixes caused by chatbot mode, and other - steps to ensure the model-generated code can be compiled successfully. - - Args: - test_case (dict): Dictionary containing test case information including: - completions (list): List of code completions generated by the model. - - Returns: - list: Processed code completion list. - """ - processed_completions = [] - for comp in completions: - comp = self._extract_code(comp) - post_comp = self._remove_prefix(test_case['prompt'], comp) - processed_completions.append(post_comp) - return processed_completions + post_comp = self._extract_code(completion) + return post_comp def _evaluate( self, input_data: Union[Dict, List] @@ -197,6 +157,31 @@ class CodeEvaluator(BaseEvaluator): return True, output, None + def _process_results(self, outputs: List, prompts: List, + total_count: int) -> Dict: + """Process the evaluation results. + Args: + outputs (list): List of evaluation results for each test case. + prompts (list): List of prompts used for each test case. + total_count (int): Total number of test cases. + Returns: + dict: Processed results including: + - pass@1: Percentage of test cases passed + - details: Detailed results for each test case + """ + details = [] + correct = 0 + for output, prompt in zip(outputs, prompts): + output['prompt'] = prompt + if output.get('status') == 'OK': + output['correct'] = True + correct += 1 + else: + output['correct'] = False + details.append(output) + + return {f'pass@1': 100 * correct / total_count, 'details': details} + def score(self, predictions: List, references: List, test_set: Dataset) -> Dict: """Score code generation predictions against references. @@ -225,25 +210,23 @@ class CodeEvaluator(BaseEvaluator): test_set_origin = test_set.drop_duplicates(subset=test_set.columns[0]) # 1. Prepare data for all test cases - all_test_cases = [] + all_test_cases, prompts = [], [] for i in range(len(test_set_origin)): test_case = test_set_origin.iloc[i] - completions = predictions[i] + completion = predictions[i] # Process code completions - processed_completions = self._process_completions( - test_case, completions) - + processed_completion = self._process_completions( + test_case, completion) + code = test_case[ + 'prompt'] + processed_completion + '\n' + test_case['tests'] sub_data_dict = { 'name': test_case['name'], 'language': test_case['language'], - 'prompt': test_case['prompt'], - 'tests': test_case['tests'], - 'processed_completions': processed_completions, - 'completions': completions + 'code': code } - all_test_cases.append(sub_data_dict) + prompts.append(test_case['prompt']) # 2. Send all test cases to the evaluation service success, outputs, error_message = self._evaluate(all_test_cases) @@ -251,18 +234,4 @@ class CodeEvaluator(BaseEvaluator): return {'error': error_message} # 3. Process the returned results - details = [] - correct = 0 - for output in outputs: - if output.get('status') == 'OK': - output['correct'] = True - correct += 1 - else: - output['correct'] = False - - details.append(output) - - return { - f'pass@1': 100 * correct / len(test_set_origin), - 'details': details - } \ No newline at end of file + return self._process_results(outputs, prompts, len(test_set_origin))