OpenCompass/opencompass/datasets/babilong/babilong.py

107 lines
3.0 KiB
Python
Raw Normal View History

# flake8: noqa: F401, E501
import json
import os
from datasets import Dataset
from opencompass.datasets.babilong.babilong_utils import compare_answers
from opencompass.datasets.babilong.prompts import (DEFAULT_PROMPTS,
DEFAULT_TEMPLATE,
get_formatted_input)
from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET
from opencompass.utils import get_data_path
@LOAD_DATASET.register_module()
class BabiLongDataset(BaseDataset):
@staticmethod
def load(
path,
task,
split_name,
use_instruction=True,
use_examples=True,
use_post_prompt=True,
) -> Dataset:
assert task in [
'qa1',
'qa2',
'qa3',
'qa4',
'qa5',
'qa6',
'qa7',
'qa8',
'qa9',
'qa10',
], f"Task must be in ['qa1', 'qa2', 'qa3', 'qa4', 'qa5', 'qa6', 'qa7', 'qa8', 'qa9', 'qa10']"
assert split_name in [
'0k',
'1k',
'2k',
'4k',
'8k',
'16k',
'32k',
'64k',
'128k',
'256k',
'512k',
'1m',
], f"Split name must be in ['0k', '1k', '2k', '4k', '8k', '16k', '32k', '64k', '128k', '256k', '512k', '1m']"
# configure the prompt
prompt_cfg = {
'instruction':
(DEFAULT_PROMPTS[task]['instruction'] if use_instruction else ''),
'examples':
(DEFAULT_PROMPTS[task]['examples'] if use_examples else ''),
'post_prompt':
(DEFAULT_PROMPTS[task]['post_prompt'] if use_post_prompt else ''),
'template':
DEFAULT_TEMPLATE,
}
path = get_data_path(path)
file = os.path.join(path, task, f'{split_name}.json')
with open(file, 'r') as f:
task_data = json.load(f)
data = []
for sample in task_data:
tmp_data = {'prompt': [], 'answer': []}
target = sample['target']
context = sample['input']
question = sample['question']
input_text = get_formatted_input(
context,
question,
prompt_cfg['examples'],
prompt_cfg['instruction'],
prompt_cfg['post_prompt'],
template=DEFAULT_TEMPLATE,
)
tmp_data['prompt'].append(input_text)
tmp_data['answer'].append(target)
data.append(tmp_data)
return Dataset.from_list(data)
class BabiLongEvaluator(BaseEvaluator):
def score(self, predictions, gold):
assert len(predictions) == len(gold)
score = (sum([
compare_answers(str(ref[0]), pred)
for pred, ref in zip(predictions, gold)
]) / len(predictions) * 100)
result = {'score': round(score, 2)}
return result