mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* Add NeedleInAHaystack Test * Apply pre-commit formatting * Update configs/eval_hf_internlm_chat_20b_cdme.py Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com> * add needle in haystack test * update needle in haystack test --------- Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>
277 lines
12 KiB
Python
277 lines
12 KiB
Python
import argparse
|
||
import json
|
||
import os
|
||
import shutil
|
||
from pathlib import Path
|
||
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import pandas as pd
|
||
import seaborn as sns
|
||
import tiktoken
|
||
from matplotlib.colors import LinearSegmentedColormap
|
||
|
||
|
||
class CDMEDatasetProcessor:
|
||
|
||
def __init__(self,
|
||
path,
|
||
output_path,
|
||
tokenizer_model='gpt-4',
|
||
num_records_per_file=10,
|
||
length_buffer=200,
|
||
guided=False,
|
||
file_list=[]):
|
||
self.path = path
|
||
self.output_path = output_path
|
||
self.tokenizer = tiktoken.encoding_for_model(tokenizer_model)
|
||
self.num_records_per_file = num_records_per_file
|
||
self.length_buffer = length_buffer
|
||
self.guided = guided
|
||
self.file_list = file_list
|
||
|
||
def process_files(self,
|
||
context_lengths,
|
||
needle,
|
||
retrieval_question,
|
||
document_depth_percent_intervals,
|
||
document_depth_percent_interval_type='linear'):
|
||
files = Path(self.path).glob('*.jsonl')
|
||
for file in files:
|
||
if os.path.basename(file) in self.file_list:
|
||
self.process_file(file, context_lengths, needle,
|
||
retrieval_question,
|
||
document_depth_percent_intervals,
|
||
document_depth_percent_interval_type)
|
||
|
||
def process_file(self, file, context_lengths, needle, retrieval_question,
|
||
document_depth_percent_intervals,
|
||
document_depth_percent_interval_type):
|
||
with open(file, 'r', encoding='utf-8') as f:
|
||
lines = [json.loads(line.strip()) for line in f]
|
||
|
||
for original_context_length in context_lengths:
|
||
context_length = original_context_length - self.length_buffer
|
||
target_length_per_record = context_length - len(
|
||
self._get_tokens_from_context(needle))
|
||
for depth_percent in self._generate_depth_percents(
|
||
document_depth_percent_intervals,
|
||
document_depth_percent_interval_type):
|
||
output_file = (Path(self.output_path) /
|
||
f'Length{original_context_length}'
|
||
f'Depth{int(depth_percent)}' /
|
||
f'{file.stem}_Length{original_context_length}'
|
||
f'_Depth{int(depth_percent)}{file.suffix}')
|
||
|
||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(output_file, 'w', encoding='utf-8') as out_f:
|
||
counter = 0
|
||
accumulated_tokens = []
|
||
for line in lines:
|
||
tokens_current_line = self._get_tokens_from_context(
|
||
line['text'])
|
||
accumulated_tokens.extend(tokens_current_line)
|
||
|
||
if len(accumulated_tokens) >= target_length_per_record:
|
||
|
||
processed_text = self._generate_context(
|
||
accumulated_tokens[:target_length_per_record],
|
||
depth_percent, needle)
|
||
|
||
processed_prompt = self._generate_prompt(
|
||
processed_text, retrieval_question)
|
||
json.dump(
|
||
{
|
||
'prompt': processed_prompt,
|
||
'answer': needle
|
||
},
|
||
out_f,
|
||
ensure_ascii=False)
|
||
out_f.write('\n')
|
||
counter += 1
|
||
if counter >= self.num_records_per_file:
|
||
break
|
||
# Reset the accumulated tokens for the next record
|
||
accumulated_tokens = []
|
||
|
||
def _generate_context(self, tokens_context, depth_percent, needle):
|
||
tokens_needle = self._get_tokens_from_context(needle)
|
||
|
||
# Insert the needle into the context at the specified depth percent
|
||
insertion_point = int(len(tokens_context) * (depth_percent / 100))
|
||
tokens_context = (tokens_context[:insertion_point] + tokens_needle +
|
||
tokens_context[insertion_point:])
|
||
|
||
# Decode the tokens back to text
|
||
new_context = self._decode_tokens(tokens_context)
|
||
return new_context
|
||
|
||
def _get_tokens_from_context(self, context):
|
||
return self.tokenizer.encode(context)
|
||
|
||
def _decode_tokens(self, tokens):
|
||
return self.tokenizer.decode(tokens)
|
||
|
||
def _generate_prompt(self, context, retrieval_question):
|
||
if self.guided:
|
||
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
|
||
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话,或重复你的回答\n'
|
||
f'用户现在给你的文档是{context}\n\n'
|
||
f'现在请问:{retrieval_question}'
|
||
f'提示:文档中与该问题最相关的句子是_______')
|
||
else:
|
||
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
|
||
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话,或重复你的回答\n'
|
||
f'用户现在给你的文档是{context}\n\n'
|
||
f'现在请问:{retrieval_question}')
|
||
return prompt
|
||
|
||
def _generate_depth_percents(self, intervals, interval_type):
|
||
if interval_type == 'linear':
|
||
return np.linspace(0, 100, num=intervals)
|
||
elif interval_type == 'sigmoid':
|
||
return [self._logistic(x) for x in np.linspace(0, 100, intervals)]
|
||
else:
|
||
raise ValueError('Unsupported interval type')
|
||
|
||
@staticmethod
|
||
def _logistic(x, L=100, x0=50, k=0.1):
|
||
return np.round(L / (1 + np.exp(-k * (x - x0))), 3)
|
||
|
||
|
||
class CDMEDataset():
|
||
|
||
@staticmethod
|
||
def generate(processed_datasets_path, data_path, tokenizer_model,
|
||
num_records_per_file, length_buffer, guided, file_list,
|
||
context_lengths, needle, retrieval_question,
|
||
document_depth_percent_intervals):
|
||
# Check if the processed datasets directory exists
|
||
if os.path.exists(processed_datasets_path):
|
||
shutil.rmtree(processed_datasets_path)
|
||
print('The existing processed datasets directory '
|
||
f'{processed_datasets_path} has been '
|
||
'removed for a fresh start.')
|
||
else:
|
||
print('No existing processed datasets directory found at'
|
||
f' {processed_datasets_path}. '
|
||
'Starting with a fresh directory.')
|
||
|
||
processor = CDMEDatasetProcessor(
|
||
path=data_path,
|
||
output_path=processed_datasets_path,
|
||
tokenizer_model=tokenizer_model,
|
||
num_records_per_file=num_records_per_file,
|
||
length_buffer=length_buffer,
|
||
guided=guided,
|
||
file_list=file_list)
|
||
|
||
processor.process_files(context_lengths, needle, retrieval_question,
|
||
document_depth_percent_intervals)
|
||
|
||
print('Datasets has been created.')
|
||
|
||
@staticmethod
|
||
def visualize(csv_file_paths):
|
||
for file_path in csv_file_paths:
|
||
df = pd.read_csv(file_path)
|
||
model_name = df.columns[4]
|
||
# Process the data
|
||
df['Context Length'] = df['dataset'].apply(lambda x: int(
|
||
x.replace('CDME_', '').split('Depth')[0].replace('Length', ''))
|
||
)
|
||
df['Document Depth'] = df['dataset'].apply(
|
||
lambda x: float(x.replace('CDME_', '').split('Depth')[1]))
|
||
df = df[['Document Depth', 'Context Length', model_name]]\
|
||
.rename(columns={model_name: 'Score'})
|
||
|
||
# Create pivot table
|
||
pivot_table = pd.pivot_table(df,
|
||
values='Score',
|
||
index=['Document Depth'],
|
||
columns=['Context Length'],
|
||
aggfunc='mean')
|
||
|
||
# Create a heatmap for visualization
|
||
cmap = LinearSegmentedColormap.from_list(
|
||
'custom_cmap', ['#F0496E', '#EBB839', '#0CD79F'])
|
||
plt.figure(figsize=(17.5, 8))
|
||
sns.heatmap(pivot_table, cmap=cmap, cbar_kws={'label': 'Score'})
|
||
plt.title(f'{model_name} 8K Context Performance\n'
|
||
'Fact Retrieval Across'
|
||
'Context Lengths ("Needle In A Haystack")')
|
||
plt.xlabel('Token Limit')
|
||
plt.ylabel('Depth Percent')
|
||
plt.xticks(rotation=45)
|
||
plt.yticks(rotation=0)
|
||
plt.tight_layout()
|
||
|
||
# Save the heatmap as a PNG file
|
||
png_file_path = file_path.replace('.csv', '.png')
|
||
plt.savefig(png_file_path)
|
||
plt.close() # Close the plot to prevent memory leaks
|
||
# Print the path to the saved PNG file
|
||
print(f'Heatmap saved as: {png_file_path}')
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='Generate CDMEDataset.')
|
||
|
||
parser.add_argument('--processed_datasets_path',
|
||
type=str,
|
||
default='./data/CDME/processed')
|
||
parser.add_argument('--data_path', type=str, default='./data/CDME')
|
||
parser.add_argument('--tokenizer_model', type=str, default='gpt-4')
|
||
parser.add_argument('--num_records_per_file', type=int, default=10)
|
||
parser.add_argument('--length_buffer', type=int, default=200)
|
||
parser.add_argument('--guided', type=bool, default=True)
|
||
parser.add_argument('--file_list', nargs='*', default=['zh_finance.jsonl'])
|
||
parser.add_argument('--context_lengths',
|
||
nargs='*',
|
||
type=int,
|
||
default=list(range(1000, 9000, 1000)))
|
||
parser.add_argument('--needle',
|
||
type=str,
|
||
default='\n小明最喜欢的实习的地点就是上海人工智能实验室。\n')
|
||
parser.add_argument('--retrieval_question',
|
||
type=str,
|
||
default='小明最喜欢的实习地点是哪里?'
|
||
'你的回答格式应该为“小明最喜欢的实习地点就是________。”')
|
||
parser.add_argument('--document_depth_percent_intervals',
|
||
type=int,
|
||
default=35)
|
||
parser.add_argument('--plot',
|
||
action='store_true',
|
||
help='Visualize the dataset results')
|
||
parser.add_argument('--csv_file_paths',
|
||
nargs='*',
|
||
default=['path/to/your/result.csv'],
|
||
help='Paths to CSV files for visualization')
|
||
|
||
args = parser.parse_args()
|
||
|
||
if args.plot:
|
||
if not args.csv_file_paths:
|
||
print("Error: '--csv_file_paths' is required for visualization.")
|
||
exit(1)
|
||
CDMEDataset.visualize(args.csv_file_paths)
|
||
|
||
else:
|
||
doc_depth_intervals = args.document_depth_percent_intervals
|
||
CDMEDataset.generate(
|
||
processed_datasets_path=args.processed_datasets_path,
|
||
data_path=args.data_path,
|
||
tokenizer_model=args.tokenizer_model,
|
||
num_records_per_file=args.num_records_per_file,
|
||
length_buffer=args.length_buffer,
|
||
guided=args.guided,
|
||
file_list=args.file_list,
|
||
context_lengths=args.context_lengths,
|
||
needle=args.needle,
|
||
retrieval_question=args.retrieval_question,
|
||
document_depth_percent_intervals=doc_depth_intervals)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|