mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Update plot function in tools_needleinahaystack.py (#747)
* Add NeedleInAHaystack Test * Apply pre-commit formatting * Update configs/eval_hf_internlm_chat_20b_cdme.py Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com> * add needle in haystack test * update needle in haystack test * update plot function in tools_needleinahaystack.py * optimizing needleinahaystack dataset generation strategy * modify minor formatting issues --------- Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>
This commit is contained in:
parent
327951087f
commit
17b8e929dd
@ -2,37 +2,64 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets.cdme.cdme import CDMEDataset,CDMEEvaluator,cdme_postprocess,cdme_dataset_postprocess
|
||||
import os
|
||||
import math
|
||||
|
||||
|
||||
def logistic(x, L=100, x0=50, k=0.1):
|
||||
return round(L / (1 + math.exp(-k * (x - x0))), 3)
|
||||
|
||||
|
||||
def generate_linear_space(start, end, num):
|
||||
step = (end - start) / (num - 1)
|
||||
return [start + step * i for i in range(num)]
|
||||
|
||||
|
||||
def generate_depth_percents(intervals, interval_type):
|
||||
if interval_type == 'linear':
|
||||
return generate_linear_space(0, 100, intervals)
|
||||
elif interval_type == 'sigmoid':
|
||||
linear_space = generate_linear_space(0, 100, intervals)
|
||||
return [logistic(x) for x in linear_space]
|
||||
else:
|
||||
raise ValueError('Unsupported interval type')
|
||||
|
||||
|
||||
cdme_reader_cfg = dict(input_columns=['prompt'], output_column='answer')
|
||||
|
||||
cdme_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=
|
||||
'''{prompt}'''),
|
||||
type=PromptTemplate,
|
||||
template='''{prompt}'''),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512))
|
||||
|
||||
cdme_eval_cfg = dict(
|
||||
evaluator=dict(type=CDMEEvaluator),
|
||||
pred_postprocessor=dict(type=cdme_postprocess),
|
||||
dataset_postprocessor=dict(type=cdme_dataset_postprocess))
|
||||
|
||||
evaluator=dict(type=CDMEEvaluator),
|
||||
pred_postprocessor=dict(type=cdme_postprocess),
|
||||
dataset_postprocessor=dict(type=cdme_dataset_postprocess),
|
||||
pred_role='BOT')
|
||||
|
||||
context_lengths = list(range(1000, 9000, 1000))
|
||||
document_depth_percent_intervals = 35
|
||||
document_depth_percent_interval_type = "linear"
|
||||
|
||||
base_path = './data/CDME/processed'
|
||||
cdme_datasets = []
|
||||
|
||||
for folder in os.listdir(base_path):
|
||||
if os.path.isdir(os.path.join(base_path, folder)):
|
||||
|
||||
for original_context_length in context_lengths:
|
||||
for depth_percent in generate_depth_percents(
|
||||
document_depth_percent_intervals,
|
||||
document_depth_percent_interval_type):
|
||||
dataset_dict = dict(
|
||||
abbr=f'CDME_{folder}',
|
||||
abbr=f'CDME_Length{original_context_length}'
|
||||
'Depth{int(depth_percent)}',
|
||||
type=CDMEDataset,
|
||||
path=os.path.join(base_path, folder),
|
||||
path=base_path,
|
||||
length=original_context_length,
|
||||
depth=int(depth_percent),
|
||||
reader_cfg=cdme_reader_cfg,
|
||||
infer_cfg=cdme_infer_cfg,
|
||||
eval_cfg=cdme_eval_cfg
|
||||
)
|
||||
cdme_datasets.append(dataset_dict)
|
||||
|
||||
|
@ -13,15 +13,15 @@ from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
||||
class CDMEDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str):
|
||||
|
||||
def load(path: str, length: int, depth: int):
|
||||
data = {'prompt': [], 'answer': []}
|
||||
for file in Path(path).glob('*.jsonl'):
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = json.loads(line.strip())
|
||||
data['prompt'].append(line['prompt'])
|
||||
data['answer'].append(line['answer'])
|
||||
if line['length'] == length and line['depth'] == depth:
|
||||
data['prompt'].append(line['prompt'])
|
||||
data['answer'].append(line['answer'])
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
'prompt': data['prompt'],
|
||||
|
@ -50,21 +50,20 @@ class CDMEDatasetProcessor:
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
lines = [json.loads(line.strip()) for line in f]
|
||||
|
||||
for original_context_length in context_lengths:
|
||||
context_length = original_context_length - self.length_buffer
|
||||
target_length_per_record = context_length - len(
|
||||
self._get_tokens_from_context(needle))
|
||||
for depth_percent in self._generate_depth_percents(
|
||||
document_depth_percent_intervals,
|
||||
document_depth_percent_interval_type):
|
||||
output_file = (Path(self.output_path) /
|
||||
f'Length{original_context_length}'
|
||||
f'Depth{int(depth_percent)}' /
|
||||
f'{file.stem}_Length{original_context_length}'
|
||||
f'_Depth{int(depth_percent)}{file.suffix}')
|
||||
# set output file
|
||||
output_file = Path(self.output_path) / f'{file.stem}_processed.jsonl'
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as out_f:
|
||||
for original_context_length in context_lengths:
|
||||
context_length = original_context_length - self.length_buffer
|
||||
target_length_per_record = context_length - len(
|
||||
self._get_tokens_from_context(needle))
|
||||
|
||||
for depth_percent in self._generate_depth_percents(
|
||||
document_depth_percent_intervals,
|
||||
document_depth_percent_interval_type):
|
||||
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_file, 'w', encoding='utf-8') as out_f:
|
||||
counter = 0
|
||||
accumulated_tokens = []
|
||||
for line in lines:
|
||||
@ -73,7 +72,6 @@ class CDMEDatasetProcessor:
|
||||
accumulated_tokens.extend(tokens_current_line)
|
||||
|
||||
if len(accumulated_tokens) >= target_length_per_record:
|
||||
|
||||
processed_text = self._generate_context(
|
||||
accumulated_tokens[:target_length_per_record],
|
||||
depth_percent, needle)
|
||||
@ -83,7 +81,9 @@ class CDMEDatasetProcessor:
|
||||
json.dump(
|
||||
{
|
||||
'prompt': processed_prompt,
|
||||
'answer': needle
|
||||
'answer': needle,
|
||||
'length': original_context_length,
|
||||
'depth': int(depth_percent),
|
||||
},
|
||||
out_f,
|
||||
ensure_ascii=False)
|
||||
@ -91,7 +91,7 @@ class CDMEDatasetProcessor:
|
||||
counter += 1
|
||||
if counter >= self.num_records_per_file:
|
||||
break
|
||||
# Reset the accumulated tokens for the next record
|
||||
# reset accumulated_tokens for next record
|
||||
accumulated_tokens = []
|
||||
|
||||
def _generate_context(self, tokens_context, depth_percent, needle):
|
||||
@ -175,43 +175,104 @@ class CDMEDataset():
|
||||
def visualize(csv_file_paths):
|
||||
for file_path in csv_file_paths:
|
||||
df = pd.read_csv(file_path)
|
||||
model_name = df.columns[4]
|
||||
# Process the data
|
||||
df['Context Length'] = df['dataset'].apply(lambda x: int(
|
||||
x.replace('CDME_', '').split('Depth')[0].replace('Length', ''))
|
||||
)
|
||||
|
||||
# Split 'dataset' column to
|
||||
# get 'Context Length' and 'Document Depth'
|
||||
df['Context Length'] = df['dataset'].apply(
|
||||
lambda x: int(x.split('Length')[1].split('Depth')[0]))
|
||||
df['Document Depth'] = df['dataset'].apply(
|
||||
lambda x: float(x.replace('CDME_', '').split('Depth')[1]))
|
||||
df = df[['Document Depth', 'Context Length', model_name]]\
|
||||
.rename(columns={model_name: 'Score'})
|
||||
lambda x: float(x.split('Depth')[1].split('_')[0]))
|
||||
|
||||
# Create pivot table
|
||||
pivot_table = pd.pivot_table(df,
|
||||
values='Score',
|
||||
index=['Document Depth'],
|
||||
columns=['Context Length'],
|
||||
aggfunc='mean')
|
||||
# Exclude 'Context Length' and 'Document Depth' columns
|
||||
model_columns = [
|
||||
col for col in df.columns
|
||||
if col not in ['Context Length', 'Document Depth']
|
||||
]
|
||||
|
||||
# Create a heatmap for visualization
|
||||
cmap = LinearSegmentedColormap.from_list(
|
||||
'custom_cmap', ['#F0496E', '#EBB839', '#0CD79F'])
|
||||
plt.figure(figsize=(17.5, 8))
|
||||
sns.heatmap(pivot_table, cmap=cmap, cbar_kws={'label': 'Score'})
|
||||
plt.title(f'{model_name} 8K Context Performance\n'
|
||||
'Fact Retrieval Across'
|
||||
'Context Lengths ("Needle In A Haystack")')
|
||||
plt.xlabel('Token Limit')
|
||||
plt.ylabel('Depth Percent')
|
||||
plt.xticks(rotation=45)
|
||||
plt.yticks(rotation=0)
|
||||
plt.tight_layout()
|
||||
for model_name in model_columns[4:]:
|
||||
model_df = df[['Document Depth', 'Context Length',
|
||||
model_name]].copy()
|
||||
model_df.rename(columns={model_name: 'Score'}, inplace=True)
|
||||
|
||||
# Save the heatmap as a PNG file
|
||||
png_file_path = file_path.replace('.csv', '.png')
|
||||
plt.savefig(png_file_path)
|
||||
plt.close() # Close the plot to prevent memory leaks
|
||||
# Print the path to the saved PNG file
|
||||
print(f'Heatmap saved as: {png_file_path}')
|
||||
# Create pivot table
|
||||
pivot_table = pd.pivot_table(model_df,
|
||||
values='Score',
|
||||
index=['Document Depth'],
|
||||
columns=['Context Length'],
|
||||
aggfunc='mean')
|
||||
|
||||
# Calculate mean scores
|
||||
mean_scores = pivot_table.mean().values
|
||||
|
||||
# Calculate overall score
|
||||
overall_score = mean_scores.mean()
|
||||
|
||||
# Create heatmap and line plot
|
||||
plt.figure(figsize=(17.5, 8))
|
||||
ax = plt.gca()
|
||||
cmap = LinearSegmentedColormap.from_list(
|
||||
'custom_cmap', ['#F0496E', '#EBB839', '#0CD79F'])
|
||||
|
||||
# Draw heatmap
|
||||
sns.heatmap(pivot_table,
|
||||
cmap=cmap,
|
||||
ax=ax,
|
||||
cbar_kws={'label': 'Score'},
|
||||
vmin=0,
|
||||
vmax=100)
|
||||
|
||||
# Set line plot data
|
||||
x_data = [i + 0.5 for i in range(len(mean_scores))]
|
||||
y_data = mean_scores
|
||||
|
||||
# Create twin axis for line plot
|
||||
ax2 = ax.twinx()
|
||||
# Draw line plot
|
||||
ax2.plot(x_data,
|
||||
y_data,
|
||||
color='white',
|
||||
marker='o',
|
||||
linestyle='-',
|
||||
linewidth=2,
|
||||
markersize=8,
|
||||
label='Average Depth Score')
|
||||
# Set y-axis range
|
||||
ax2.set_ylim(0, 100)
|
||||
|
||||
# Hide original y-axis ticks and labels
|
||||
ax2.set_yticklabels([])
|
||||
ax2.set_yticks([])
|
||||
|
||||
# Add legend
|
||||
ax2.legend(loc='upper left')
|
||||
|
||||
# Set chart title and labels
|
||||
ax.set_title(f'{model_name} 8K Context Performance\n' +
|
||||
'Fact Retrieval Across Context Lengths ' +
|
||||
'("Needle In A Haystack")')
|
||||
ax.set_xlabel('Token Limit')
|
||||
ax.set_ylabel('Depth Percent')
|
||||
ax.set_xticklabels(pivot_table.columns.values, rotation=45)
|
||||
ax.set_yticklabels(pivot_table.index.values, rotation=0)
|
||||
# Add overall score as a subtitle
|
||||
plt.text(0.5,
|
||||
-0.13, f'Overall Score for {model_name}: '
|
||||
f'{overall_score:.2f}',
|
||||
ha='center',
|
||||
va='center',
|
||||
transform=ax.transAxes,
|
||||
fontsize=13)
|
||||
|
||||
# Save heatmap as PNG
|
||||
png_file_path = file_path.replace('.csv', f'_{model_name}.png')
|
||||
# plt.tight_layout()
|
||||
plt.savefig(png_file_path)
|
||||
plt.show()
|
||||
|
||||
plt.close() # Close figure to prevent memory leaks
|
||||
|
||||
# Print saved PNG file path
|
||||
print(f'Heatmap for {model_name} saved as: {png_file_path}')
|
||||
|
||||
|
||||
def main():
|
||||
@ -224,7 +285,7 @@ def main():
|
||||
parser.add_argument('--tokenizer_model', type=str, default='gpt-4')
|
||||
parser.add_argument('--num_records_per_file', type=int, default=10)
|
||||
parser.add_argument('--length_buffer', type=int, default=200)
|
||||
parser.add_argument('--guided', type=bool, default=True)
|
||||
parser.add_argument('--guided', type=bool, default=False)
|
||||
parser.add_argument('--file_list', nargs='*', default=['zh_finance.jsonl'])
|
||||
parser.add_argument('--context_lengths',
|
||||
nargs='*',
|
||||
|
Loading…
Reference in New Issue
Block a user