mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
code alignment
This commit is contained in:
parent
73c80953c6
commit
7ae59ef7f9
@ -233,15 +233,27 @@ class LCBCodeGenerationEvaluator(BaseEvaluator):
|
|||||||
num_process_evaluate,
|
num_process_evaluate,
|
||||||
timeout=6,
|
timeout=6,
|
||||||
release_version='release_v1',
|
release_version='release_v1',
|
||||||
extractor_version='v1'):
|
extractor_version='v1',
|
||||||
|
start_date=None,
|
||||||
|
end_date=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_process_evaluate = num_process_evaluate
|
self.num_process_evaluate = num_process_evaluate
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.dataset = LCBCodeGenerationDataset.load(
|
self.dataset = LCBCodeGenerationDataset.load(
|
||||||
release_version=release_version)['test']
|
release_version=release_version,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date)['test']
|
||||||
self.extractor_version = extractor_version
|
self.extractor_version = extractor_version
|
||||||
|
|
||||||
def score(self, predictions, references):
|
def score(self, predictions, references, test_set):
|
||||||
|
if len(predictions) != len(references):
|
||||||
|
return {
|
||||||
|
'error':
|
||||||
|
'predictions and references have different '
|
||||||
|
f'length. len(predictions): {len(predictions)}, '
|
||||||
|
f'len(references): {len(references)}'
|
||||||
|
}
|
||||||
|
|
||||||
if self.extractor_version == 'v1':
|
if self.extractor_version == 'v1':
|
||||||
predictions = [[extract_code_generation(item)]
|
predictions = [[extract_code_generation(item)]
|
||||||
for item in predictions]
|
for item in predictions]
|
||||||
@ -254,19 +266,24 @@ class LCBCodeGenerationEvaluator(BaseEvaluator):
|
|||||||
evaluation_samples[self.dataset[idx][
|
evaluation_samples[self.dataset[idx][
|
||||||
'question_id']] = self.dataset[idx]['evaluation_sample']
|
'question_id']] = self.dataset[idx]['evaluation_sample']
|
||||||
|
|
||||||
references = [evaluation_samples[item] for item in references]
|
filtered_predictions = []
|
||||||
|
filtered_references = []
|
||||||
|
for idx, item in enumerate(references):
|
||||||
|
if item in self.dataset['question_id']:
|
||||||
|
filtered_predictions.append(predictions[idx])
|
||||||
|
filtered_references.append(item)
|
||||||
|
|
||||||
references = [{'input_output': item} for item in references]
|
filtered_references = [evaluation_samples[item] for item in filtered_references] # noqa: E501
|
||||||
|
|
||||||
BaseEvaluator.is_num_equal(predictions, references)
|
filtered_references = [{'input_output': item} for item in filtered_references] # noqa: E501
|
||||||
|
|
||||||
extracted_predictions = {}
|
extracted_predictions = {}
|
||||||
for idx, content in enumerate(predictions):
|
for idx, content in enumerate(filtered_predictions):
|
||||||
extracted_predictions[idx] = content
|
extracted_predictions[idx] = content
|
||||||
|
|
||||||
metrics, eval_results, final_metadata = codegen_metrics(
|
metrics, eval_results, final_metadata = codegen_metrics(
|
||||||
references,
|
filtered_references,
|
||||||
predictions,
|
filtered_predictions,
|
||||||
k_list=[1],
|
k_list=[1],
|
||||||
num_process_evaluate=self.num_process_evaluate,
|
num_process_evaluate=self.num_process_evaluate,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
|
@ -7,6 +7,7 @@ import pickle
|
|||||||
import zlib
|
import zlib
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from datasets import DatasetDict, load_dataset, load_from_disk
|
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||||
|
|
||||||
@ -53,7 +54,9 @@ class LCBCodeGenerationDataset(BaseDataset):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def load(path: str = 'opencompass/code_generation_lite',
|
def load(path: str = 'opencompass/code_generation_lite',
|
||||||
local_mode: bool = False,
|
local_mode: bool = False,
|
||||||
release_version: str = 'release_v1'):
|
release_version: str = 'release_v1',
|
||||||
|
start_date: str = None,
|
||||||
|
end_date: str = None):
|
||||||
|
|
||||||
def transform(item):
|
def transform(item):
|
||||||
# Define the dataitem mapping logic
|
# Define the dataitem mapping logic
|
||||||
@ -107,6 +110,15 @@ class LCBCodeGenerationDataset(BaseDataset):
|
|||||||
|
|
||||||
dataset = dataset.map(transform)
|
dataset = dataset.map(transform)
|
||||||
|
|
||||||
|
if start_date is not None:
|
||||||
|
p_start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
|
dataset = dataset.filter(
|
||||||
|
lambda e: p_start_date <= datetime.fromisoformat(e['contest_date'])) # noqa: E501
|
||||||
|
if end_date is not None:
|
||||||
|
p_end_date = datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
|
dataset = dataset.filter(
|
||||||
|
lambda e: datetime.fromisoformat(e['contest_date']) <= p_end_date) # noqa: E501
|
||||||
|
|
||||||
return DatasetDict({'test': dataset, 'train': dataset})
|
return DatasetDict({'test': dataset, 'train': dataset})
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user