OpenCompass/tools/tools_needleinahaystack.py

338 lines
14 KiB
Python
Raw Normal View History

import argparse
import json
import os
import shutil
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tiktoken
from matplotlib.colors import LinearSegmentedColormap
class CDMEDatasetProcessor:
def __init__(self,
path,
output_path,
tokenizer_model='gpt-4',
num_records_per_file=10,
length_buffer=200,
guided=False,
file_list=[]):
self.path = path
self.output_path = output_path
self.tokenizer = tiktoken.encoding_for_model(tokenizer_model)
self.num_records_per_file = num_records_per_file
self.length_buffer = length_buffer
self.guided = guided
self.file_list = file_list
def process_files(self,
context_lengths,
needle,
retrieval_question,
document_depth_percent_intervals,
document_depth_percent_interval_type='linear'):
files = Path(self.path).glob('*.jsonl')
for file in files:
if os.path.basename(file) in self.file_list:
self.process_file(file, context_lengths, needle,
retrieval_question,
document_depth_percent_intervals,
document_depth_percent_interval_type)
def process_file(self, file, context_lengths, needle, retrieval_question,
document_depth_percent_intervals,
document_depth_percent_interval_type):
with open(file, 'r', encoding='utf-8') as f:
lines = [json.loads(line.strip()) for line in f]
# set output file
output_file = Path(self.output_path) / f'{file.stem}_processed.jsonl'
output_file.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, 'w', encoding='utf-8') as out_f:
for original_context_length in context_lengths:
context_length = original_context_length - self.length_buffer
target_length_per_record = context_length - len(
self._get_tokens_from_context(needle))
for depth_percent in self._generate_depth_percents(
document_depth_percent_intervals,
document_depth_percent_interval_type):
counter = 0
accumulated_tokens = []
for line in lines:
tokens_current_line = self._get_tokens_from_context(
line['text'])
accumulated_tokens.extend(tokens_current_line)
if len(accumulated_tokens) >= target_length_per_record:
processed_text = self._generate_context(
accumulated_tokens[:target_length_per_record],
depth_percent, needle)
processed_prompt = self._generate_prompt(
processed_text, retrieval_question)
json.dump(
{
'prompt': processed_prompt,
'answer': needle,
'length': original_context_length,
'depth': int(depth_percent),
},
out_f,
ensure_ascii=False)
out_f.write('\n')
counter += 1
if counter >= self.num_records_per_file:
break
# reset accumulated_tokens for next record
accumulated_tokens = []
def _generate_context(self, tokens_context, depth_percent, needle):
tokens_needle = self._get_tokens_from_context(needle)
# Insert the needle into the context at the specified depth percent
insertion_point = int(len(tokens_context) * (depth_percent / 100))
tokens_context = (tokens_context[:insertion_point] + tokens_needle +
tokens_context[insertion_point:])
# Decode the tokens back to text
new_context = self._decode_tokens(tokens_context)
return new_context
def _get_tokens_from_context(self, context):
return self.tokenizer.encode(context)
def _decode_tokens(self, tokens):
return self.tokenizer.decode(tokens)
def _generate_prompt(self, context, retrieval_question):
if self.guided:
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话,或重复你的回答\n'
f'用户现在给你的文档是{context}\n\n'
f'现在请问:{retrieval_question}'
f'提示文档中与该问题最相关的句子是_______')
else:
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话,或重复你的回答\n'
f'用户现在给你的文档是{context}\n\n'
f'现在请问:{retrieval_question}')
return prompt
def _generate_depth_percents(self, intervals, interval_type):
if interval_type == 'linear':
return np.linspace(0, 100, num=intervals)
elif interval_type == 'sigmoid':
return [self._logistic(x) for x in np.linspace(0, 100, intervals)]
else:
raise ValueError('Unsupported interval type')
@staticmethod
def _logistic(x, L=100, x0=50, k=0.1):
return np.round(L / (1 + np.exp(-k * (x - x0))), 3)
class CDMEDataset():
@staticmethod
def generate(processed_datasets_path, data_path, tokenizer_model,
num_records_per_file, length_buffer, guided, file_list,
context_lengths, needle, retrieval_question,
document_depth_percent_intervals):
# Check if the processed datasets directory exists
if os.path.exists(processed_datasets_path):
shutil.rmtree(processed_datasets_path)
print('The existing processed datasets directory '
f'{processed_datasets_path} has been '
'removed for a fresh start.')
else:
print('No existing processed datasets directory found at'
f' {processed_datasets_path}. '
'Starting with a fresh directory.')
processor = CDMEDatasetProcessor(
path=data_path,
output_path=processed_datasets_path,
tokenizer_model=tokenizer_model,
num_records_per_file=num_records_per_file,
length_buffer=length_buffer,
guided=guided,
file_list=file_list)
processor.process_files(context_lengths, needle, retrieval_question,
document_depth_percent_intervals)
print('Datasets has been created.')
@staticmethod
def visualize(csv_file_paths):
for file_path in csv_file_paths:
df = pd.read_csv(file_path)
# Split 'dataset' column to
# get 'Context Length' and 'Document Depth'
df['Context Length'] = df['dataset'].apply(
lambda x: int(x.split('Length')[1].split('Depth')[0]))
df['Document Depth'] = df['dataset'].apply(
lambda x: float(x.split('Depth')[1].split('_')[0]))
# Exclude 'Context Length' and 'Document Depth' columns
model_columns = [
col for col in df.columns
if col not in ['Context Length', 'Document Depth']
]
for model_name in model_columns[4:]:
model_df = df[['Document Depth', 'Context Length',
model_name]].copy()
model_df.rename(columns={model_name: 'Score'}, inplace=True)
# Create pivot table
pivot_table = pd.pivot_table(model_df,
values='Score',
index=['Document Depth'],
columns=['Context Length'],
aggfunc='mean')
# Calculate mean scores
mean_scores = pivot_table.mean().values
# Calculate overall score
overall_score = mean_scores.mean()
# Create heatmap and line plot
plt.figure(figsize=(17.5, 8))
ax = plt.gca()
cmap = LinearSegmentedColormap.from_list(
'custom_cmap', ['#F0496E', '#EBB839', '#0CD79F'])
# Draw heatmap
sns.heatmap(pivot_table,
cmap=cmap,
ax=ax,
cbar_kws={'label': 'Score'},
vmin=0,
vmax=100)
# Set line plot data
x_data = [i + 0.5 for i in range(len(mean_scores))]
y_data = mean_scores
# Create twin axis for line plot
ax2 = ax.twinx()
# Draw line plot
ax2.plot(x_data,
y_data,
color='white',
marker='o',
linestyle='-',
linewidth=2,
markersize=8,
label='Average Depth Score')
# Set y-axis range
ax2.set_ylim(0, 100)
# Hide original y-axis ticks and labels
ax2.set_yticklabels([])
ax2.set_yticks([])
# Add legend
ax2.legend(loc='upper left')
# Set chart title and labels
ax.set_title(f'{model_name} 8K Context Performance\n' +
'Fact Retrieval Across Context Lengths ' +
'("Needle In A Haystack")')
ax.set_xlabel('Token Limit')
ax.set_ylabel('Depth Percent')
ax.set_xticklabels(pivot_table.columns.values, rotation=45)
ax.set_yticklabels(pivot_table.index.values, rotation=0)
# Add overall score as a subtitle
plt.text(0.5,
-0.13, f'Overall Score for {model_name}: '
f'{overall_score:.2f}',
ha='center',
va='center',
transform=ax.transAxes,
fontsize=13)
# Save heatmap as PNG
png_file_path = file_path.replace('.csv', f'_{model_name}.png')
# plt.tight_layout()
plt.savefig(png_file_path)
plt.show()
plt.close() # Close figure to prevent memory leaks
# Print saved PNG file path
print(f'Heatmap for {model_name} saved as: {png_file_path}')
def main():
parser = argparse.ArgumentParser(description='Generate CDMEDataset.')
parser.add_argument('--processed_datasets_path',
type=str,
default='./data/CDME/processed')
parser.add_argument('--data_path', type=str, default='./data/CDME')
parser.add_argument('--tokenizer_model', type=str, default='gpt-4')
parser.add_argument('--num_records_per_file', type=int, default=10)
parser.add_argument('--length_buffer', type=int, default=200)
parser.add_argument('--guided', type=bool, default=False)
parser.add_argument('--file_list', nargs='*', default=['zh_finance.jsonl'])
parser.add_argument('--context_lengths',
nargs='*',
type=int,
default=list(range(1000, 9000, 1000)))
parser.add_argument('--needle',
type=str,
default='\n小明最喜欢的实习的地点就是上海人工智能实验室。\n')
parser.add_argument('--retrieval_question',
type=str,
default='小明最喜欢的实习地点是哪里?'
'你的回答格式应该为“小明最喜欢的实习地点就是________。”')
parser.add_argument('--document_depth_percent_intervals',
type=int,
default=35)
parser.add_argument('--plot',
action='store_true',
help='Visualize the dataset results')
parser.add_argument('--csv_file_paths',
nargs='*',
default=['path/to/your/result.csv'],
help='Paths to CSV files for visualization')
args = parser.parse_args()
if args.plot:
if not args.csv_file_paths:
print("Error: '--csv_file_paths' is required for visualization.")
exit(1)
CDMEDataset.visualize(args.csv_file_paths)
else:
doc_depth_intervals = args.document_depth_percent_intervals
CDMEDataset.generate(
processed_datasets_path=args.processed_datasets_path,
data_path=args.data_path,
tokenizer_model=args.tokenizer_model,
num_records_per_file=args.num_records_per_file,
length_buffer=args.length_buffer,
guided=args.guided,
file_list=args.file_list,
context_lengths=args.context_lengths,
needle=args.needle,
retrieval_question=args.retrieval_question,
document_depth_percent_intervals=doc_depth_intervals)
if __name__ == '__main__':
main()