2023-12-23 12:00:51 +08:00
|
|
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
|
|
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
|
|
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
2024-01-02 17:22:56 +08:00
|
|
|
from opencompass.datasets.cdme.cdme import CDMEDataset
|
|
|
|
from opencompass.datasets.cdme.cdme import CDMEEvaluator
|
|
|
|
from opencompass.datasets.cdme.cdme import cdme_postprocess
|
|
|
|
from opencompass.datasets.cdme.cdme import cdme_dataset_postprocess
|
2023-12-29 18:51:09 +08:00
|
|
|
import math
|
|
|
|
|
|
|
|
|
|
|
|
def logistic(x, L=100, x0=50, k=0.1):
|
|
|
|
return round(L / (1 + math.exp(-k * (x - x0))), 3)
|
|
|
|
|
|
|
|
|
|
|
|
def generate_linear_space(start, end, num):
|
2024-01-02 17:22:56 +08:00
|
|
|
if num == 1:
|
|
|
|
return [start]
|
|
|
|
elif num < 1:
|
|
|
|
raise ValueError("num must be at least 1.")
|
2023-12-29 18:51:09 +08:00
|
|
|
step = (end - start) / (num - 1)
|
|
|
|
return [start + step * i for i in range(num)]
|
|
|
|
|
|
|
|
|
|
|
|
def generate_depth_percents(intervals, interval_type):
|
|
|
|
if interval_type == 'linear':
|
|
|
|
return generate_linear_space(0, 100, intervals)
|
|
|
|
elif interval_type == 'sigmoid':
|
|
|
|
linear_space = generate_linear_space(0, 100, intervals)
|
|
|
|
return [logistic(x) for x in linear_space]
|
|
|
|
else:
|
|
|
|
raise ValueError('Unsupported interval type')
|
|
|
|
|
2023-12-23 12:00:51 +08:00
|
|
|
|
|
|
|
cdme_reader_cfg = dict(input_columns=['prompt'], output_column='answer')
|
|
|
|
|
|
|
|
cdme_infer_cfg = dict(
|
|
|
|
prompt_template=dict(
|
2023-12-29 18:51:09 +08:00
|
|
|
type=PromptTemplate,
|
|
|
|
template='''{prompt}'''),
|
2023-12-23 12:00:51 +08:00
|
|
|
retriever=dict(type=ZeroRetriever),
|
|
|
|
inferencer=dict(type=GenInferencer, max_out_len=512))
|
|
|
|
|
|
|
|
cdme_eval_cfg = dict(
|
2023-12-29 18:51:09 +08:00
|
|
|
evaluator=dict(type=CDMEEvaluator),
|
|
|
|
pred_postprocessor=dict(type=cdme_postprocess),
|
|
|
|
dataset_postprocessor=dict(type=cdme_dataset_postprocess),
|
|
|
|
pred_role='BOT')
|
2023-12-23 12:00:51 +08:00
|
|
|
|
2024-01-02 17:22:56 +08:00
|
|
|
context_lengths = list(range(1000, 201000, 1000))
|
|
|
|
document_depth_percent_intervals = 20
|
2023-12-29 18:51:09 +08:00
|
|
|
document_depth_percent_interval_type = "linear"
|
2023-12-23 12:00:51 +08:00
|
|
|
|
2024-01-02 17:22:56 +08:00
|
|
|
base_path = './data/CDME'
|
|
|
|
file_list = ['zh_finance.jsonl']
|
2023-12-23 12:00:51 +08:00
|
|
|
cdme_datasets = []
|
|
|
|
|
2023-12-29 18:51:09 +08:00
|
|
|
for original_context_length in context_lengths:
|
|
|
|
for depth_percent in generate_depth_percents(
|
|
|
|
document_depth_percent_intervals,
|
|
|
|
document_depth_percent_interval_type):
|
2024-01-02 17:22:56 +08:00
|
|
|
dataset_dict = {
|
|
|
|
'abbr': f'CDME_Length{original_context_length}'
|
|
|
|
f'Depth{int(depth_percent)}',
|
|
|
|
'type': CDMEDataset,
|
|
|
|
'path': base_path,
|
|
|
|
'length': original_context_length,
|
|
|
|
'depth': int(depth_percent),
|
|
|
|
'tokenizer_model': 'gpt-4',
|
|
|
|
'file_list': file_list,
|
|
|
|
'num_repeats_per_file': 10,
|
|
|
|
'length_buffer': 200,
|
|
|
|
'guide': True,
|
|
|
|
'language': 'Chinese',
|
|
|
|
'needle': '\n小明最喜欢的实习的地点就是上海人工智能实验室。\n',
|
|
|
|
'retrieval_question': '小明最喜欢的实习地点是哪里?请按照'
|
|
|
|
'“小明最喜欢的实习地点就是________。”的格式回答。',
|
|
|
|
'reader_cfg': cdme_reader_cfg,
|
|
|
|
'infer_cfg': cdme_infer_cfg,
|
|
|
|
'eval_cfg': cdme_eval_cfg
|
|
|
|
}
|
2023-12-23 12:00:51 +08:00
|
|
|
cdme_datasets.append(dataset_dict)
|