2024-01-08 22:07:24 +08:00
|
|
|
import copy
|
2023-12-25 21:59:16 +08:00
|
|
|
import csv
|
|
|
|
import json
|
|
|
|
import os
|
2024-01-08 22:07:24 +08:00
|
|
|
from typing import List
|
2023-12-25 21:59:16 +08:00
|
|
|
|
|
|
|
from datasets import Dataset
|
|
|
|
|
2024-01-08 22:07:24 +08:00
|
|
|
from opencompass.datasets.circular import (CircularDatasetMeta,
|
|
|
|
CircularEvaluator)
|
|
|
|
from opencompass.openicl.icl_evaluator import AccEvaluator, BaseEvaluator
|
2023-12-25 21:59:16 +08:00
|
|
|
from opencompass.openicl.icl_inferencer import GenInferencer, PPLInferencer
|
|
|
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
|
|
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
|
|
|
from opencompass.registry import LOAD_DATASET
|
2025-02-20 19:32:12 +08:00
|
|
|
from opencompass.utils import get_data_path
|
2023-12-25 21:59:16 +08:00
|
|
|
|
|
|
|
from .base import BaseDataset
|
|
|
|
|
|
|
|
|
2024-01-08 22:07:24 +08:00
|
|
|
class OptionSimAccEvaluator(BaseEvaluator):
|
|
|
|
|
|
|
|
def __init__(self, options) -> None:
|
|
|
|
super().__init__()
|
|
|
|
if not all((isinstance(i, str) and i.isupper() and len(i) == 1)
|
|
|
|
for i in options):
|
|
|
|
raise ValueError(
|
|
|
|
f'Each options should be single upper letter, got {options}')
|
|
|
|
|
|
|
|
self.options = options
|
|
|
|
|
|
|
|
def match_any_label(self, pred, test_item):
|
|
|
|
from rapidfuzz.distance import Levenshtein as L
|
|
|
|
|
|
|
|
from opencompass.utils.text_postprocessors import \
|
|
|
|
first_option_postprocess
|
|
|
|
|
|
|
|
pred = pred.strip()
|
|
|
|
if any([pred == i for i in self.options]):
|
|
|
|
parsed = pred
|
|
|
|
else:
|
|
|
|
parsed = ''
|
|
|
|
if parsed == '':
|
|
|
|
parsed = first_option_postprocess(pred,
|
|
|
|
''.join(self.options),
|
|
|
|
cushion=False)
|
|
|
|
if parsed == '':
|
|
|
|
possible_options = []
|
|
|
|
for opt in self.options:
|
|
|
|
opt_str = test_item[opt]
|
|
|
|
if opt_str is not None and opt_str.lower() in pred.lower():
|
|
|
|
possible_options.append(opt)
|
|
|
|
if len(possible_options) == 1:
|
|
|
|
parsed = possible_options[0]
|
|
|
|
if parsed == '':
|
|
|
|
dists = []
|
|
|
|
for opt in self.options:
|
|
|
|
opt_str = test_item[opt]
|
|
|
|
if opt_str is None:
|
|
|
|
continue
|
|
|
|
cands = [opt, opt_str, opt + '. ' + opt_str]
|
|
|
|
d = min(L.distance(pred, cand) for cand in cands)
|
|
|
|
dists.append((d, opt))
|
|
|
|
if len(dists) > 0:
|
|
|
|
parsed = min(dists)[1]
|
|
|
|
return parsed
|
|
|
|
|
|
|
|
def score(self, predictions: List, references: List, test_set) -> dict:
|
|
|
|
assert len(predictions) == len(references)
|
|
|
|
|
|
|
|
num_correct, num_total = 0, 0
|
|
|
|
details = {}
|
|
|
|
for index in range(len(predictions)):
|
|
|
|
pred = predictions[index]
|
|
|
|
refr = references[index]
|
|
|
|
parsed = self.match_any_label(pred, test_set[index])
|
|
|
|
num_correct += 1 if parsed == refr else 0
|
|
|
|
num_total += 1
|
|
|
|
details[str(index)] = {}
|
|
|
|
details[str(index)]['pred'] = pred
|
|
|
|
details[str(index)]['parsed'] = parsed
|
|
|
|
details[str(index)]['refr'] = refr
|
|
|
|
details[str(index)]['correct'] = parsed == refr
|
|
|
|
return {'accuracy': num_correct / num_total * 100, 'details': details}
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: DO NOT COPY YOURSELF!!!
|
|
|
|
class CircularOptionSimAccEvaluator(OptionSimAccEvaluator):
|
|
|
|
|
|
|
|
def __init__(self, options, circular_pattern='circular'):
|
|
|
|
super().__init__(options)
|
|
|
|
self.circular_pattern = circular_pattern
|
|
|
|
|
|
|
|
def score(self, predictions, references, test_set):
|
|
|
|
from opencompass.datasets.circular import (get_all_possible_patterns,
|
|
|
|
get_circular_patterns,
|
|
|
|
get_origin_patterns)
|
|
|
|
|
|
|
|
circular_patterns = {}
|
|
|
|
circular_patterns['origin'] = get_origin_patterns(
|
|
|
|
test_set[0]['circular_pattern'])
|
|
|
|
circular_patterns['circular'] = get_circular_patterns(
|
|
|
|
test_set[0]['circular_pattern'])
|
|
|
|
if self.circular_pattern == 'all_possible':
|
|
|
|
circular_patterns['all_possible'] = get_all_possible_patterns(
|
|
|
|
test_set[0]['circular_pattern'])
|
|
|
|
|
|
|
|
metrics = {}
|
|
|
|
tmp_metrics = {}
|
|
|
|
tmp_metrics.update({f'correct_{k}': 0 for k in circular_patterns})
|
|
|
|
tmp_metrics.update({f'count_{k}': 0 for k in circular_patterns})
|
|
|
|
# calculate the original accuracy
|
|
|
|
for pred, refr, origin_item in zip(predictions, references, test_set):
|
|
|
|
parsed = self.match_any_label(pred, origin_item)
|
|
|
|
circular_pattern = origin_item['circular_pattern']
|
|
|
|
for k in circular_patterns:
|
|
|
|
if tuple(circular_pattern) in circular_patterns[k]:
|
2025-02-20 19:32:12 +08:00
|
|
|
tmp_metrics[f'correct_{k}'] += 1 if parsed == refr else 0
|
2024-01-08 22:07:24 +08:00
|
|
|
tmp_metrics[f'count_{k}'] += 1
|
|
|
|
|
|
|
|
for k in circular_patterns:
|
|
|
|
metrics[f'acc_{k}'] = (tmp_metrics[f'correct_{k}'] /
|
|
|
|
tmp_metrics[f'count_{k}'] * 100)
|
|
|
|
|
|
|
|
# calculate the circular accuracy
|
|
|
|
_details = {k: {} for k in circular_patterns}
|
|
|
|
for pred, refr, origin_item in zip(predictions, references, test_set):
|
|
|
|
index = origin_item['qid']
|
|
|
|
parsed = self.match_any_label(pred, origin_item)
|
|
|
|
circular_pattern = origin_item['circular_pattern']
|
|
|
|
for k in circular_patterns:
|
|
|
|
if tuple(circular_pattern) in circular_patterns[k]:
|
|
|
|
_details[k].setdefault(
|
|
|
|
index, []).append(True if parsed == refr else False)
|
|
|
|
for k in _details:
|
|
|
|
_details[k] = {
|
|
|
|
index: sum(_details[k][index])
|
|
|
|
for index in _details[k]
|
|
|
|
}
|
|
|
|
for k in _details:
|
|
|
|
for j in range(1, len(circular_patterns[k]) + 1):
|
|
|
|
count = sum([_details[k][index] >= j for index in _details[k]])
|
|
|
|
total = len(_details[k])
|
|
|
|
if j != len(circular_patterns[k]):
|
|
|
|
metrics[f'more_{j}_{k}'] = count / total * 100
|
|
|
|
else:
|
|
|
|
metrics[f'perf_{k}'] = count / total * 100
|
|
|
|
|
|
|
|
# make details
|
|
|
|
details = {}
|
|
|
|
for index in range(len(predictions)):
|
|
|
|
parsed = self.match_any_label(predictions[index], test_set[index])
|
|
|
|
details[str(index)] = {}
|
|
|
|
if 'question' in test_set[index]:
|
|
|
|
details[str(index)]['question'] = test_set[index]['question']
|
|
|
|
details[str(index)]['pred'] = predictions[index]
|
|
|
|
details[str(index)]['parsed'] = parsed
|
|
|
|
details[str(index)]['refr'] = references[index]
|
|
|
|
details[str(index)]['correct'] = parsed == references[index]
|
|
|
|
metrics['details'] = details
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
2023-12-25 21:59:16 +08:00
|
|
|
@LOAD_DATASET.register_module()
|
|
|
|
class CustomDataset(BaseDataset):
|
|
|
|
|
|
|
|
@staticmethod
|
2025-02-20 19:32:12 +08:00
|
|
|
def load(path, file_name=None, local_mode=False):
|
|
|
|
path = get_data_path(path, local_mode=local_mode)
|
|
|
|
if file_name is not None:
|
|
|
|
path = os.path.join(path, file_name)
|
2023-12-25 21:59:16 +08:00
|
|
|
if path.endswith('.jsonl'):
|
2024-01-08 22:07:24 +08:00
|
|
|
with open(path, 'r', encoding='utf-8-sig') as f:
|
2023-12-25 21:59:16 +08:00
|
|
|
data = [json.loads(line) for line in f]
|
|
|
|
elif path.endswith('.csv'):
|
2024-01-08 22:07:24 +08:00
|
|
|
with open(path, 'r', encoding='utf-8-sig') as f:
|
2023-12-25 21:59:16 +08:00
|
|
|
reader = csv.reader(f)
|
|
|
|
header = next(reader)
|
|
|
|
data = [dict(zip(header, row)) for row in reader]
|
|
|
|
else:
|
|
|
|
raise ValueError(f'Unsupported file format: {path}')
|
|
|
|
|
|
|
|
return Dataset.from_list(data)
|
|
|
|
|
|
|
|
|
2025-03-21 20:09:25 +08:00
|
|
|
@LOAD_DATASET.register_module()
|
|
|
|
class CodeCustomDataset(BaseDataset):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load(path, file_name=None, local_mode=False, num_repeats=1, **kwargs):
|
|
|
|
path = get_data_path(path, local_mode=local_mode)
|
|
|
|
if file_name is not None:
|
|
|
|
path = os.path.join(path, file_name)
|
|
|
|
data = []
|
|
|
|
if path.endswith('.jsonl'):
|
|
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
for line in f:
|
|
|
|
data.extend(
|
|
|
|
[json.loads(line.strip()) for _ in range(num_repeats)])
|
|
|
|
elif path.endswith('.csv'):
|
|
|
|
with open(path, 'r', encoding='utf-8-sig') as f:
|
|
|
|
reader = csv.reader(f)
|
|
|
|
header = next(reader)
|
|
|
|
for row in reader:
|
|
|
|
data.extend(
|
|
|
|
[dict(zip(header, row)) for _ in range(num_repeats)])
|
|
|
|
else:
|
|
|
|
raise ValueError(f'Unsupported file format: {path}')
|
|
|
|
|
|
|
|
return Dataset.from_list(data)
|
|
|
|
|
|
|
|
|
2024-01-08 22:07:24 +08:00
|
|
|
class CircularCustomDataset(CustomDataset, metaclass=CircularDatasetMeta):
|
|
|
|
dataset_class = CustomDataset
|
|
|
|
|
|
|
|
|
2023-12-25 21:59:16 +08:00
|
|
|
def stringfy_types(obj):
|
|
|
|
for k, v in obj.items():
|
|
|
|
if k == 'type':
|
|
|
|
obj[k] = f'{v.__module__}.{v.__name__}'
|
|
|
|
elif isinstance(v, dict):
|
|
|
|
stringfy_types(v)
|
|
|
|
return obj
|
|
|
|
|
|
|
|
|
|
|
|
def make_mcq_gen_config(meta):
|
|
|
|
if meta.get('template', None) is None:
|
|
|
|
_human_prompt = 'Question: {question}' + ''.join(
|
|
|
|
[f'\n{item}. {{{item}}}' for item in meta['options']])
|
|
|
|
human_prompt = meta.get('human_prompt', _human_prompt)
|
|
|
|
_bot_prompt = f'Answer: {{{meta["output_column"]}}}'
|
|
|
|
bot_prompt = meta.get('bot_prompt', _bot_prompt)
|
|
|
|
template = dict(round=[
|
|
|
|
dict(role='HUMAN', prompt=human_prompt),
|
|
|
|
dict(role='BOT', prompt=bot_prompt),
|
|
|
|
])
|
|
|
|
else:
|
|
|
|
template = meta['template']
|
|
|
|
|
|
|
|
reader_cfg = dict(
|
|
|
|
input_columns=meta['input_columns'],
|
|
|
|
output_column=meta['output_column'],
|
|
|
|
)
|
2024-04-09 17:50:23 +08:00
|
|
|
if 'test_range' in meta:
|
|
|
|
reader_cfg['test_range'] = meta['test_range']
|
2023-12-25 21:59:16 +08:00
|
|
|
infer_cfg = dict(
|
|
|
|
prompt_template=dict(
|
|
|
|
type=PromptTemplate,
|
|
|
|
template=template,
|
|
|
|
),
|
|
|
|
retriever=dict(type=ZeroRetriever),
|
|
|
|
inferencer=dict(type=GenInferencer),
|
|
|
|
)
|
|
|
|
|
2024-01-08 22:07:24 +08:00
|
|
|
eval_cfg = dict(
|
2025-02-20 19:32:12 +08:00
|
|
|
evaluator=dict(
|
|
|
|
type=meta.get('evaluator', OptionSimAccEvaluator),
|
|
|
|
**meta.get('evaluator_kwargs', {'options': meta['options']}),
|
|
|
|
),
|
2024-01-08 22:07:24 +08:00
|
|
|
pred_role='BOT',
|
|
|
|
)
|
2023-12-25 21:59:16 +08:00
|
|
|
|
|
|
|
dataset = dict(
|
|
|
|
abbr=meta['abbr'],
|
|
|
|
type=CustomDataset,
|
|
|
|
path=meta['path'],
|
|
|
|
reader_cfg=reader_cfg,
|
|
|
|
infer_cfg=infer_cfg,
|
|
|
|
eval_cfg=eval_cfg,
|
|
|
|
)
|
|
|
|
return dataset
|
|
|
|
|
|
|
|
|
2024-01-08 22:07:24 +08:00
|
|
|
def make_circular_mcq_gen_config(meta):
|
|
|
|
if meta.get('template', None) is None:
|
|
|
|
_human_prompt = 'Question: {question}' + ''.join(
|
|
|
|
[f'\n{item}. {{{item}}}' for item in meta['options']])
|
|
|
|
human_prompt = meta.get('human_prompt', _human_prompt)
|
|
|
|
_bot_prompt = f'Answer: {{{meta["output_column"]}}}'
|
|
|
|
bot_prompt = meta.get('bot_prompt', _bot_prompt)
|
|
|
|
template = dict(round=[
|
|
|
|
dict(role='HUMAN', prompt=human_prompt),
|
|
|
|
dict(role='BOT', prompt=bot_prompt),
|
|
|
|
])
|
|
|
|
else:
|
|
|
|
template = meta['template']
|
|
|
|
|
|
|
|
reader_cfg = dict(
|
|
|
|
input_columns=meta['input_columns'],
|
|
|
|
output_column=meta['output_column'],
|
|
|
|
)
|
2024-04-09 17:50:23 +08:00
|
|
|
if 'test_range' in meta:
|
|
|
|
reader_cfg['test_range'] = meta['test_range']
|
2024-01-08 22:07:24 +08:00
|
|
|
infer_cfg = dict(
|
|
|
|
prompt_template=dict(
|
|
|
|
type=PromptTemplate,
|
|
|
|
template=template,
|
|
|
|
),
|
|
|
|
retriever=dict(type=ZeroRetriever),
|
|
|
|
inferencer=dict(type=GenInferencer),
|
|
|
|
)
|
|
|
|
|
|
|
|
eval_cfg = dict(
|
2025-02-20 19:32:12 +08:00
|
|
|
evaluator=dict(
|
|
|
|
type=meta.get('evaluator', CircularOptionSimAccEvaluator),
|
|
|
|
**meta.get('evaluator_kwargs', {'options': meta['options']}),
|
|
|
|
),
|
2024-01-08 22:07:24 +08:00
|
|
|
pred_role='BOT',
|
|
|
|
)
|
|
|
|
|
|
|
|
dataset = dict(
|
|
|
|
abbr=meta['abbr'],
|
|
|
|
type=CircularCustomDataset,
|
|
|
|
option_keys=meta['options'],
|
|
|
|
answer_key=meta['output_column'],
|
|
|
|
path=meta['path'],
|
|
|
|
reader_cfg=reader_cfg,
|
|
|
|
infer_cfg=infer_cfg,
|
|
|
|
eval_cfg=eval_cfg,
|
|
|
|
)
|
|
|
|
return dataset
|
|
|
|
|
|
|
|
|
2023-12-25 21:59:16 +08:00
|
|
|
def make_qa_gen_config(meta):
|
|
|
|
if meta.get('template', None) is None:
|
|
|
|
human_prompt = meta.get('human_prompt', '{question}')
|
|
|
|
if meta['output_column'] is None:
|
|
|
|
template = dict(round=[
|
|
|
|
dict(role='HUMAN', prompt=human_prompt),
|
|
|
|
])
|
|
|
|
else:
|
|
|
|
bot_prompt = meta.get('bot_prompt', f'{{{meta["output_column"]}}}')
|
|
|
|
template = dict(round=[
|
|
|
|
dict(role='HUMAN', prompt=human_prompt),
|
|
|
|
dict(role='BOT', prompt=bot_prompt),
|
|
|
|
])
|
|
|
|
else:
|
|
|
|
template = meta['template']
|
|
|
|
reader_cfg = dict(
|
|
|
|
input_columns=meta['input_columns'],
|
|
|
|
output_column=meta['output_column'],
|
|
|
|
)
|
2024-04-09 17:50:23 +08:00
|
|
|
if 'test_range' in meta:
|
|
|
|
reader_cfg['test_range'] = meta['test_range']
|
2023-12-25 21:59:16 +08:00
|
|
|
infer_cfg = dict(
|
|
|
|
prompt_template=dict(
|
|
|
|
type=PromptTemplate,
|
|
|
|
template=template,
|
|
|
|
),
|
|
|
|
retriever=dict(type=ZeroRetriever),
|
|
|
|
inferencer=dict(type=GenInferencer),
|
|
|
|
)
|
|
|
|
|
|
|
|
eval_cfg = dict(
|
2025-02-20 19:32:12 +08:00
|
|
|
evaluator=dict(
|
|
|
|
type=meta.get('evaluator', AccEvaluator),
|
|
|
|
**meta.get('evaluator_kwargs', {}),
|
|
|
|
),
|
2023-12-25 21:59:16 +08:00
|
|
|
pred_role='BOT',
|
|
|
|
)
|
|
|
|
|
|
|
|
dataset = dict(
|
|
|
|
abbr=meta['abbr'],
|
|
|
|
type=CustomDataset,
|
|
|
|
path=meta['path'],
|
|
|
|
reader_cfg=reader_cfg,
|
|
|
|
infer_cfg=infer_cfg,
|
|
|
|
eval_cfg=eval_cfg,
|
|
|
|
)
|
|
|
|
return dataset
|
|
|
|
|
|
|
|
|
|
|
|
def make_mcq_ppl_config(meta):
|
|
|
|
if meta.get('template', None) is None:
|
|
|
|
_human_prompt = 'Question: {question}' + ''.join(
|
|
|
|
[f'\n{item}. {{{item}}}' for item in meta['options']])
|
|
|
|
human_prompt = meta.get('human_prompt', _human_prompt)
|
|
|
|
_bot_prompt = f'Answer: {{{meta["output_column"]}}}'
|
|
|
|
bot_prompt = meta.get('bot_prompt', _bot_prompt)
|
|
|
|
template = {
|
|
|
|
answer: dict(round=[
|
|
|
|
dict(role='HUMAN', prompt=human_prompt),
|
2025-02-20 19:32:12 +08:00
|
|
|
dict(
|
|
|
|
role='BOT',
|
|
|
|
prompt=bot_prompt.format(
|
|
|
|
**{meta['output_column']: answer}),
|
|
|
|
),
|
2023-12-25 21:59:16 +08:00
|
|
|
], )
|
|
|
|
for answer in meta['options']
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
template = meta['template']
|
|
|
|
|
|
|
|
reader_cfg = dict(
|
|
|
|
input_columns=meta['input_columns'],
|
|
|
|
output_column=meta['output_column'],
|
|
|
|
)
|
2024-04-09 17:50:23 +08:00
|
|
|
if 'test_range' in meta:
|
|
|
|
reader_cfg['test_range'] = meta['test_range']
|
2023-12-25 21:59:16 +08:00
|
|
|
infer_cfg = dict(
|
|
|
|
prompt_template=dict(
|
|
|
|
type=PromptTemplate,
|
|
|
|
template=template,
|
|
|
|
),
|
|
|
|
retriever=dict(type=ZeroRetriever),
|
|
|
|
inferencer=dict(type=PPLInferencer),
|
|
|
|
)
|
|
|
|
|
2025-02-20 19:32:12 +08:00
|
|
|
eval_cfg = dict(evaluator=dict(
|
|
|
|
type=meta.get('evaluator', AccEvaluator),
|
|
|
|
**meta.get('evaluator_kwargs', {}),
|
|
|
|
))
|
2023-12-25 21:59:16 +08:00
|
|
|
|
|
|
|
dataset = dict(
|
|
|
|
abbr=meta['abbr'],
|
|
|
|
type=CustomDataset,
|
|
|
|
path=meta['path'],
|
|
|
|
reader_cfg=reader_cfg,
|
|
|
|
infer_cfg=infer_cfg,
|
|
|
|
eval_cfg=eval_cfg,
|
|
|
|
)
|
|
|
|
return dataset
|
|
|
|
|
|
|
|
|
2024-01-08 22:07:24 +08:00
|
|
|
def make_circular_mcq_ppl_config(meta):
|
|
|
|
if meta.get('template', None) is None:
|
|
|
|
_human_prompt = 'Question: {question}' + ''.join(
|
|
|
|
[f'\n{item}. {{{item}}}' for item in meta['options']])
|
|
|
|
human_prompt = meta.get('human_prompt', _human_prompt)
|
|
|
|
_bot_prompt = f'Answer: {{{meta["output_column"]}}}'
|
|
|
|
bot_prompt = meta.get('bot_prompt', _bot_prompt)
|
|
|
|
template = {
|
|
|
|
answer: dict(round=[
|
|
|
|
dict(role='HUMAN', prompt=human_prompt),
|
2025-02-20 19:32:12 +08:00
|
|
|
dict(
|
|
|
|
role='BOT',
|
|
|
|
prompt=bot_prompt.format(
|
|
|
|
**{meta['output_column']: answer}),
|
|
|
|
),
|
2024-01-08 22:07:24 +08:00
|
|
|
], )
|
|
|
|
for answer in meta['options']
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
template = meta['template']
|
|
|
|
|
|
|
|
reader_cfg = dict(
|
|
|
|
input_columns=meta['input_columns'],
|
|
|
|
output_column=meta['output_column'],
|
|
|
|
)
|
2024-04-09 17:50:23 +08:00
|
|
|
if 'test_range' in meta:
|
|
|
|
reader_cfg['test_range'] = meta['test_range']
|
2024-01-08 22:07:24 +08:00
|
|
|
infer_cfg = dict(
|
|
|
|
prompt_template=dict(
|
|
|
|
type=PromptTemplate,
|
|
|
|
template=template,
|
|
|
|
),
|
|
|
|
retriever=dict(type=ZeroRetriever),
|
|
|
|
inferencer=dict(type=PPLInferencer),
|
|
|
|
)
|
|
|
|
|
2025-02-20 19:32:12 +08:00
|
|
|
eval_cfg = dict(evaluator=dict(
|
|
|
|
type=meta.get('evaluator', CircularEvaluator),
|
|
|
|
**meta.get('evaluator_kwargs', {}),
|
|
|
|
))
|
2024-01-08 22:07:24 +08:00
|
|
|
|
|
|
|
dataset = dict(
|
|
|
|
abbr=meta['abbr'],
|
|
|
|
type=CircularCustomDataset,
|
|
|
|
option_keys=meta['options'],
|
|
|
|
answer_key=meta['output_column'],
|
|
|
|
path=meta['path'],
|
|
|
|
reader_cfg=reader_cfg,
|
|
|
|
infer_cfg=infer_cfg,
|
|
|
|
eval_cfg=eval_cfg,
|
|
|
|
)
|
|
|
|
return dataset
|
|
|
|
|
|
|
|
|
2023-12-25 21:59:16 +08:00
|
|
|
def parse_example_dataset(config):
|
2024-01-08 22:07:24 +08:00
|
|
|
# config -> .meta.jsonl -> parsed_results
|
2023-12-25 21:59:16 +08:00
|
|
|
path = config['path']
|
|
|
|
|
2024-01-08 22:07:24 +08:00
|
|
|
# load sample and get parsed_meta
|
|
|
|
parsed_meta = {}
|
2023-12-25 21:59:16 +08:00
|
|
|
if path.endswith('.jsonl'):
|
|
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
data_item = json.loads(f.readline())
|
|
|
|
elif path.endswith('.csv'):
|
|
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
reader = csv.reader(f)
|
|
|
|
header = next(reader)
|
|
|
|
row = next(reader)
|
|
|
|
data_item = dict(zip(header, row))
|
|
|
|
else:
|
|
|
|
raise ValueError(f'Unsupported ext: {path}, .jsonl or .csv required')
|
|
|
|
|
2024-01-08 22:07:24 +08:00
|
|
|
parsed_meta['path'] = path
|
2023-12-25 21:59:16 +08:00
|
|
|
input_columns = [i for i in data_item.keys() if i != 'answer']
|
2024-01-08 22:07:24 +08:00
|
|
|
parsed_meta['input_columns'] = input_columns
|
2023-12-25 21:59:16 +08:00
|
|
|
output_column = 'answer' if 'answer' in data_item else None
|
2024-01-08 22:07:24 +08:00
|
|
|
parsed_meta['output_column'] = output_column
|
2023-12-25 21:59:16 +08:00
|
|
|
options = []
|
|
|
|
for i in range(26):
|
|
|
|
i = chr(ord('A') + i)
|
|
|
|
if i in data_item:
|
|
|
|
options.append(i)
|
|
|
|
else:
|
|
|
|
break
|
2024-01-08 22:07:24 +08:00
|
|
|
parsed_meta['options'] = options
|
2023-12-25 21:59:16 +08:00
|
|
|
abbr = os.path.basename(path).split('.')[0]
|
2024-01-08 22:07:24 +08:00
|
|
|
parsed_meta['abbr'] = abbr
|
|
|
|
parsed_meta['data_type'] = 'mcq' if len(options) > 1 else 'qa'
|
|
|
|
parsed_meta['infer_method'] = 'gen'
|
2023-12-25 21:59:16 +08:00
|
|
|
|
2024-01-08 22:07:24 +08:00
|
|
|
# try to read meta json
|
|
|
|
meta_path = config.get('meta_path', path + '.meta.json')
|
|
|
|
if os.path.exists(meta_path):
|
|
|
|
with open(meta_path, 'r', encoding='utf-8') as f:
|
|
|
|
read_from_file_meta = json.load(f)
|
2023-12-25 21:59:16 +08:00
|
|
|
else:
|
2024-01-08 22:07:24 +08:00
|
|
|
read_from_file_meta = {}
|
|
|
|
|
|
|
|
# get config meta
|
|
|
|
config_meta = copy.deepcopy(config)
|
|
|
|
|
|
|
|
# merge meta
|
|
|
|
meta = {}
|
|
|
|
meta.update(parsed_meta)
|
|
|
|
meta.update(read_from_file_meta)
|
|
|
|
meta.update(config_meta)
|
2023-12-25 21:59:16 +08:00
|
|
|
|
|
|
|
return meta
|
|
|
|
|
|
|
|
|
|
|
|
def make_custom_dataset_config(config):
|
|
|
|
# considered as a custom dataset
|
|
|
|
meta = parse_example_dataset(config)
|
|
|
|
make_config_func = {
|
|
|
|
('mcq', 'gen'): make_mcq_gen_config,
|
|
|
|
('mcq', 'ppl'): make_mcq_ppl_config,
|
|
|
|
('qa', 'gen'): make_qa_gen_config,
|
2024-01-08 22:07:24 +08:00
|
|
|
('circular-mcq', 'gen'): make_circular_mcq_gen_config,
|
|
|
|
('circular-mcq', 'ppl'): make_circular_mcq_ppl_config,
|
2023-12-25 21:59:16 +08:00
|
|
|
}.get((meta['data_type'], meta['infer_method']), None)
|
|
|
|
if make_config_func is None:
|
|
|
|
raise ValueError(f'Unsupported dataset data_type: {meta["data_type"]}'
|
|
|
|
f' and infer_method: {meta["infer_method"]}')
|
|
|
|
dataset = make_config_func(meta)
|
|
|
|
dataset = stringfy_types(dataset)
|
|
|
|
return dataset
|