2023-12-23 12:00:51 +08:00
|
|
|
import json
|
|
|
|
import re
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
from datasets import Dataset
|
|
|
|
|
|
|
|
from opencompass.datasets.base import BaseDataset
|
|
|
|
from opencompass.openicl import BaseEvaluator
|
|
|
|
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
|
|
|
|
|
|
|
|
|
|
|
@LOAD_DATASET.register_module()
|
|
|
|
class CDMEDataset(BaseDataset):
|
|
|
|
|
|
|
|
@staticmethod
|
2023-12-29 18:51:09 +08:00
|
|
|
def load(path: str, length: int, depth: int):
|
2023-12-23 12:00:51 +08:00
|
|
|
data = {'prompt': [], 'answer': []}
|
|
|
|
for file in Path(path).glob('*.jsonl'):
|
|
|
|
with open(file, 'r', encoding='utf-8') as f:
|
|
|
|
for line in f:
|
|
|
|
line = json.loads(line.strip())
|
2023-12-29 18:51:09 +08:00
|
|
|
if line['length'] == length and line['depth'] == depth:
|
|
|
|
data['prompt'].append(line['prompt'])
|
|
|
|
data['answer'].append(line['answer'])
|
2023-12-23 12:00:51 +08:00
|
|
|
|
|
|
|
dataset = Dataset.from_dict({
|
|
|
|
'prompt': data['prompt'],
|
|
|
|
'answer': data['answer'],
|
|
|
|
})
|
|
|
|
return dataset
|
|
|
|
|
|
|
|
|
|
|
|
class CDMEEvaluator(BaseEvaluator):
|
|
|
|
|
|
|
|
def levenshtein_distance(self, s1, s2):
|
|
|
|
if len(s1) < len(s2):
|
|
|
|
return self.levenshtein_distance(s2, s1)
|
|
|
|
|
|
|
|
if len(s2) == 0:
|
|
|
|
return len(s1)
|
|
|
|
|
|
|
|
previous_row = range(len(s2) + 1)
|
|
|
|
for i, c1 in enumerate(s1):
|
|
|
|
current_row = [i + 1]
|
|
|
|
for j, c2 in enumerate(s2):
|
|
|
|
insertions = previous_row[j + 1] + 1
|
|
|
|
deletions = current_row[j] + 1
|
|
|
|
substitutions = previous_row[j] + (c1 != c2)
|
|
|
|
current_row.append(min(insertions, deletions, substitutions))
|
|
|
|
previous_row = current_row
|
|
|
|
|
|
|
|
return previous_row[-1]
|
|
|
|
|
|
|
|
def score(self, predictions, references):
|
|
|
|
if len(predictions) != len(references):
|
|
|
|
return {
|
|
|
|
'error': 'predictions and references have different lengths'
|
|
|
|
}
|
|
|
|
|
|
|
|
total_score = 0
|
|
|
|
details = []
|
|
|
|
for prediction, reference in zip(predictions, references):
|
|
|
|
prediction = re.sub(r'\s+', '', prediction)
|
|
|
|
reference = re.sub(r'\s+', '', reference)
|
|
|
|
edit_distance = self.levenshtein_distance(prediction, reference)
|
|
|
|
max_len = max(len(prediction), len(reference))
|
|
|
|
score = 100 * (1 -
|
|
|
|
edit_distance / max_len) if max_len != 0 else 100
|
|
|
|
|
|
|
|
detail = {
|
|
|
|
'pred': prediction,
|
|
|
|
'answer': reference,
|
|
|
|
'edit_distance': edit_distance,
|
|
|
|
'score': score
|
|
|
|
}
|
|
|
|
total_score += score
|
|
|
|
details.append(detail)
|
|
|
|
|
|
|
|
average_score = total_score / len(predictions) if predictions else 0
|
|
|
|
result = {'score': average_score, 'details': details}
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('cdme')
|
|
|
|
def cdme_postprocess(text: str) -> str:
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('cdme_dataset')
|
|
|
|
def cdme_dataset_postprocess(text: str) -> str:
|
|
|
|
return text
|