mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* Add NeedleInAHaystack Test * Apply pre-commit formatting * Update configs/eval_hf_internlm_chat_20b_cdme.py Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com> * add needle in haystack test * update needle in haystack test * update plot function in tools_needleinahaystack.py * optimizing needleinahaystack dataset generation strategy * modify minor formatting issues * add English version support * change NeedleInAHaystackDataset to dynamic loading * change NeedleInAHaystackDataset to dynamic loading * fix needleinahaystack test eval bug * fix needleinahaystack config bug --------- Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>
138 lines
5.0 KiB
Python
138 lines
5.0 KiB
Python
import argparse
|
|
|
|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
import seaborn as sns
|
|
from matplotlib.colors import LinearSegmentedColormap
|
|
|
|
|
|
class CDMEDataset():
|
|
|
|
@staticmethod
|
|
def visualize(csv_file_paths):
|
|
for file_path in csv_file_paths:
|
|
df = pd.read_csv(file_path)
|
|
|
|
# Split 'dataset' column to
|
|
# get 'Context Length' and 'Document Depth'
|
|
df['Context Length'] = df['dataset'].apply(
|
|
lambda x: int(x.split('Length')[1].split('Depth')[0]))
|
|
df['Document Depth'] = df['dataset'].apply(
|
|
lambda x: float(x.split('Depth')[1].split('_')[0]))
|
|
|
|
# Exclude 'Context Length' and 'Document Depth' columns
|
|
model_columns = [
|
|
col for col in df.columns
|
|
if col not in ['Context Length', 'Document Depth']
|
|
]
|
|
|
|
for model_name in model_columns[4:]:
|
|
model_df = df[['Document Depth', 'Context Length',
|
|
model_name]].copy()
|
|
model_df.rename(columns={model_name: 'Score'}, inplace=True)
|
|
|
|
# Create pivot table
|
|
pivot_table = pd.pivot_table(model_df,
|
|
values='Score',
|
|
index=['Document Depth'],
|
|
columns=['Context Length'],
|
|
aggfunc='mean')
|
|
|
|
# Calculate mean scores
|
|
mean_scores = pivot_table.mean().values
|
|
|
|
# Calculate overall score
|
|
overall_score = mean_scores.mean()
|
|
|
|
# Create heatmap and line plot
|
|
plt.figure(figsize=(17.5, 8))
|
|
ax = plt.gca()
|
|
cmap = LinearSegmentedColormap.from_list(
|
|
'custom_cmap', ['#F0496E', '#EBB839', '#0CD79F'])
|
|
|
|
# Draw heatmap
|
|
sns.heatmap(pivot_table,
|
|
cmap=cmap,
|
|
ax=ax,
|
|
cbar_kws={'label': 'Score'},
|
|
vmin=0,
|
|
vmax=100)
|
|
|
|
# Set line plot data
|
|
x_data = [i + 0.5 for i in range(len(mean_scores))]
|
|
y_data = mean_scores
|
|
|
|
# Create twin axis for line plot
|
|
ax2 = ax.twinx()
|
|
# Draw line plot
|
|
ax2.plot(x_data,
|
|
y_data,
|
|
color='white',
|
|
marker='o',
|
|
linestyle='-',
|
|
linewidth=2,
|
|
markersize=8,
|
|
label='Average Depth Score')
|
|
# Set y-axis range
|
|
ax2.set_ylim(0, 100)
|
|
|
|
# Hide original y-axis ticks and labels
|
|
ax2.set_yticklabels([])
|
|
ax2.set_yticks([])
|
|
|
|
# Add legend
|
|
ax2.legend(loc='upper left')
|
|
|
|
# Set chart title and labels
|
|
ax.set_title(f'{model_name} 8K Context Performance\n' +
|
|
'Fact Retrieval Across Context Lengths ' +
|
|
'("Needle In A Haystack")')
|
|
ax.set_xlabel('Token Limit')
|
|
ax.set_ylabel('Depth Percent')
|
|
ax.set_xticklabels(pivot_table.columns.values, rotation=45)
|
|
ax.set_yticklabels(pivot_table.index.values, rotation=0)
|
|
# Add overall score as a subtitle
|
|
plt.text(0.5,
|
|
-0.13, f'Overall Score for {model_name}: '
|
|
f'{overall_score:.2f}',
|
|
ha='center',
|
|
va='center',
|
|
transform=ax.transAxes,
|
|
fontsize=13)
|
|
|
|
# Save heatmap as PNG
|
|
png_file_path = file_path.replace('.csv', f'_{model_name}.png')
|
|
# plt.tight_layout()
|
|
plt.savefig(png_file_path)
|
|
plt.show()
|
|
|
|
plt.close() # Close figure to prevent memory leaks
|
|
|
|
# Print saved PNG file path
|
|
print(f'Heatmap for {model_name} saved as: {png_file_path}')
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Generate NeedleInAHaystack'
|
|
'Test Plots')
|
|
|
|
parser.add_argument('--plot',
|
|
action='store_true',
|
|
help='Visualize the dataset results')
|
|
parser.add_argument('--csv_file_paths',
|
|
nargs='*',
|
|
default=['path/to/your/result.csv'],
|
|
help='Paths to CSV files for visualization')
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.plot:
|
|
if not args.csv_file_paths:
|
|
print("Error: '--csv_file_paths' is required for visualization.")
|
|
exit(1)
|
|
CDMEDataset.visualize(args.csv_file_paths)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|