OpenCompass/opencompass/openicl/icl_evaluator/code_evaluator.py

268 lines
9.7 KiB
Python
Raw Normal View History

# flake8: noqa: E501
import difflib
import os
import re
import tempfile
import time
from typing import Any, Dict, List, Optional, Tuple, Union
from datasets import Dataset
from gradio_client import Client
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS
@ICL_EVALUATORS.register_module()
class CodeEvaluator(BaseEvaluator):
"""Evaluator for code generation tasks.
This evaluator sends code to a remote evaluation service to test its
functionality against provided test cases. It handles code extraction,
processing, and result analysis.
"""
def __init__(self,
language: str,
ip_address: str = 'localhost',
retry: int = 3) -> None:
"""Initialize the CodeEvaluator.
Args:
language (str): Programming language of the code to evaluate.
ip_address (str, optional): IP address of the evaluation service. Defaults to 'localhost'.
retry (int, optional): Number of retry attempts for failed connections. Defaults to 3.
"""
self.language = language
self.retry = retry
self.client = Client(ip_address)
super().__init__()
def _extract_code(self, text: str) -> str:
"""Extract code from markdown-formatted text.
Args:
text (str): Text that may contain code blocks in markdown format.
Returns:
str: Extracted code from the last code block, or the original text if no code blocks found.
"""
blocks = re.findall(r'```\w*\n(.*?)```', text, re.DOTALL)
if len(blocks) >= 1:
text = blocks[0]
return text
def _code_eval_service(
self, input_data: Union[Dict, List,
str]) -> Tuple[bool, Union[Dict, List, Any]]:
"""Send code to the remote evaluation service using gradio_client and
get the results.
Args:
input_data: Can be one of:
- dict: Dictionary containing code information for a single test case
- list: List of dictionaries for batch evaluation
- str: File path to code file
Returns:
tuple: (succeed, output)
- succeed (bool): Whether the request was successful
- output (dict/list/str): Evaluation results or error message
"""
try:
temp_file_path = None
# Handle file path input
if isinstance(input_data, str):
with tempfile.NamedTemporaryFile(suffix=f'.{self.language}',
delete=False) as temp_file:
temp_file_path = temp_file.name
with open(input_data, 'r') as src_file:
content = src_file.read()
temp_file.write(content.encode())
input_data = temp_file_path
# Send to evaluation service
result = self.client.predict(input_data, api_name='/evaluate')
# Process the result
if isinstance(result, (dict, list)):
return True, result
else:
# Try to parse the result as JSON if it's a string
try:
import json
parsed_result = json.loads(result)
return True, parsed_result
except: # noqa: E722
return True, {'status': 'unknown', 'raw_result': result}
except Exception as e:
return False, str(e)
finally:
# Clean up temporary file if it was created
if temp_file_path and os.path.exists(temp_file_path):
try:
os.unlink(temp_file_path)
except: # noqa: E722
pass
def _remove_prefix(self,
prompt: str,
completion: str,
threshold: float = 0.95) -> str:
"""Determine the truncation point in the completion based on the last
line of the prompt, remove all content before that line in the
completion, and return the completion string after removing the prefix.
This is done to convert chatbot-style inference mode to completion
mode.
Args:
prompt (str): The prompt text.
completion (str): The completion text.
threshold (float): Line similarity threshold.
Returns:
str: The completion string after removing the prefix.
"""
prompt_lines = prompt.splitlines()
completion_lines = completion.splitlines()
if not prompt_lines:
return completion
last_prompt_line = prompt_lines[-1]
cut_index = -1
for i, completion_line in enumerate(completion_lines):
similarity = difflib.SequenceMatcher(None, last_prompt_line,
completion_line).ratio()
if similarity >= threshold:
cut_index = i
break
if cut_index != -1:
return '\n'.join(completion_lines[cut_index + 1:])
else:
return completion
def _process_completions(self, test_case: dict, completions: list) -> list:
"""Process code completion list, which typically involves extracting
code, removing repetitive prefixes caused by chatbot mode, and other
steps to ensure the model-generated code can be compiled successfully.
Args:
test_case (dict): Dictionary containing test case information including:
completions (list): List of code completions generated by the model.
Returns:
list: Processed code completion list.
"""
processed_completions = []
for comp in completions:
comp = self._extract_code(comp)
post_comp = self._remove_prefix(test_case['prompt'], comp)
processed_completions.append(post_comp)
return processed_completions
def _evaluate(
self, input_data: Union[Dict, List]
) -> Tuple[bool, Optional[Union[Dict, List]], Optional[str]]:
"""Evaluate code with retry mechanism.
Args:
input_data: Can be either:
- dict: Dictionary containing code and test information for a single test case
- list: List of dictionaries for batch evaluation
Returns:
tuple: (success, output, error_message)
- success (bool): Whether the evaluation was successful
- output (dict or list): Evaluation output (if successful)
- error_message (str): Error message (if failed)
"""
num_retry = 0
while num_retry < self.retry:
succeed, output = self._code_eval_service(input_data)
if not succeed:
num_retry += 1
time.sleep(10)
else:
break
if not succeed:
return False, None, f'code eval service connection failed: {output}'
return True, output, None
def score(self, predictions: List, references: List,
test_set: Dataset) -> Dict:
"""Score code generation predictions against references.
Args:
predictions (list): List of model-generated code completions.
references (list): List of reference solutions (not directly used in evaluation).
test_set (Dataset): Dataset containing test cases and other metadata.
Returns:
dict: Evaluation results including:
- accuracy: Percentage of correctly solved problems
- details: Detailed results for each test case
- error: Error message if evaluation failed
"""
if len(predictions) != len(references):
return {
'error':
'predictions and references have different '
f'length. len(predictions): {len(predictions)}, '
f'len(references): {len(references)}'
}
test_set = test_set.to_pandas()
# Use the first column as the unique identifier
test_set_origin = test_set.drop_duplicates(subset=test_set.columns[0])
num_repeats = int(len(test_set) / len(test_set_origin))
# 1. Prepare data for all test cases
all_test_cases = []
for i in range(len(test_set_origin)):
test_case = test_set_origin.iloc[i]
completions = predictions[i * num_repeats:(i + 1) * num_repeats]
# Process code completions
processed_completions = self._process_completions(
test_case, completions)
result_dict = {
'name': test_case['name'],
'language': test_case['language'],
'prompt': test_case['prompt'],
'tests': test_case['tests'],
'processed_completions': processed_completions,
'completions': completions
}
all_test_cases.append(result_dict)
# 2. Send all test cases to the evaluation service
success, outputs, error_message = self._evaluate(all_test_cases)
if not success:
return {'error': error_message}
# 3. Process the returned results
details = []
correct = 0
for output in outputs:
if output.get('status') == 'OK':
output['correct'] = True
correct += 1
else:
output['correct'] = False
details.append(output)
return {
f'pass@{num_repeats}': 100 * correct / len(test_set_origin),
'details': details
}