mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
code alignment
This commit is contained in:
parent
73c80953c6
commit
7ae59ef7f9
@ -233,15 +233,27 @@ class LCBCodeGenerationEvaluator(BaseEvaluator):
|
||||
num_process_evaluate,
|
||||
timeout=6,
|
||||
release_version='release_v1',
|
||||
extractor_version='v1'):
|
||||
extractor_version='v1',
|
||||
start_date=None,
|
||||
end_date=None):
|
||||
super().__init__()
|
||||
self.num_process_evaluate = num_process_evaluate
|
||||
self.timeout = timeout
|
||||
self.dataset = LCBCodeGenerationDataset.load(
|
||||
release_version=release_version)['test']
|
||||
release_version=release_version,
|
||||
start_date=start_date,
|
||||
end_date=end_date)['test']
|
||||
self.extractor_version = extractor_version
|
||||
|
||||
def score(self, predictions, references):
|
||||
def score(self, predictions, references, test_set):
|
||||
if len(predictions) != len(references):
|
||||
return {
|
||||
'error':
|
||||
'predictions and references have different '
|
||||
f'length. len(predictions): {len(predictions)}, '
|
||||
f'len(references): {len(references)}'
|
||||
}
|
||||
|
||||
if self.extractor_version == 'v1':
|
||||
predictions = [[extract_code_generation(item)]
|
||||
for item in predictions]
|
||||
@ -254,19 +266,24 @@ class LCBCodeGenerationEvaluator(BaseEvaluator):
|
||||
evaluation_samples[self.dataset[idx][
|
||||
'question_id']] = self.dataset[idx]['evaluation_sample']
|
||||
|
||||
references = [evaluation_samples[item] for item in references]
|
||||
filtered_predictions = []
|
||||
filtered_references = []
|
||||
for idx, item in enumerate(references):
|
||||
if item in self.dataset['question_id']:
|
||||
filtered_predictions.append(predictions[idx])
|
||||
filtered_references.append(item)
|
||||
|
||||
references = [{'input_output': item} for item in references]
|
||||
filtered_references = [evaluation_samples[item] for item in filtered_references] # noqa: E501
|
||||
|
||||
BaseEvaluator.is_num_equal(predictions, references)
|
||||
filtered_references = [{'input_output': item} for item in filtered_references] # noqa: E501
|
||||
|
||||
extracted_predictions = {}
|
||||
for idx, content in enumerate(predictions):
|
||||
for idx, content in enumerate(filtered_predictions):
|
||||
extracted_predictions[idx] = content
|
||||
|
||||
metrics, eval_results, final_metadata = codegen_metrics(
|
||||
references,
|
||||
predictions,
|
||||
filtered_references,
|
||||
filtered_predictions,
|
||||
k_list=[1],
|
||||
num_process_evaluate=self.num_process_evaluate,
|
||||
timeout=self.timeout,
|
||||
|
@ -7,6 +7,7 @@ import pickle
|
||||
import zlib
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
|
||||
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||
|
||||
@ -53,7 +54,9 @@ class LCBCodeGenerationDataset(BaseDataset):
|
||||
@staticmethod
|
||||
def load(path: str = 'opencompass/code_generation_lite',
|
||||
local_mode: bool = False,
|
||||
release_version: str = 'release_v1'):
|
||||
release_version: str = 'release_v1',
|
||||
start_date: str = None,
|
||||
end_date: str = None):
|
||||
|
||||
def transform(item):
|
||||
# Define the dataitem mapping logic
|
||||
@ -107,6 +110,15 @@ class LCBCodeGenerationDataset(BaseDataset):
|
||||
|
||||
dataset = dataset.map(transform)
|
||||
|
||||
if start_date is not None:
|
||||
p_start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
dataset = dataset.filter(
|
||||
lambda e: p_start_date <= datetime.fromisoformat(e['contest_date'])) # noqa: E501
|
||||
if end_date is not None:
|
||||
p_end_date = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
dataset = dataset.filter(
|
||||
lambda e: datetime.fromisoformat(e['contest_date']) <= p_end_date) # noqa: E501
|
||||
|
||||
return DatasetDict({'test': dataset, 'train': dataset})
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user