From 17b8e929dd80670c890c63de0efaa52ed35e8308 Mon Sep 17 00:00:00 2001 From: Mo Li <82895469+DseidLi@users.noreply.github.com> Date: Fri, 29 Dec 2023 18:51:09 +0800 Subject: [PATCH] [Feature] Update plot function in tools_needleinahaystack.py (#747) * Add NeedleInAHaystack Test * Apply pre-commit formatting * Update configs/eval_hf_internlm_chat_20b_cdme.py Co-authored-by: Songyang Zhang * add needle in haystack test * update needle in haystack test * update plot function in tools_needleinahaystack.py * optimizing needleinahaystack dataset generation strategy * modify minor formatting issues --------- Co-authored-by: Songyang Zhang --- configs/datasets/cdme/cdme.py | 53 +++++++--- opencompass/datasets/cdme/cdme.py | 8 +- tools/tools_needleinahaystack.py | 163 ++++++++++++++++++++---------- 3 files changed, 156 insertions(+), 68 deletions(-) diff --git a/configs/datasets/cdme/cdme.py b/configs/datasets/cdme/cdme.py index f2eee21d..9b7aea91 100644 --- a/configs/datasets/cdme/cdme.py +++ b/configs/datasets/cdme/cdme.py @@ -2,37 +2,64 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.datasets.cdme.cdme import CDMEDataset,CDMEEvaluator,cdme_postprocess,cdme_dataset_postprocess -import os +import math + + +def logistic(x, L=100, x0=50, k=0.1): + return round(L / (1 + math.exp(-k * (x - x0))), 3) + + +def generate_linear_space(start, end, num): + step = (end - start) / (num - 1) + return [start + step * i for i in range(num)] + + +def generate_depth_percents(intervals, interval_type): + if interval_type == 'linear': + return generate_linear_space(0, 100, intervals) + elif interval_type == 'sigmoid': + linear_space = generate_linear_space(0, 100, intervals) + return [logistic(x) for x in linear_space] + else: + raise ValueError('Unsupported interval type') + cdme_reader_cfg = dict(input_columns=['prompt'], output_column='answer') cdme_infer_cfg = dict( prompt_template=dict( -type=PromptTemplate, - template= - '''{prompt}'''), + type=PromptTemplate, + template='''{prompt}'''), retriever=dict(type=ZeroRetriever), inferencer=dict(type=GenInferencer, max_out_len=512)) cdme_eval_cfg = dict( - evaluator=dict(type=CDMEEvaluator), - pred_postprocessor=dict(type=cdme_postprocess), - dataset_postprocessor=dict(type=cdme_dataset_postprocess)) - + evaluator=dict(type=CDMEEvaluator), + pred_postprocessor=dict(type=cdme_postprocess), + dataset_postprocessor=dict(type=cdme_dataset_postprocess), + pred_role='BOT') +context_lengths = list(range(1000, 9000, 1000)) +document_depth_percent_intervals = 35 +document_depth_percent_interval_type = "linear" base_path = './data/CDME/processed' cdme_datasets = [] -for folder in os.listdir(base_path): - if os.path.isdir(os.path.join(base_path, folder)): + +for original_context_length in context_lengths: + for depth_percent in generate_depth_percents( + document_depth_percent_intervals, + document_depth_percent_interval_type): dataset_dict = dict( - abbr=f'CDME_{folder}', + abbr=f'CDME_Length{original_context_length}' + 'Depth{int(depth_percent)}', type=CDMEDataset, - path=os.path.join(base_path, folder), + path=base_path, + length=original_context_length, + depth=int(depth_percent), reader_cfg=cdme_reader_cfg, infer_cfg=cdme_infer_cfg, eval_cfg=cdme_eval_cfg ) cdme_datasets.append(dataset_dict) - diff --git a/opencompass/datasets/cdme/cdme.py b/opencompass/datasets/cdme/cdme.py index f0d16f24..b4a673c4 100644 --- a/opencompass/datasets/cdme/cdme.py +++ b/opencompass/datasets/cdme/cdme.py @@ -13,15 +13,15 @@ from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS class CDMEDataset(BaseDataset): @staticmethod - def load(path: str): - + def load(path: str, length: int, depth: int): data = {'prompt': [], 'answer': []} for file in Path(path).glob('*.jsonl'): with open(file, 'r', encoding='utf-8') as f: for line in f: line = json.loads(line.strip()) - data['prompt'].append(line['prompt']) - data['answer'].append(line['answer']) + if line['length'] == length and line['depth'] == depth: + data['prompt'].append(line['prompt']) + data['answer'].append(line['answer']) dataset = Dataset.from_dict({ 'prompt': data['prompt'], diff --git a/tools/tools_needleinahaystack.py b/tools/tools_needleinahaystack.py index e09d3c4b..064a80d1 100644 --- a/tools/tools_needleinahaystack.py +++ b/tools/tools_needleinahaystack.py @@ -50,21 +50,20 @@ class CDMEDatasetProcessor: with open(file, 'r', encoding='utf-8') as f: lines = [json.loads(line.strip()) for line in f] - for original_context_length in context_lengths: - context_length = original_context_length - self.length_buffer - target_length_per_record = context_length - len( - self._get_tokens_from_context(needle)) - for depth_percent in self._generate_depth_percents( - document_depth_percent_intervals, - document_depth_percent_interval_type): - output_file = (Path(self.output_path) / - f'Length{original_context_length}' - f'Depth{int(depth_percent)}' / - f'{file.stem}_Length{original_context_length}' - f'_Depth{int(depth_percent)}{file.suffix}') + # set output file + output_file = Path(self.output_path) / f'{file.stem}_processed.jsonl' + output_file.parent.mkdir(parents=True, exist_ok=True) + + with open(output_file, 'w', encoding='utf-8') as out_f: + for original_context_length in context_lengths: + context_length = original_context_length - self.length_buffer + target_length_per_record = context_length - len( + self._get_tokens_from_context(needle)) + + for depth_percent in self._generate_depth_percents( + document_depth_percent_intervals, + document_depth_percent_interval_type): - output_file.parent.mkdir(parents=True, exist_ok=True) - with open(output_file, 'w', encoding='utf-8') as out_f: counter = 0 accumulated_tokens = [] for line in lines: @@ -73,7 +72,6 @@ class CDMEDatasetProcessor: accumulated_tokens.extend(tokens_current_line) if len(accumulated_tokens) >= target_length_per_record: - processed_text = self._generate_context( accumulated_tokens[:target_length_per_record], depth_percent, needle) @@ -83,7 +81,9 @@ class CDMEDatasetProcessor: json.dump( { 'prompt': processed_prompt, - 'answer': needle + 'answer': needle, + 'length': original_context_length, + 'depth': int(depth_percent), }, out_f, ensure_ascii=False) @@ -91,7 +91,7 @@ class CDMEDatasetProcessor: counter += 1 if counter >= self.num_records_per_file: break - # Reset the accumulated tokens for the next record + # reset accumulated_tokens for next record accumulated_tokens = [] def _generate_context(self, tokens_context, depth_percent, needle): @@ -175,43 +175,104 @@ class CDMEDataset(): def visualize(csv_file_paths): for file_path in csv_file_paths: df = pd.read_csv(file_path) - model_name = df.columns[4] - # Process the data - df['Context Length'] = df['dataset'].apply(lambda x: int( - x.replace('CDME_', '').split('Depth')[0].replace('Length', '')) - ) + + # Split 'dataset' column to + # get 'Context Length' and 'Document Depth' + df['Context Length'] = df['dataset'].apply( + lambda x: int(x.split('Length')[1].split('Depth')[0])) df['Document Depth'] = df['dataset'].apply( - lambda x: float(x.replace('CDME_', '').split('Depth')[1])) - df = df[['Document Depth', 'Context Length', model_name]]\ - .rename(columns={model_name: 'Score'}) + lambda x: float(x.split('Depth')[1].split('_')[0])) - # Create pivot table - pivot_table = pd.pivot_table(df, - values='Score', - index=['Document Depth'], - columns=['Context Length'], - aggfunc='mean') + # Exclude 'Context Length' and 'Document Depth' columns + model_columns = [ + col for col in df.columns + if col not in ['Context Length', 'Document Depth'] + ] - # Create a heatmap for visualization - cmap = LinearSegmentedColormap.from_list( - 'custom_cmap', ['#F0496E', '#EBB839', '#0CD79F']) - plt.figure(figsize=(17.5, 8)) - sns.heatmap(pivot_table, cmap=cmap, cbar_kws={'label': 'Score'}) - plt.title(f'{model_name} 8K Context Performance\n' - 'Fact Retrieval Across' - 'Context Lengths ("Needle In A Haystack")') - plt.xlabel('Token Limit') - plt.ylabel('Depth Percent') - plt.xticks(rotation=45) - plt.yticks(rotation=0) - plt.tight_layout() + for model_name in model_columns[4:]: + model_df = df[['Document Depth', 'Context Length', + model_name]].copy() + model_df.rename(columns={model_name: 'Score'}, inplace=True) - # Save the heatmap as a PNG file - png_file_path = file_path.replace('.csv', '.png') - plt.savefig(png_file_path) - plt.close() # Close the plot to prevent memory leaks - # Print the path to the saved PNG file - print(f'Heatmap saved as: {png_file_path}') + # Create pivot table + pivot_table = pd.pivot_table(model_df, + values='Score', + index=['Document Depth'], + columns=['Context Length'], + aggfunc='mean') + + # Calculate mean scores + mean_scores = pivot_table.mean().values + + # Calculate overall score + overall_score = mean_scores.mean() + + # Create heatmap and line plot + plt.figure(figsize=(17.5, 8)) + ax = plt.gca() + cmap = LinearSegmentedColormap.from_list( + 'custom_cmap', ['#F0496E', '#EBB839', '#0CD79F']) + + # Draw heatmap + sns.heatmap(pivot_table, + cmap=cmap, + ax=ax, + cbar_kws={'label': 'Score'}, + vmin=0, + vmax=100) + + # Set line plot data + x_data = [i + 0.5 for i in range(len(mean_scores))] + y_data = mean_scores + + # Create twin axis for line plot + ax2 = ax.twinx() + # Draw line plot + ax2.plot(x_data, + y_data, + color='white', + marker='o', + linestyle='-', + linewidth=2, + markersize=8, + label='Average Depth Score') + # Set y-axis range + ax2.set_ylim(0, 100) + + # Hide original y-axis ticks and labels + ax2.set_yticklabels([]) + ax2.set_yticks([]) + + # Add legend + ax2.legend(loc='upper left') + + # Set chart title and labels + ax.set_title(f'{model_name} 8K Context Performance\n' + + 'Fact Retrieval Across Context Lengths ' + + '("Needle In A Haystack")') + ax.set_xlabel('Token Limit') + ax.set_ylabel('Depth Percent') + ax.set_xticklabels(pivot_table.columns.values, rotation=45) + ax.set_yticklabels(pivot_table.index.values, rotation=0) + # Add overall score as a subtitle + plt.text(0.5, + -0.13, f'Overall Score for {model_name}: ' + f'{overall_score:.2f}', + ha='center', + va='center', + transform=ax.transAxes, + fontsize=13) + + # Save heatmap as PNG + png_file_path = file_path.replace('.csv', f'_{model_name}.png') + # plt.tight_layout() + plt.savefig(png_file_path) + plt.show() + + plt.close() # Close figure to prevent memory leaks + + # Print saved PNG file path + print(f'Heatmap for {model_name} saved as: {png_file_path}') def main(): @@ -224,7 +285,7 @@ def main(): parser.add_argument('--tokenizer_model', type=str, default='gpt-4') parser.add_argument('--num_records_per_file', type=int, default=10) parser.add_argument('--length_buffer', type=int, default=200) - parser.add_argument('--guided', type=bool, default=True) + parser.add_argument('--guided', type=bool, default=False) parser.add_argument('--file_list', nargs='*', default=['zh_finance.jsonl']) parser.add_argument('--context_lengths', nargs='*',