OpenCompass/opencompass/datasets/last_letters.py
2025-02-19 09:50:46 -05:00

37 lines
1.4 KiB
Python

import os
import re
import json
from datasets import Dataset, DatasetDict
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from .base import BaseDataset
from opencompass.utils.datasets import DEFAULT_DATA_FOLDER
from opencompass.utils.fileio import download_url
@LOAD_DATASET.register_module()
class LastLettersDataset(BaseDataset):
@staticmethod
def load(path: str):
cache_dir = os.environ.get('COMPASS_DATA_CACHE', '')
local_path = './data/last_letters/last_letters.json'
data_path = os.path.join(DEFAULT_DATA_FOLDER, cache_dir, local_path)
if not os.path.exists(data_path):
dataset_url = "https://raw.githubusercontent.com/wjn1996/Chain-of-Knowledge/refs/heads/main/tasks/Letter/dataset/last_letters.json"
download_url(dataset_url, os.path.dirname(data_path))
dataset = []
with open(data_path, 'r', encoding='utf-8') as f:
for ex in json.load(f)["examples"]:
dataset.append(ex)
dataset = Dataset.from_list(dataset)
return DatasetDict({'test': dataset})
@TEXT_POSTPROCESSORS.register_module('last_letters')
def last_letters_pred_postprocess(text: str) -> str:
text = text.split('answer is ')[-1]
text = re.sub("\"|\'|\n|\.|\s", "", text)
text = re.sub(r"^[^a-zA-Z]+|[^a-zA-Z]+$", "", text)
return text