2023-12-23 12:00:51 +08:00
|
|
|
import argparse
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import pandas as pd
|
|
|
|
import seaborn as sns
|
|
|
|
from matplotlib.colors import LinearSegmentedColormap
|
|
|
|
|
|
|
|
|
|
|
|
class CDMEDataset():
|
|
|
|
|
|
|
|
@staticmethod
|
2024-01-17 13:47:34 +08:00
|
|
|
def visualize(path: str, dataset_length: str):
|
|
|
|
for file_path in path:
|
2023-12-23 12:00:51 +08:00
|
|
|
df = pd.read_csv(file_path)
|
2023-12-29 18:51:09 +08:00
|
|
|
|
|
|
|
df['Context Length'] = df['dataset'].apply(
|
|
|
|
lambda x: int(x.split('Length')[1].split('Depth')[0]))
|
2023-12-23 12:00:51 +08:00
|
|
|
df['Document Depth'] = df['dataset'].apply(
|
2023-12-29 18:51:09 +08:00
|
|
|
lambda x: float(x.split('Depth')[1].split('_')[0]))
|
|
|
|
|
|
|
|
# Exclude 'Context Length' and 'Document Depth' columns
|
|
|
|
model_columns = [
|
|
|
|
col for col in df.columns
|
|
|
|
if col not in ['Context Length', 'Document Depth']
|
|
|
|
]
|
|
|
|
|
|
|
|
for model_name in model_columns[4:]:
|
|
|
|
model_df = df[['Document Depth', 'Context Length',
|
|
|
|
model_name]].copy()
|
|
|
|
model_df.rename(columns={model_name: 'Score'}, inplace=True)
|
|
|
|
|
|
|
|
# Create pivot table
|
|
|
|
pivot_table = pd.pivot_table(model_df,
|
|
|
|
values='Score',
|
|
|
|
index=['Document Depth'],
|
|
|
|
columns=['Context Length'],
|
|
|
|
aggfunc='mean')
|
|
|
|
|
|
|
|
# Calculate mean scores
|
|
|
|
mean_scores = pivot_table.mean().values
|
|
|
|
|
|
|
|
# Calculate overall score
|
|
|
|
overall_score = mean_scores.mean()
|
|
|
|
|
|
|
|
# Create heatmap and line plot
|
2024-01-17 13:47:34 +08:00
|
|
|
plt.figure(figsize=(15.5, 8))
|
2023-12-29 18:51:09 +08:00
|
|
|
ax = plt.gca()
|
|
|
|
cmap = LinearSegmentedColormap.from_list(
|
|
|
|
'custom_cmap', ['#F0496E', '#EBB839', '#0CD79F'])
|
|
|
|
|
|
|
|
# Draw heatmap
|
|
|
|
sns.heatmap(pivot_table,
|
|
|
|
cmap=cmap,
|
|
|
|
ax=ax,
|
|
|
|
cbar_kws={'label': 'Score'},
|
|
|
|
vmin=0,
|
|
|
|
vmax=100)
|
|
|
|
|
|
|
|
# Set line plot data
|
|
|
|
x_data = [i + 0.5 for i in range(len(mean_scores))]
|
|
|
|
y_data = mean_scores
|
|
|
|
|
|
|
|
# Create twin axis for line plot
|
|
|
|
ax2 = ax.twinx()
|
|
|
|
# Draw line plot
|
|
|
|
ax2.plot(x_data,
|
|
|
|
y_data,
|
|
|
|
color='white',
|
|
|
|
marker='o',
|
|
|
|
linestyle='-',
|
|
|
|
linewidth=2,
|
|
|
|
markersize=8,
|
|
|
|
label='Average Depth Score')
|
|
|
|
# Set y-axis range
|
|
|
|
ax2.set_ylim(0, 100)
|
|
|
|
|
|
|
|
# Hide original y-axis ticks and labels
|
|
|
|
ax2.set_yticklabels([])
|
|
|
|
ax2.set_yticks([])
|
|
|
|
|
|
|
|
# Add legend
|
|
|
|
ax2.legend(loc='upper left')
|
|
|
|
|
|
|
|
# Set chart title and labels
|
2024-01-17 13:47:34 +08:00
|
|
|
ax.set_title(f'{model_name} {dataset_length} Context '
|
|
|
|
'Performance\nFact Retrieval Across '
|
|
|
|
'Context Lengths ("Needle In A Haystack")')
|
2023-12-29 18:51:09 +08:00
|
|
|
ax.set_xlabel('Token Limit')
|
|
|
|
ax.set_ylabel('Depth Percent')
|
|
|
|
ax.set_xticklabels(pivot_table.columns.values, rotation=45)
|
|
|
|
ax.set_yticklabels(pivot_table.index.values, rotation=0)
|
|
|
|
# Add overall score as a subtitle
|
|
|
|
plt.text(0.5,
|
|
|
|
-0.13, f'Overall Score for {model_name}: '
|
|
|
|
f'{overall_score:.2f}',
|
|
|
|
ha='center',
|
|
|
|
va='center',
|
|
|
|
transform=ax.transAxes,
|
|
|
|
fontsize=13)
|
|
|
|
|
|
|
|
# Save heatmap as PNG
|
|
|
|
png_file_path = file_path.replace('.csv', f'_{model_name}.png')
|
2024-01-17 13:47:34 +08:00
|
|
|
plt.tight_layout()
|
|
|
|
plt.subplots_adjust(right=1)
|
|
|
|
plt.draw()
|
2023-12-29 18:51:09 +08:00
|
|
|
plt.savefig(png_file_path)
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
plt.close() # Close figure to prevent memory leaks
|
|
|
|
|
|
|
|
# Print saved PNG file path
|
|
|
|
print(f'Heatmap for {model_name} saved as: {png_file_path}')
|
2023-12-23 12:00:51 +08:00
|
|
|
|
|
|
|
|
|
|
|
def main():
|
2024-01-02 17:22:56 +08:00
|
|
|
parser = argparse.ArgumentParser(description='Generate NeedleInAHaystack'
|
|
|
|
'Test Plots')
|
2023-12-23 12:00:51 +08:00
|
|
|
|
2024-01-17 13:47:34 +08:00
|
|
|
parser.add_argument('--path',
|
2023-12-23 12:00:51 +08:00
|
|
|
nargs='*',
|
|
|
|
default=['path/to/your/result.csv'],
|
|
|
|
help='Paths to CSV files for visualization')
|
2024-01-17 13:47:34 +08:00
|
|
|
parser.add_argument('--dataset_length',
|
|
|
|
default='8K',
|
|
|
|
type=str,
|
|
|
|
help='Dataset_length for visualization')
|
2023-12-23 12:00:51 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
2024-01-17 13:47:34 +08:00
|
|
|
if not args.path:
|
|
|
|
print("Error: '--path' is required for visualization.")
|
|
|
|
exit(1)
|
|
|
|
CDMEDataset.visualize(args.path, args.dataset_length)
|
2023-12-23 12:00:51 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|