mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Update] Change NeedleInAHaystackDataset to dynamic dataset loading (#754)
* Add NeedleInAHaystack Test * Apply pre-commit formatting * Update configs/eval_hf_internlm_chat_20b_cdme.py Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com> * add needle in haystack test * update needle in haystack test * update plot function in tools_needleinahaystack.py * optimizing needleinahaystack dataset generation strategy * modify minor formatting issues * add English version support * change NeedleInAHaystackDataset to dynamic loading * change NeedleInAHaystackDataset to dynamic loading * fix needleinahaystack test eval bug * fix needleinahaystack config bug --------- Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>
This commit is contained in:
parent
b69fe2343b
commit
33f8df1ca3
81
configs/datasets/cdme/cdme200k.py
Normal file
81
configs/datasets/cdme/cdme200k.py
Normal file
@ -0,0 +1,81 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets.cdme.cdme import CDMEDataset
|
||||
from opencompass.datasets.cdme.cdme import CDMEEvaluator
|
||||
from opencompass.datasets.cdme.cdme import cdme_postprocess
|
||||
from opencompass.datasets.cdme.cdme import cdme_dataset_postprocess
|
||||
import math
|
||||
|
||||
|
||||
def logistic(x, L=100, x0=50, k=0.1):
|
||||
return round(L / (1 + math.exp(-k * (x - x0))), 3)
|
||||
|
||||
|
||||
def generate_linear_space(start, end, num):
|
||||
if num == 1:
|
||||
return [start]
|
||||
elif num < 1:
|
||||
raise ValueError("num must be at least 1.")
|
||||
step = (end - start) / (num - 1)
|
||||
return [start + step * i for i in range(num)]
|
||||
|
||||
|
||||
def generate_depth_percents(intervals, interval_type):
|
||||
if interval_type == 'linear':
|
||||
return generate_linear_space(0, 100, intervals)
|
||||
elif interval_type == 'sigmoid':
|
||||
linear_space = generate_linear_space(0, 100, intervals)
|
||||
return [logistic(x) for x in linear_space]
|
||||
else:
|
||||
raise ValueError('Unsupported interval type')
|
||||
|
||||
|
||||
cdme_reader_cfg = dict(input_columns=['prompt'], output_column='answer')
|
||||
|
||||
cdme_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template='''{prompt}'''),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512))
|
||||
|
||||
cdme_eval_cfg = dict(
|
||||
evaluator=dict(type=CDMEEvaluator),
|
||||
pred_postprocessor=dict(type=cdme_postprocess),
|
||||
dataset_postprocessor=dict(type=cdme_dataset_postprocess),
|
||||
pred_role='BOT')
|
||||
|
||||
context_lengths = list(range(1000, 201000, 1000))
|
||||
document_depth_percent_intervals = 20
|
||||
document_depth_percent_interval_type = "linear"
|
||||
|
||||
base_path = './data/CDME'
|
||||
file_list = ['zh_finance.jsonl']
|
||||
cdme_datasets = []
|
||||
|
||||
for original_context_length in context_lengths:
|
||||
for depth_percent in generate_depth_percents(
|
||||
document_depth_percent_intervals,
|
||||
document_depth_percent_interval_type):
|
||||
dataset_dict = {
|
||||
'abbr': f'CDME_Length{original_context_length}'
|
||||
f'Depth{int(depth_percent)}',
|
||||
'type': CDMEDataset,
|
||||
'path': base_path,
|
||||
'length': original_context_length,
|
||||
'depth': int(depth_percent),
|
||||
'tokenizer_model': 'gpt-4',
|
||||
'file_list': file_list,
|
||||
'num_repeats_per_file': 10,
|
||||
'length_buffer': 200,
|
||||
'guide': True,
|
||||
'language': 'Chinese',
|
||||
'needle': '\n小明最喜欢的实习的地点就是上海人工智能实验室。\n',
|
||||
'retrieval_question': '小明最喜欢的实习地点是哪里?请按照'
|
||||
'“小明最喜欢的实习地点就是________。”的格式回答。',
|
||||
'reader_cfg': cdme_reader_cfg,
|
||||
'infer_cfg': cdme_infer_cfg,
|
||||
'eval_cfg': cdme_eval_cfg
|
||||
}
|
||||
cdme_datasets.append(dataset_dict)
|
81
configs/datasets/cdme/cdme32k.py
Normal file
81
configs/datasets/cdme/cdme32k.py
Normal file
@ -0,0 +1,81 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets.cdme.cdme import CDMEDataset
|
||||
from opencompass.datasets.cdme.cdme import CDMEEvaluator
|
||||
from opencompass.datasets.cdme.cdme import cdme_postprocess
|
||||
from opencompass.datasets.cdme.cdme import cdme_dataset_postprocess
|
||||
import math
|
||||
|
||||
|
||||
def logistic(x, L=100, x0=50, k=0.1):
|
||||
return round(L / (1 + math.exp(-k * (x - x0))), 3)
|
||||
|
||||
|
||||
def generate_linear_space(start, end, num):
|
||||
if num == 1:
|
||||
return [start]
|
||||
elif num < 1:
|
||||
raise ValueError("num must be at least 1.")
|
||||
step = (end - start) / (num - 1)
|
||||
return [start + step * i for i in range(num)]
|
||||
|
||||
|
||||
def generate_depth_percents(intervals, interval_type):
|
||||
if interval_type == 'linear':
|
||||
return generate_linear_space(0, 100, intervals)
|
||||
elif interval_type == 'sigmoid':
|
||||
linear_space = generate_linear_space(0, 100, intervals)
|
||||
return [logistic(x) for x in linear_space]
|
||||
else:
|
||||
raise ValueError('Unsupported interval type')
|
||||
|
||||
|
||||
cdme_reader_cfg = dict(input_columns=['prompt'], output_column='answer')
|
||||
|
||||
cdme_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template='''{prompt}'''),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512))
|
||||
|
||||
cdme_eval_cfg = dict(
|
||||
evaluator=dict(type=CDMEEvaluator),
|
||||
pred_postprocessor=dict(type=cdme_postprocess),
|
||||
dataset_postprocessor=dict(type=cdme_dataset_postprocess),
|
||||
pred_role='BOT')
|
||||
|
||||
context_lengths = list(range(1000, 33000, 1000))
|
||||
document_depth_percent_intervals = 20
|
||||
document_depth_percent_interval_type = "linear"
|
||||
|
||||
base_path = './data/CDME'
|
||||
file_list = ['zh_finance.jsonl']
|
||||
cdme_datasets = []
|
||||
|
||||
for original_context_length in context_lengths:
|
||||
for depth_percent in generate_depth_percents(
|
||||
document_depth_percent_intervals,
|
||||
document_depth_percent_interval_type):
|
||||
dataset_dict = {
|
||||
'abbr': f'CDME_Length{original_context_length}'
|
||||
f'Depth{int(depth_percent)}',
|
||||
'type': CDMEDataset,
|
||||
'path': base_path,
|
||||
'length': original_context_length,
|
||||
'depth': int(depth_percent),
|
||||
'tokenizer_model': 'gpt-4',
|
||||
'file_list': file_list,
|
||||
'num_repeats_per_file': 10,
|
||||
'length_buffer': 200,
|
||||
'guide': True,
|
||||
'language': 'Chinese',
|
||||
'needle': '\n小明最喜欢的实习的地点就是上海人工智能实验室。\n',
|
||||
'retrieval_question': '小明最喜欢的实习地点是哪里?请按照'
|
||||
'“小明最喜欢的实习地点就是________。”的格式回答。',
|
||||
'reader_cfg': cdme_reader_cfg,
|
||||
'infer_cfg': cdme_infer_cfg,
|
||||
'eval_cfg': cdme_eval_cfg
|
||||
}
|
||||
cdme_datasets.append(dataset_dict)
|
@ -1,7 +1,10 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets.cdme.cdme import CDMEDataset,CDMEEvaluator,cdme_postprocess,cdme_dataset_postprocess
|
||||
from opencompass.datasets.cdme.cdme import CDMEDataset
|
||||
from opencompass.datasets.cdme.cdme import CDMEEvaluator
|
||||
from opencompass.datasets.cdme.cdme import cdme_postprocess
|
||||
from opencompass.datasets.cdme.cdme import cdme_dataset_postprocess
|
||||
import math
|
||||
|
||||
|
||||
@ -10,6 +13,10 @@ def logistic(x, L=100, x0=50, k=0.1):
|
||||
|
||||
|
||||
def generate_linear_space(start, end, num):
|
||||
if num == 1:
|
||||
return [start]
|
||||
elif num < 1:
|
||||
raise ValueError("num must be at least 1.")
|
||||
step = (end - start) / (num - 1)
|
||||
return [start + step * i for i in range(num)]
|
||||
|
||||
@ -40,26 +47,35 @@ cdme_eval_cfg = dict(
|
||||
pred_role='BOT')
|
||||
|
||||
context_lengths = list(range(1000, 9000, 1000))
|
||||
document_depth_percent_intervals = 35
|
||||
document_depth_percent_intervals = 20
|
||||
document_depth_percent_interval_type = "linear"
|
||||
|
||||
base_path = './data/CDME/processed'
|
||||
base_path = './data/CDME'
|
||||
file_list = ['zh_finance.jsonl']
|
||||
cdme_datasets = []
|
||||
|
||||
|
||||
for original_context_length in context_lengths:
|
||||
for depth_percent in generate_depth_percents(
|
||||
document_depth_percent_intervals,
|
||||
document_depth_percent_interval_type):
|
||||
dataset_dict = dict(
|
||||
abbr=f'CDME_Length{original_context_length}'
|
||||
'Depth{int(depth_percent)}',
|
||||
type=CDMEDataset,
|
||||
path=base_path,
|
||||
length=original_context_length,
|
||||
depth=int(depth_percent),
|
||||
reader_cfg=cdme_reader_cfg,
|
||||
infer_cfg=cdme_infer_cfg,
|
||||
eval_cfg=cdme_eval_cfg
|
||||
)
|
||||
dataset_dict = {
|
||||
'abbr': f'CDME_Length{original_context_length}'
|
||||
f'Depth{int(depth_percent)}',
|
||||
'type': CDMEDataset,
|
||||
'path': base_path,
|
||||
'length': original_context_length,
|
||||
'depth': int(depth_percent),
|
||||
'tokenizer_model': 'gpt-4',
|
||||
'file_list': file_list,
|
||||
'num_repeats_per_file': 10,
|
||||
'length_buffer': 200,
|
||||
'guide': True,
|
||||
'language': 'Chinese',
|
||||
'needle': '\n小明最喜欢的实习的地点就是上海人工智能实验室。\n',
|
||||
'retrieval_question': '小明最喜欢的实习地点是哪里?请按照'
|
||||
'“小明最喜欢的实习地点就是________。”的格式回答。',
|
||||
'reader_cfg': cdme_reader_cfg,
|
||||
'infer_cfg': cdme_infer_cfg,
|
||||
'eval_cfg': cdme_eval_cfg
|
||||
}
|
||||
cdme_datasets.append(dataset_dict)
|
@ -2,7 +2,7 @@ from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
from mmengine.config import read_base
|
||||
with read_base():
|
||||
from .datasets.cdme.cdme import cdme_datasets
|
||||
from .datasets.cdme.cdme8k import cdme_datasets
|
||||
|
||||
datasets = [*cdme_datasets]
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import tiktoken
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.datasets.base import BaseDataset
|
||||
@ -13,15 +15,110 @@ from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
||||
class CDMEDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, length: int, depth: int):
|
||||
def load(
|
||||
path: str,
|
||||
length: int,
|
||||
depth: int,
|
||||
tokenizer_model: str,
|
||||
file_list: list[str],
|
||||
num_repeats_per_file: int,
|
||||
length_buffer: int,
|
||||
guide: bool,
|
||||
language: str,
|
||||
needle: str,
|
||||
retrieval_question: str,
|
||||
):
|
||||
data = {'prompt': [], 'answer': []}
|
||||
for file in Path(path).glob('*.jsonl'):
|
||||
tokenizer = tiktoken.encoding_for_model(tokenizer_model)
|
||||
|
||||
def _generate_context(tokens_context, depth_percent, needle):
|
||||
tokens_needle = _get_tokens_from_context(needle)
|
||||
insertion_point = int(len(tokens_context) * (depth_percent / 100))
|
||||
tokens_context = (tokens_context[:insertion_point] +
|
||||
tokens_needle + tokens_context[insertion_point:])
|
||||
new_context = _decode_tokens(tokens_context)
|
||||
return new_context
|
||||
|
||||
def _get_tokens_from_context(context):
|
||||
return tokenizer.encode(context)
|
||||
|
||||
def _decode_tokens(tokens):
|
||||
return tokenizer.decode(tokens)
|
||||
|
||||
def _modify_retrieval_question(retrieval_question):
|
||||
if language == 'Chinese':
|
||||
parts = retrieval_question.split('请按照')
|
||||
guide_retrieval_question = (parts[0] + '在回答之前,请思考文档中与此问题'
|
||||
'最相关的内容是什么。请按照' + parts[1])
|
||||
return guide_retrieval_question
|
||||
elif language == 'English':
|
||||
parts = retrieval_question.split('Please answer in the format')
|
||||
guide_retrieval_question = (
|
||||
parts[0] + 'Before answering, please consider'
|
||||
' what in the document is most relevant to this question.'
|
||||
' Please answer in the format' + parts[1])
|
||||
return guide_retrieval_question
|
||||
else:
|
||||
raise ValueError(f"Language '{language}' is not supported.")
|
||||
|
||||
def _generate_prompt(context, retrieval_question):
|
||||
if guide:
|
||||
retrieval_question = _modify_retrieval_question(
|
||||
retrieval_question)
|
||||
|
||||
if language == 'Chinese':
|
||||
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
|
||||
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
|
||||
',或重复你的回答\n'
|
||||
f'用户现在给你的文档是{context}\n\n'
|
||||
f'现在请问:{retrieval_question}')
|
||||
elif language == 'English':
|
||||
prompt = ('You are an intelligent AI assistant skilled in '
|
||||
'answering user questions.\n'
|
||||
'Please keep your answers concise and clear. Do not'
|
||||
' talk about irrelevant topics or repeat your '
|
||||
'answers.\n'
|
||||
f'The document given to you by the user is {context}'
|
||||
f'\n\nNow, the question is: {retrieval_question}')
|
||||
else:
|
||||
raise ValueError(f"Language '{language}' is not supported.")
|
||||
|
||||
return prompt
|
||||
|
||||
files = Path(path).glob('*.jsonl')
|
||||
for file in files:
|
||||
if file.name not in file_list:
|
||||
continue
|
||||
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = json.loads(line.strip())
|
||||
if line['length'] == length and line['depth'] == depth:
|
||||
data['prompt'].append(line['prompt'])
|
||||
data['answer'].append(line['answer'])
|
||||
lines_bak = [json.loads(line.strip()) for line in f]
|
||||
lines = lines_bak.copy()
|
||||
for counter in range(num_repeats_per_file):
|
||||
random.seed(counter)
|
||||
random.shuffle(lines)
|
||||
|
||||
context_length = length - length_buffer
|
||||
target_length_per_record = context_length - len(
|
||||
_get_tokens_from_context(needle))
|
||||
|
||||
accumulated_tokens = []
|
||||
for line in lines:
|
||||
tokens_current_line = _get_tokens_from_context(
|
||||
line['text'])
|
||||
accumulated_tokens.extend(tokens_current_line)
|
||||
|
||||
if len(accumulated_tokens) >= target_length_per_record:
|
||||
break
|
||||
|
||||
processed_text = _generate_context(
|
||||
accumulated_tokens[:target_length_per_record], depth,
|
||||
needle)
|
||||
|
||||
processed_prompt = _generate_prompt(processed_text,
|
||||
retrieval_question)
|
||||
|
||||
data['prompt'].append(processed_prompt)
|
||||
data['answer'].append(needle)
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
'prompt': data['prompt'],
|
||||
|
@ -1,176 +1,13 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import tiktoken
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
|
||||
|
||||
class CDMEDatasetProcessor:
|
||||
|
||||
def __init__(self,
|
||||
path,
|
||||
output_path,
|
||||
tokenizer_model='gpt-4',
|
||||
num_records_per_file=10,
|
||||
length_buffer=200,
|
||||
guided=False,
|
||||
file_list=[]):
|
||||
self.path = path
|
||||
self.output_path = output_path
|
||||
self.tokenizer = tiktoken.encoding_for_model(tokenizer_model)
|
||||
self.num_records_per_file = num_records_per_file
|
||||
self.length_buffer = length_buffer
|
||||
self.guided = guided
|
||||
self.file_list = file_list
|
||||
|
||||
def process_files(self,
|
||||
context_lengths,
|
||||
needle,
|
||||
retrieval_question,
|
||||
document_depth_percent_intervals,
|
||||
document_depth_percent_interval_type='linear'):
|
||||
files = Path(self.path).glob('*.jsonl')
|
||||
for file in files:
|
||||
if os.path.basename(file) in self.file_list:
|
||||
self.process_file(file, context_lengths, needle,
|
||||
retrieval_question,
|
||||
document_depth_percent_intervals,
|
||||
document_depth_percent_interval_type)
|
||||
|
||||
def process_file(self, file, context_lengths, needle, retrieval_question,
|
||||
document_depth_percent_intervals,
|
||||
document_depth_percent_interval_type):
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
lines = [json.loads(line.strip()) for line in f]
|
||||
|
||||
# set output file
|
||||
output_file = Path(self.output_path) / f'{file.stem}_processed.jsonl'
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as out_f:
|
||||
for original_context_length in context_lengths:
|
||||
context_length = original_context_length - self.length_buffer
|
||||
target_length_per_record = context_length - len(
|
||||
self._get_tokens_from_context(needle))
|
||||
|
||||
for depth_percent in self._generate_depth_percents(
|
||||
document_depth_percent_intervals,
|
||||
document_depth_percent_interval_type):
|
||||
|
||||
counter = 0
|
||||
accumulated_tokens = []
|
||||
for line in lines:
|
||||
tokens_current_line = self._get_tokens_from_context(
|
||||
line['text'])
|
||||
accumulated_tokens.extend(tokens_current_line)
|
||||
|
||||
if len(accumulated_tokens) >= target_length_per_record:
|
||||
processed_text = self._generate_context(
|
||||
accumulated_tokens[:target_length_per_record],
|
||||
depth_percent, needle)
|
||||
|
||||
processed_prompt = self._generate_prompt(
|
||||
processed_text, retrieval_question)
|
||||
json.dump(
|
||||
{
|
||||
'prompt': processed_prompt,
|
||||
'answer': needle,
|
||||
'length': original_context_length,
|
||||
'depth': int(depth_percent),
|
||||
},
|
||||
out_f,
|
||||
ensure_ascii=False)
|
||||
out_f.write('\n')
|
||||
counter += 1
|
||||
if counter >= self.num_records_per_file:
|
||||
break
|
||||
# reset accumulated_tokens for next record
|
||||
accumulated_tokens = []
|
||||
|
||||
def _generate_context(self, tokens_context, depth_percent, needle):
|
||||
tokens_needle = self._get_tokens_from_context(needle)
|
||||
|
||||
# Insert the needle into the context at the specified depth percent
|
||||
insertion_point = int(len(tokens_context) * (depth_percent / 100))
|
||||
tokens_context = (tokens_context[:insertion_point] + tokens_needle +
|
||||
tokens_context[insertion_point:])
|
||||
|
||||
# Decode the tokens back to text
|
||||
new_context = self._decode_tokens(tokens_context)
|
||||
return new_context
|
||||
|
||||
def _get_tokens_from_context(self, context):
|
||||
return self.tokenizer.encode(context)
|
||||
|
||||
def _decode_tokens(self, tokens):
|
||||
return self.tokenizer.decode(tokens)
|
||||
|
||||
def _generate_prompt(self, context, retrieval_question):
|
||||
if self.guided:
|
||||
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
|
||||
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话,或重复你的回答\n'
|
||||
f'用户现在给你的文档是{context}\n\n'
|
||||
f'现在请问:{retrieval_question}'
|
||||
f'提示:文档中与该问题最相关的句子是_______')
|
||||
else:
|
||||
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
|
||||
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话,或重复你的回答\n'
|
||||
f'用户现在给你的文档是{context}\n\n'
|
||||
f'现在请问:{retrieval_question}')
|
||||
return prompt
|
||||
|
||||
def _generate_depth_percents(self, intervals, interval_type):
|
||||
if interval_type == 'linear':
|
||||
return np.linspace(0, 100, num=intervals)
|
||||
elif interval_type == 'sigmoid':
|
||||
return [self._logistic(x) for x in np.linspace(0, 100, intervals)]
|
||||
else:
|
||||
raise ValueError('Unsupported interval type')
|
||||
|
||||
@staticmethod
|
||||
def _logistic(x, L=100, x0=50, k=0.1):
|
||||
return np.round(L / (1 + np.exp(-k * (x - x0))), 3)
|
||||
|
||||
|
||||
class CDMEDataset():
|
||||
|
||||
@staticmethod
|
||||
def generate(processed_datasets_path, data_path, tokenizer_model,
|
||||
num_records_per_file, length_buffer, guided, file_list,
|
||||
context_lengths, needle, retrieval_question,
|
||||
document_depth_percent_intervals):
|
||||
# Check if the processed datasets directory exists
|
||||
if os.path.exists(processed_datasets_path):
|
||||
shutil.rmtree(processed_datasets_path)
|
||||
print('The existing processed datasets directory '
|
||||
f'{processed_datasets_path} has been '
|
||||
'removed for a fresh start.')
|
||||
else:
|
||||
print('No existing processed datasets directory found at'
|
||||
f' {processed_datasets_path}. '
|
||||
'Starting with a fresh directory.')
|
||||
|
||||
processor = CDMEDatasetProcessor(
|
||||
path=data_path,
|
||||
output_path=processed_datasets_path,
|
||||
tokenizer_model=tokenizer_model,
|
||||
num_records_per_file=num_records_per_file,
|
||||
length_buffer=length_buffer,
|
||||
guided=guided,
|
||||
file_list=file_list)
|
||||
|
||||
processor.process_files(context_lengths, needle, retrieval_question,
|
||||
document_depth_percent_intervals)
|
||||
|
||||
print('Datasets has been created.')
|
||||
|
||||
@staticmethod
|
||||
def visualize(csv_file_paths):
|
||||
for file_path in csv_file_paths:
|
||||
@ -276,31 +113,9 @@ class CDMEDataset():
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Generate CDMEDataset.')
|
||||
parser = argparse.ArgumentParser(description='Generate NeedleInAHaystack'
|
||||
'Test Plots')
|
||||
|
||||
parser.add_argument('--processed_datasets_path',
|
||||
type=str,
|
||||
default='./data/CDME/processed')
|
||||
parser.add_argument('--data_path', type=str, default='./data/CDME')
|
||||
parser.add_argument('--tokenizer_model', type=str, default='gpt-4')
|
||||
parser.add_argument('--num_records_per_file', type=int, default=10)
|
||||
parser.add_argument('--length_buffer', type=int, default=200)
|
||||
parser.add_argument('--guided', type=bool, default=False)
|
||||
parser.add_argument('--file_list', nargs='*', default=['zh_finance.jsonl'])
|
||||
parser.add_argument('--context_lengths',
|
||||
nargs='*',
|
||||
type=int,
|
||||
default=list(range(1000, 9000, 1000)))
|
||||
parser.add_argument('--needle',
|
||||
type=str,
|
||||
default='\n小明最喜欢的实习的地点就是上海人工智能实验室。\n')
|
||||
parser.add_argument('--retrieval_question',
|
||||
type=str,
|
||||
default='小明最喜欢的实习地点是哪里?'
|
||||
'你的回答格式应该为“小明最喜欢的实习地点就是________。”')
|
||||
parser.add_argument('--document_depth_percent_intervals',
|
||||
type=int,
|
||||
default=35)
|
||||
parser.add_argument('--plot',
|
||||
action='store_true',
|
||||
help='Visualize the dataset results')
|
||||
@ -317,21 +132,6 @@ def main():
|
||||
exit(1)
|
||||
CDMEDataset.visualize(args.csv_file_paths)
|
||||
|
||||
else:
|
||||
doc_depth_intervals = args.document_depth_percent_intervals
|
||||
CDMEDataset.generate(
|
||||
processed_datasets_path=args.processed_datasets_path,
|
||||
data_path=args.data_path,
|
||||
tokenizer_model=args.tokenizer_model,
|
||||
num_records_per_file=args.num_records_per_file,
|
||||
length_buffer=args.length_buffer,
|
||||
guided=args.guided,
|
||||
file_list=args.file_list,
|
||||
context_lengths=args.context_lengths,
|
||||
needle=args.needle,
|
||||
retrieval_question=args.retrieval_question,
|
||||
document_depth_percent_intervals=doc_depth_intervals)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user