code alignment

This commit is contained in:
Dongsheng Zhu 2025-03-03 02:35:12 +00:00
parent 73c80953c6
commit 7ae59ef7f9
2 changed files with 39 additions and 10 deletions

View File

@ -233,15 +233,27 @@ class LCBCodeGenerationEvaluator(BaseEvaluator):
num_process_evaluate, num_process_evaluate,
timeout=6, timeout=6,
release_version='release_v1', release_version='release_v1',
extractor_version='v1'): extractor_version='v1',
start_date=None,
end_date=None):
super().__init__() super().__init__()
self.num_process_evaluate = num_process_evaluate self.num_process_evaluate = num_process_evaluate
self.timeout = timeout self.timeout = timeout
self.dataset = LCBCodeGenerationDataset.load( self.dataset = LCBCodeGenerationDataset.load(
release_version=release_version)['test'] release_version=release_version,
start_date=start_date,
end_date=end_date)['test']
self.extractor_version = extractor_version self.extractor_version = extractor_version
def score(self, predictions, references): def score(self, predictions, references, test_set):
if len(predictions) != len(references):
return {
'error':
'predictions and references have different '
f'length. len(predictions): {len(predictions)}, '
f'len(references): {len(references)}'
}
if self.extractor_version == 'v1': if self.extractor_version == 'v1':
predictions = [[extract_code_generation(item)] predictions = [[extract_code_generation(item)]
for item in predictions] for item in predictions]
@ -254,19 +266,24 @@ class LCBCodeGenerationEvaluator(BaseEvaluator):
evaluation_samples[self.dataset[idx][ evaluation_samples[self.dataset[idx][
'question_id']] = self.dataset[idx]['evaluation_sample'] 'question_id']] = self.dataset[idx]['evaluation_sample']
references = [evaluation_samples[item] for item in references] filtered_predictions = []
filtered_references = []
for idx, item in enumerate(references):
if item in self.dataset['question_id']:
filtered_predictions.append(predictions[idx])
filtered_references.append(item)
references = [{'input_output': item} for item in references] filtered_references = [evaluation_samples[item] for item in filtered_references] # noqa: E501
BaseEvaluator.is_num_equal(predictions, references) filtered_references = [{'input_output': item} for item in filtered_references] # noqa: E501
extracted_predictions = {} extracted_predictions = {}
for idx, content in enumerate(predictions): for idx, content in enumerate(filtered_predictions):
extracted_predictions[idx] = content extracted_predictions[idx] = content
metrics, eval_results, final_metadata = codegen_metrics( metrics, eval_results, final_metadata = codegen_metrics(
references, filtered_references,
predictions, filtered_predictions,
k_list=[1], k_list=[1],
num_process_evaluate=self.num_process_evaluate, num_process_evaluate=self.num_process_evaluate,
timeout=self.timeout, timeout=self.timeout,

View File

@ -7,6 +7,7 @@ import pickle
import zlib import zlib
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from datetime import datetime
from datasets import DatasetDict, load_dataset, load_from_disk from datasets import DatasetDict, load_dataset, load_from_disk
@ -53,7 +54,9 @@ class LCBCodeGenerationDataset(BaseDataset):
@staticmethod @staticmethod
def load(path: str = 'opencompass/code_generation_lite', def load(path: str = 'opencompass/code_generation_lite',
local_mode: bool = False, local_mode: bool = False,
release_version: str = 'release_v1'): release_version: str = 'release_v1',
start_date: str = None,
end_date: str = None):
def transform(item): def transform(item):
# Define the dataitem mapping logic # Define the dataitem mapping logic
@ -107,6 +110,15 @@ class LCBCodeGenerationDataset(BaseDataset):
dataset = dataset.map(transform) dataset = dataset.map(transform)
if start_date is not None:
p_start_date = datetime.strptime(start_date, "%Y-%m-%d")
dataset = dataset.filter(
lambda e: p_start_date <= datetime.fromisoformat(e['contest_date'])) # noqa: E501
if end_date is not None:
p_end_date = datetime.strptime(end_date, "%Y-%m-%d")
dataset = dataset.filter(
lambda e: datetime.fromisoformat(e['contest_date']) <= p_end_date) # noqa: E501
return DatasetDict({'test': dataset, 'train': dataset}) return DatasetDict({'test': dataset, 'train': dataset})