mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Add NeedleInAHaystack Test Support (#714)
* Add NeedleInAHaystack Test * Apply pre-commit formatting * Update configs/eval_hf_internlm_chat_20b_cdme.py Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com> * add needle in haystack test * update needle in haystack test --------- Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>
This commit is contained in:
parent
4a2d1926a2
commit
0e24f4213e
38
configs/datasets/cdme/cdme.py
Normal file
38
configs/datasets/cdme/cdme.py
Normal file
@ -0,0 +1,38 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets.cdme.cdme import CDMEDataset,CDMEEvaluator,cdme_postprocess,cdme_dataset_postprocess
|
||||
import os
|
||||
|
||||
cdme_reader_cfg = dict(input_columns=['prompt'], output_column='answer')
|
||||
|
||||
cdme_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=
|
||||
'''{prompt}'''),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512))
|
||||
|
||||
cdme_eval_cfg = dict(
|
||||
evaluator=dict(type=CDMEEvaluator),
|
||||
pred_postprocessor=dict(type=cdme_postprocess),
|
||||
dataset_postprocessor=dict(type=cdme_dataset_postprocess))
|
||||
|
||||
|
||||
|
||||
base_path = './data/CDME/processed'
|
||||
cdme_datasets = []
|
||||
|
||||
for folder in os.listdir(base_path):
|
||||
if os.path.isdir(os.path.join(base_path, folder)):
|
||||
dataset_dict = dict(
|
||||
abbr=f'CDME_{folder}',
|
||||
type=CDMEDataset,
|
||||
path=os.path.join(base_path, folder),
|
||||
reader_cfg=cdme_reader_cfg,
|
||||
infer_cfg=cdme_infer_cfg,
|
||||
eval_cfg=cdme_eval_cfg
|
||||
)
|
||||
cdme_datasets.append(dataset_dict)
|
||||
|
40
configs/eval_hf_internlm_chat_20b_cdme.py
Normal file
40
configs/eval_hf_internlm_chat_20b_cdme.py
Normal file
@ -0,0 +1,40 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
from mmengine.config import read_base
|
||||
with read_base():
|
||||
from .datasets.cdme.cdme import cdme_datasets
|
||||
|
||||
datasets = [*cdme_datasets]
|
||||
|
||||
|
||||
_meta_template = dict(
|
||||
round=[
|
||||
dict(role='HUMAN', begin='<|User|>:', end='\n'),
|
||||
dict(role='BOT', begin='<|Bot|>:', end='<eoa>\n', generate=True),
|
||||
],
|
||||
)
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='internlm-chat-20b-hf',
|
||||
path="internlm/internlm-chat-20b",
|
||||
tokenizer_path='internlm/internlm-chat-20b',
|
||||
model_kwargs=dict(
|
||||
trust_remote_code=True,
|
||||
device_map='auto',
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
use_fast=False,
|
||||
trust_remote_code=True,
|
||||
),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
meta_template=_meta_template,
|
||||
run_cfg=dict(num_gpus=2, num_procs=1),
|
||||
end_str='<eoa>',
|
||||
)
|
||||
]
|
201
docs/en/advanced_guides/needleinahaystack_eval.md
Normal file
201
docs/en/advanced_guides/needleinahaystack_eval.md
Normal file
@ -0,0 +1,201 @@
|
||||
# Needle In A Haystack Experiment Evaluation
|
||||
|
||||
## Introduction to the Needle In A Haystack Test
|
||||
|
||||
The Needle In A Haystack test (inspired by [NeedleInAHaystack](https://github.com/gkamradt/LLMTest_NeedleInAHaystack/blob/main/LLMNeedleHaystackTester.py)) involves embedding key information randomly within a long text to form prompts for large language models (LLMs). This test evaluates the LLM's ability to extract key information from extensive text, reflecting the fundamental capabilities of LLMs in understanding long texts.
|
||||
|
||||
## Dataset Overview
|
||||
|
||||
The `Skywork/ChineseDomainModelingEval` dataset includes high-quality Chinese articles published from September to October 2023, covering multiple domains. These articles ensure a fair and challenging benchmark test.
|
||||
|
||||
## File Description
|
||||
|
||||
The dataset includes files specific to certain domains:
|
||||
|
||||
- `zh_finance.jsonl` - Finance
|
||||
- `zh_game.jsonl` - Gaming
|
||||
- `zh_government.jsonl` - Government Affairs
|
||||
- `zh_movie.jsonl` - Movies
|
||||
- `zh_tech.jsonl` - Technology
|
||||
- `zh_general.jsonl` - General
|
||||
|
||||
These files are used to evaluate the LLM's understanding capabilities in different specific areas.
|
||||
|
||||
### Evaluation Steps
|
||||
|
||||
1. Download the dataset from [Skywork/ChineseDomainModelingEval](https://huggingface.co/datasets/Skywork/ChineseDomainModelingEval/tree/main).
|
||||
|
||||
2. Place the downloaded files in `opencompass/data/CDME/`. The expected file structure in the `CDME` directory is as follows:
|
||||
|
||||
```
|
||||
opencompass/
|
||||
├── configs
|
||||
├── docs
|
||||
├── data
|
||||
│ └── CDME
|
||||
│ ├── processed
|
||||
│ ├── README.md
|
||||
│ ├── zh_finance.jsonl
|
||||
│ ├── zh_game.jsonl
|
||||
│ ├── zh_general.jsonl
|
||||
│ ├── zh_government.jsonl
|
||||
│ ├── zh_movie.jsonl
|
||||
│ └── zh_tech.jsonl
|
||||
├── LICENSE
|
||||
├── opencompass
|
||||
├── outputs
|
||||
├── run.py
|
||||
├── more...
|
||||
```
|
||||
|
||||
### Environment Setup
|
||||
|
||||
```bash
|
||||
conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y
|
||||
conda activate opencompass
|
||||
git clone https://github.com/open-compass/opencompass opencompass
|
||||
cd opencompass
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Generating the Dataset
|
||||
|
||||
Run the following command to generate the dataset:
|
||||
|
||||
```bash
|
||||
python tools/tools_needleinahaystack.py \
|
||||
--processed_datasets_path './data/CDME/processed' \
|
||||
--data_path './data/CDME' \
|
||||
--tokenizer_model 'gpt-4' \
|
||||
--num_records_per_file 10 \
|
||||
--length_buffer 200 \
|
||||
--guided True \
|
||||
--file_list 'zh_finance.jsonl' \
|
||||
--context_lengths 1000 2000 3000 4000 5000 6000 7000 8000 \
|
||||
--needle '\n小明最喜欢的实习的地点就是上海人工智能实验室。\n' \
|
||||
--retrieval_question '小明最喜欢的实习地点是哪里?你的回答格式应该为“小明最喜欢的实习地点就是________。”' \
|
||||
--document_depth_percent_intervals 35 \
|
||||
```
|
||||
|
||||
You can set specific parameters when launching `tools/tools_needleinahaystack.py` to select the datasets required for your task. Key parameters include:
|
||||
|
||||
- `needle`: The specific text (needle) to be located within the dataset.
|
||||
- `retrieval_question`: The question used to prompt the model for retrieval.
|
||||
- `context_lengths`: Specifies the context lengths (in tokens) for different test scenarios.
|
||||
- `document_depth_percent_intervals`: The number of interval divisions for document depth to determine where to insert the "needle".
|
||||
|
||||
### Evaluation
|
||||
|
||||
For example, to evaluate using the `internlm` model, you can use the following command:
|
||||
|
||||
```bash
|
||||
python run.py configs/eval_hf_internlm_chat_20b_cdme.py --slurm -p partition_name-q auto --max-num-workers 32
|
||||
```
|
||||
|
||||
This command initiates the evaluation process, where the model will attempt to find the specified "needle" in the generated dataset. The parameters `-p partition_name-q auto` and `--max-num-workers 32` specify the Slurm queue and the maximum number of worker processes.
|
||||
|
||||
### Score Calculation Method
|
||||
|
||||
In the `CDMEEvaluator` class, we use two main methods to calculate scores: `levenshtein_distance` and `score`. Here is a detailed introduction and implementation
|
||||
|
||||
of these methods.
|
||||
|
||||
#### Levenshtein Distance
|
||||
|
||||
Levenshtein distance is a method for measuring the difference between two strings. It represents the minimum number of single-character edits (insertions, deletions, or substitutions) required to change one string into the other.
|
||||
|
||||
```python
|
||||
def levenshtein_distance(self, s1, s2):
|
||||
if len(s1) < len(s2):
|
||||
return self.levenshtein_distance(s2, s1)
|
||||
|
||||
if len(s2) == 0:
|
||||
return len(s1)
|
||||
|
||||
previous_row = range(len(s2) + 1)
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
```
|
||||
|
||||
#### Score Calculation
|
||||
|
||||
The `score` calculation method accepts lists of predictions and references and calculates the edit distance and score for each pair of prediction and reference.
|
||||
|
||||
```python
|
||||
def score(self, predictions, references):
|
||||
if len(predictions) != len(references):
|
||||
return {"error": "predictions and references have different lengths"}
|
||||
|
||||
total_score = 0
|
||||
details = []
|
||||
for prediction, reference in zip(predictions, references):
|
||||
prediction = re.sub(r'\s+', '', prediction)
|
||||
reference = re.sub(r'\s+', '', reference)
|
||||
edit_distance = self.levenshtein_distance(prediction, reference)
|
||||
max_len = max(len(prediction), len(reference))
|
||||
score = 100 * (1 - edit_distance / max_len) if max_len != 0 else 100
|
||||
|
||||
detail = {
|
||||
"pred": prediction,
|
||||
"answer": reference,
|
||||
"edit_distance": edit_distance,
|
||||
"score": score
|
||||
}
|
||||
total_score += score
|
||||
details.append(detail)
|
||||
|
||||
average_score = total_score / len(predictions) if predictions else 0
|
||||
result = {"score": average_score, "details": details}
|
||||
return result
|
||||
```
|
||||
|
||||
The method first removes all whitespace characters from the predictions and references, then calculates the Levenshtein distance between them. The score is calculated as 100 minus the percentage loss based on the edit distance. Finally, it returns detailed scores for each prediction and the average score.
|
||||
|
||||
### Visualization
|
||||
|
||||
You can visualize the CSV files in the `outputs` folder using the `tools_needleinahaystack.py` script. For example:
|
||||
|
||||
```bash
|
||||
python tools/tools_needleinahaystack.py \
|
||||
--plot \
|
||||
--csv_file_paths 'outputs/default/20231216_161457/summary/summary_20231216_161457.csv' 'outputs/default/20231217_022310/summary/summary_20231217_022310.csv'
|
||||
```
|
||||
|
||||
Currently, this approach only supports the CDME dataset, and we welcome community contributions to more datasets.
|
||||
|
||||
If you use this method, please add a citation:
|
||||
|
||||
```bibtex
|
||||
|
||||
@misc{2023opencompass,
|
||||
title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
|
||||
author={OpenCompass Contributors},
|
||||
howpublished={\url{https://github.com/open-compass/opencompass}},
|
||||
year={2023}
|
||||
}
|
||||
|
||||
@misc{LLMTest_NeedleInAHaystack,
|
||||
title={LLMTest Needle In A Haystack - Pressure Testing LLMs},
|
||||
author={gkamradt},
|
||||
year={2023},
|
||||
howpublished={\url{https://github.com/gkamradt/LLMTest_NeedleInAHaystack}}
|
||||
}
|
||||
|
||||
@misc{wei2023skywork,
|
||||
title={Skywork: A More Open Bilingual Foundation Model},
|
||||
author={Tianwen Wei and others},
|
||||
year={2023},
|
||||
eprint={2310.19341},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
|
||||
```
|
199
docs/zh_cn/advanced_guides/needleinahaystack_eval.md
Normal file
199
docs/zh_cn/advanced_guides/needleinahaystack_eval.md
Normal file
@ -0,0 +1,199 @@
|
||||
# 大海捞针(Needle In A Haystack)实验评估
|
||||
|
||||
## 大海捞针测试简介
|
||||
|
||||
大海捞针测试(灵感来自 [NeedleInAHaystack](https://github.com/gkamradt/LLMTest_NeedleInAHaystack/blob/main/LLMNeedleHaystackTester.py))是指通过将关键信息随机插入一段长文本的不同位置,形成大语言模型 (LLM) 的Prompt,通过测试大模型是否能从长文本中提取出关键信息,从而测试大模型的长文本信息提取能力的一种方法,可反映LLM长文本理解的基本能力。
|
||||
|
||||
## 数据集介绍
|
||||
|
||||
`Skywork/ChineseDomainModelingEval` 数据集收录了 2023 年 9 月至 10 月期间发布的高质量中文文章,涵盖了多个领域。这些文章确保了公平且具有挑战性的基准测试。
|
||||
|
||||
## 文件介绍
|
||||
|
||||
该数据集包括特定领域的文件:
|
||||
|
||||
- `zh_finance.jsonl` - 金融
|
||||
- `zh_game.jsonl` - 游戏
|
||||
- `zh_government.jsonl` - 政务
|
||||
- `zh_movie.jsonl` - 电影
|
||||
- `zh_tech.jsonl` - 技术
|
||||
- `zh_general.jsonl` - 综合
|
||||
|
||||
这些文件用于评估LLM对不同特定领域的理解能力。
|
||||
|
||||
### 评估步骤
|
||||
|
||||
1. 从 [Skywork/ChineseDomainModelingEval](https://huggingface.co/datasets/Skywork/ChineseDomainModelingEval/tree/main) 下载数据集。
|
||||
|
||||
2. 将下载的文件放置在 `opencompass/data/CDME/` 下。`CDME` 目录中的预期文件结构如下:
|
||||
|
||||
```
|
||||
opencompass/
|
||||
├── configs
|
||||
├── docs
|
||||
├── data
|
||||
│ └── CDME
|
||||
│ ├── processed
|
||||
│ ├── README.md
|
||||
│ ├── zh_finance.jsonl
|
||||
│ ├── zh_game.jsonl
|
||||
│ ├── zh_general.jsonl
|
||||
│ ├── zh_government.jsonl
|
||||
│ ├── zh_movie.jsonl
|
||||
│ └── zh_tech.jsonl
|
||||
├── LICENSE
|
||||
├── opencompass
|
||||
├── outputs
|
||||
├── run.py
|
||||
├── more...
|
||||
```
|
||||
|
||||
### 环境配置
|
||||
|
||||
```bash
|
||||
conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y
|
||||
conda activate opencompass
|
||||
git clone https://github.com/open-compass/opencompass opencompass
|
||||
cd opencompass
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### 生成数据集
|
||||
|
||||
运行以下命令以生成数据集:
|
||||
|
||||
```bash
|
||||
python tools/tools_needleinahaystack.py \
|
||||
--processed_datasets_path './data/CDME/processed' \
|
||||
--data_path './data/CDME' \
|
||||
--tokenizer_model 'gpt-4' \
|
||||
--num_records_per_file 10 \
|
||||
--length_buffer 200 \
|
||||
--guided True \
|
||||
--file_list 'zh_finance.jsonl' \
|
||||
--context_lengths 1000 2000 3000 4000 5000 6000 7000 8000 \
|
||||
--needle '\n小明最喜欢的实习的地点就是上海人工智能实验室。\n' \
|
||||
--retrieval_question '小明最喜欢的实习地点是哪里?你的回答格式应该为“小明最喜欢的实习地点就是________。”' \
|
||||
--document_depth_percent_intervals 35 \
|
||||
```
|
||||
|
||||
您可以在启动 `tools/tools_needleinahaystack.py` 时设置特定参数,以选择任务所需的数据集。主要参数包括:
|
||||
|
||||
- `needle`: 要在数据集中查找的指定文本(针)。
|
||||
- `retrieval_question`: 用于提示模型检索的问题。
|
||||
- `context_lengths`: 指定不同测试场景的上下文长度(以token为单位)。
|
||||
- `document_depth_percent_intervals`: 文档深度的划分间隔数量,用于确定在何处插入“针”。
|
||||
|
||||
### 评估
|
||||
|
||||
例如,使用 `internlm` 模型进行评估,可以使用以下命令:
|
||||
|
||||
```bash
|
||||
python run.py configs/eval_hf_internlm_chat_20b_cdme.py --slurm -p partition_name-q auto --max-num-workers 32
|
||||
```
|
||||
|
||||
这个命令将启动评估流程,其中模型将试图在生成的数据集中找到指定的“针”。参数 `-p partition_name-q auto` 和 `--max-num-workers 32` 用于指定Slurm队列和最大工作进程数。
|
||||
|
||||
### Score计算方法
|
||||
|
||||
在 `CDMEEvaluator` 类中,我们使用两个主要方法来计算得分:`levenshtein_distance` 和 `score`。下面是这些方法的详细介绍和实现。
|
||||
|
||||
#### Levenshtein Distance
|
||||
|
||||
Levenshtein 距离是一种衡量两个字符串差异的方法。它表示将一个字符串转换为另一个所需的最少单字符编辑(插入、删除或替换)的数量。
|
||||
|
||||
```python
|
||||
def levenshtein_distance(self, s1, s2):
|
||||
if len(s1) < len(s2):
|
||||
return self.levenshtein_distance(s2, s1)
|
||||
|
||||
if len(s2) == 0:
|
||||
return len(s1)
|
||||
|
||||
previous_row = range(len(s2) + 1)
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
```
|
||||
|
||||
#### Score Calculation
|
||||
|
||||
得分计算方法 `score` 接受预测值和参考值两个列表,并计算每对预测值和参考值的编辑距离和得分。
|
||||
|
||||
```python
|
||||
def score(self, predictions, references):
|
||||
if len(predictions) != len(references):
|
||||
return {"error": "predictions and references have different lengths"}
|
||||
|
||||
total_score = 0
|
||||
details = []
|
||||
for prediction, reference in zip(predictions, references):
|
||||
prediction = re.sub(r'\s+', '', prediction)
|
||||
reference = re.sub(r'\s+', '', reference)
|
||||
edit_distance = self.levenshtein_distance(prediction, reference)
|
||||
max_len = max(len(prediction), len(reference))
|
||||
score = 100 * (1 - edit_distance / max_len) if max_len != 0 else 100
|
||||
|
||||
detail = {
|
||||
"pred": prediction,
|
||||
"answer": reference,
|
||||
"edit_distance": edit_distance,
|
||||
"score": score
|
||||
}
|
||||
total_score += score
|
||||
details.append(detail)
|
||||
|
||||
average_score = total_score / len(predictions) if predictions else 0
|
||||
result = {"score": average_score, "details": details}
|
||||
return result
|
||||
```
|
||||
|
||||
该方法首先去除预测值和参考值中的所有空白字符,然后计算它们之间的 Levenshtein 距离。得分计算为 100 减去基于编辑距离的百分比损失。最后,返回每个预测值的详细得分和平均得分。
|
||||
|
||||
### 可视化
|
||||
|
||||
可以使用 `tools_needleinahaystack.py` 脚本,将 `outputs` 文件夹中的 CSV 文件进行可视化绘图。例如
|
||||
|
||||
```bash
|
||||
python tools/tools_needleinahaystack.py \
|
||||
--plot \
|
||||
--csv_file_paths 'outputs/default/20231216_161457/summary/summary_20231216_161457.csv' 'outputs/default/20231217_022310/summary/summary_20231217_022310.csv'
|
||||
```
|
||||
|
||||
目前该方案仅支持 CDME 数据集,我们欢迎社区贡献更多的数据集。
|
||||
|
||||
如果使用了该方法,请添加引用:
|
||||
|
||||
```bibtex
|
||||
|
||||
@misc{2023opencompass,
|
||||
title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
|
||||
author={OpenCompass Contributors},
|
||||
howpublished={\url{https://github.com/open-compass/opencompass}},
|
||||
year={2023}
|
||||
}
|
||||
|
||||
@misc{LLMTest_NeedleInAHaystack,
|
||||
title={LLMTest Needle In A Haystack - Pressure Testing LLMs},
|
||||
author={gkamradt},
|
||||
year={2023},
|
||||
howpublished={\url{https://github.com/gkamradt/LLMTest_NeedleInAHaystack}}
|
||||
}
|
||||
|
||||
@misc{wei2023skywork,
|
||||
title={Skywork: A More Open Bilingual Foundation Model},
|
||||
author={Tianwen Wei and Liang Zhao and Lichang Zhang and Bo Zhu and Lijie Wang and Haihua Yang and Biye Li and Cheng Cheng and Weiwei Lü and Rui Hu and Chenxia Li and Liu Yang and Xilin Luo and Xuejie Wu and Lunan Liu and Wenjun Cheng and Peng Cheng and Jianhao Zhang and Xiaoyu Zhang and Lei Lin and Xiaokun Wang and Yutuan Ma and Chuanhai Dong and Yanqi Sun and Yifu Chen and Yongyi Peng and Xiaojuan Liang and Shuicheng Yan and Han Fang and Yahui Zhou},
|
||||
year={2023},
|
||||
eprint={2310.19341},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
|
||||
```
|
91
opencompass/datasets/cdme/cdme.py
Normal file
91
opencompass/datasets/cdme/cdme.py
Normal file
@ -0,0 +1,91 @@
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.datasets.base import BaseDataset
|
||||
from opencompass.openicl import BaseEvaluator
|
||||
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class CDMEDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str):
|
||||
|
||||
data = {'prompt': [], 'answer': []}
|
||||
for file in Path(path).glob('*.jsonl'):
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = json.loads(line.strip())
|
||||
data['prompt'].append(line['prompt'])
|
||||
data['answer'].append(line['answer'])
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
'prompt': data['prompt'],
|
||||
'answer': data['answer'],
|
||||
})
|
||||
return dataset
|
||||
|
||||
|
||||
class CDMEEvaluator(BaseEvaluator):
|
||||
|
||||
def levenshtein_distance(self, s1, s2):
|
||||
if len(s1) < len(s2):
|
||||
return self.levenshtein_distance(s2, s1)
|
||||
|
||||
if len(s2) == 0:
|
||||
return len(s1)
|
||||
|
||||
previous_row = range(len(s2) + 1)
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
def score(self, predictions, references):
|
||||
if len(predictions) != len(references):
|
||||
return {
|
||||
'error': 'predictions and references have different lengths'
|
||||
}
|
||||
|
||||
total_score = 0
|
||||
details = []
|
||||
for prediction, reference in zip(predictions, references):
|
||||
prediction = re.sub(r'\s+', '', prediction)
|
||||
reference = re.sub(r'\s+', '', reference)
|
||||
edit_distance = self.levenshtein_distance(prediction, reference)
|
||||
max_len = max(len(prediction), len(reference))
|
||||
score = 100 * (1 -
|
||||
edit_distance / max_len) if max_len != 0 else 100
|
||||
|
||||
detail = {
|
||||
'pred': prediction,
|
||||
'answer': reference,
|
||||
'edit_distance': edit_distance,
|
||||
'score': score
|
||||
}
|
||||
total_score += score
|
||||
details.append(detail)
|
||||
|
||||
average_score = total_score / len(predictions) if predictions else 0
|
||||
result = {'score': average_score, 'details': details}
|
||||
return result
|
||||
|
||||
|
||||
@TEXT_POSTPROCESSORS.register_module('cdme')
|
||||
def cdme_postprocess(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
@TEXT_POSTPROCESSORS.register_module('cdme_dataset')
|
||||
def cdme_dataset_postprocess(text: str) -> str:
|
||||
return text
|
276
tools/tools_needleinahaystack.py
Normal file
276
tools/tools_needleinahaystack.py
Normal file
@ -0,0 +1,276 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import tiktoken
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
|
||||
|
||||
class CDMEDatasetProcessor:
|
||||
|
||||
def __init__(self,
|
||||
path,
|
||||
output_path,
|
||||
tokenizer_model='gpt-4',
|
||||
num_records_per_file=10,
|
||||
length_buffer=200,
|
||||
guided=False,
|
||||
file_list=[]):
|
||||
self.path = path
|
||||
self.output_path = output_path
|
||||
self.tokenizer = tiktoken.encoding_for_model(tokenizer_model)
|
||||
self.num_records_per_file = num_records_per_file
|
||||
self.length_buffer = length_buffer
|
||||
self.guided = guided
|
||||
self.file_list = file_list
|
||||
|
||||
def process_files(self,
|
||||
context_lengths,
|
||||
needle,
|
||||
retrieval_question,
|
||||
document_depth_percent_intervals,
|
||||
document_depth_percent_interval_type='linear'):
|
||||
files = Path(self.path).glob('*.jsonl')
|
||||
for file in files:
|
||||
if os.path.basename(file) in self.file_list:
|
||||
self.process_file(file, context_lengths, needle,
|
||||
retrieval_question,
|
||||
document_depth_percent_intervals,
|
||||
document_depth_percent_interval_type)
|
||||
|
||||
def process_file(self, file, context_lengths, needle, retrieval_question,
|
||||
document_depth_percent_intervals,
|
||||
document_depth_percent_interval_type):
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
lines = [json.loads(line.strip()) for line in f]
|
||||
|
||||
for original_context_length in context_lengths:
|
||||
context_length = original_context_length - self.length_buffer
|
||||
target_length_per_record = context_length - len(
|
||||
self._get_tokens_from_context(needle))
|
||||
for depth_percent in self._generate_depth_percents(
|
||||
document_depth_percent_intervals,
|
||||
document_depth_percent_interval_type):
|
||||
output_file = (Path(self.output_path) /
|
||||
f'Length{original_context_length}'
|
||||
f'Depth{int(depth_percent)}' /
|
||||
f'{file.stem}_Length{original_context_length}'
|
||||
f'_Depth{int(depth_percent)}{file.suffix}')
|
||||
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_file, 'w', encoding='utf-8') as out_f:
|
||||
counter = 0
|
||||
accumulated_tokens = []
|
||||
for line in lines:
|
||||
tokens_current_line = self._get_tokens_from_context(
|
||||
line['text'])
|
||||
accumulated_tokens.extend(tokens_current_line)
|
||||
|
||||
if len(accumulated_tokens) >= target_length_per_record:
|
||||
|
||||
processed_text = self._generate_context(
|
||||
accumulated_tokens[:target_length_per_record],
|
||||
depth_percent, needle)
|
||||
|
||||
processed_prompt = self._generate_prompt(
|
||||
processed_text, retrieval_question)
|
||||
json.dump(
|
||||
{
|
||||
'prompt': processed_prompt,
|
||||
'answer': needle
|
||||
},
|
||||
out_f,
|
||||
ensure_ascii=False)
|
||||
out_f.write('\n')
|
||||
counter += 1
|
||||
if counter >= self.num_records_per_file:
|
||||
break
|
||||
# Reset the accumulated tokens for the next record
|
||||
accumulated_tokens = []
|
||||
|
||||
def _generate_context(self, tokens_context, depth_percent, needle):
|
||||
tokens_needle = self._get_tokens_from_context(needle)
|
||||
|
||||
# Insert the needle into the context at the specified depth percent
|
||||
insertion_point = int(len(tokens_context) * (depth_percent / 100))
|
||||
tokens_context = (tokens_context[:insertion_point] + tokens_needle +
|
||||
tokens_context[insertion_point:])
|
||||
|
||||
# Decode the tokens back to text
|
||||
new_context = self._decode_tokens(tokens_context)
|
||||
return new_context
|
||||
|
||||
def _get_tokens_from_context(self, context):
|
||||
return self.tokenizer.encode(context)
|
||||
|
||||
def _decode_tokens(self, tokens):
|
||||
return self.tokenizer.decode(tokens)
|
||||
|
||||
def _generate_prompt(self, context, retrieval_question):
|
||||
if self.guided:
|
||||
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
|
||||
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话,或重复你的回答\n'
|
||||
f'用户现在给你的文档是{context}\n\n'
|
||||
f'现在请问:{retrieval_question}'
|
||||
f'提示:文档中与该问题最相关的句子是_______')
|
||||
else:
|
||||
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
|
||||
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话,或重复你的回答\n'
|
||||
f'用户现在给你的文档是{context}\n\n'
|
||||
f'现在请问:{retrieval_question}')
|
||||
return prompt
|
||||
|
||||
def _generate_depth_percents(self, intervals, interval_type):
|
||||
if interval_type == 'linear':
|
||||
return np.linspace(0, 100, num=intervals)
|
||||
elif interval_type == 'sigmoid':
|
||||
return [self._logistic(x) for x in np.linspace(0, 100, intervals)]
|
||||
else:
|
||||
raise ValueError('Unsupported interval type')
|
||||
|
||||
@staticmethod
|
||||
def _logistic(x, L=100, x0=50, k=0.1):
|
||||
return np.round(L / (1 + np.exp(-k * (x - x0))), 3)
|
||||
|
||||
|
||||
class CDMEDataset():
|
||||
|
||||
@staticmethod
|
||||
def generate(processed_datasets_path, data_path, tokenizer_model,
|
||||
num_records_per_file, length_buffer, guided, file_list,
|
||||
context_lengths, needle, retrieval_question,
|
||||
document_depth_percent_intervals):
|
||||
# Check if the processed datasets directory exists
|
||||
if os.path.exists(processed_datasets_path):
|
||||
shutil.rmtree(processed_datasets_path)
|
||||
print('The existing processed datasets directory '
|
||||
f'{processed_datasets_path} has been '
|
||||
'removed for a fresh start.')
|
||||
else:
|
||||
print('No existing processed datasets directory found at'
|
||||
f' {processed_datasets_path}. '
|
||||
'Starting with a fresh directory.')
|
||||
|
||||
processor = CDMEDatasetProcessor(
|
||||
path=data_path,
|
||||
output_path=processed_datasets_path,
|
||||
tokenizer_model=tokenizer_model,
|
||||
num_records_per_file=num_records_per_file,
|
||||
length_buffer=length_buffer,
|
||||
guided=guided,
|
||||
file_list=file_list)
|
||||
|
||||
processor.process_files(context_lengths, needle, retrieval_question,
|
||||
document_depth_percent_intervals)
|
||||
|
||||
print('Datasets has been created.')
|
||||
|
||||
@staticmethod
|
||||
def visualize(csv_file_paths):
|
||||
for file_path in csv_file_paths:
|
||||
df = pd.read_csv(file_path)
|
||||
model_name = df.columns[4]
|
||||
# Process the data
|
||||
df['Context Length'] = df['dataset'].apply(lambda x: int(
|
||||
x.replace('CDME_', '').split('Depth')[0].replace('Length', ''))
|
||||
)
|
||||
df['Document Depth'] = df['dataset'].apply(
|
||||
lambda x: float(x.replace('CDME_', '').split('Depth')[1]))
|
||||
df = df[['Document Depth', 'Context Length', model_name]]\
|
||||
.rename(columns={model_name: 'Score'})
|
||||
|
||||
# Create pivot table
|
||||
pivot_table = pd.pivot_table(df,
|
||||
values='Score',
|
||||
index=['Document Depth'],
|
||||
columns=['Context Length'],
|
||||
aggfunc='mean')
|
||||
|
||||
# Create a heatmap for visualization
|
||||
cmap = LinearSegmentedColormap.from_list(
|
||||
'custom_cmap', ['#F0496E', '#EBB839', '#0CD79F'])
|
||||
plt.figure(figsize=(17.5, 8))
|
||||
sns.heatmap(pivot_table, cmap=cmap, cbar_kws={'label': 'Score'})
|
||||
plt.title(f'{model_name} 8K Context Performance\n'
|
||||
'Fact Retrieval Across'
|
||||
'Context Lengths ("Needle In A Haystack")')
|
||||
plt.xlabel('Token Limit')
|
||||
plt.ylabel('Depth Percent')
|
||||
plt.xticks(rotation=45)
|
||||
plt.yticks(rotation=0)
|
||||
plt.tight_layout()
|
||||
|
||||
# Save the heatmap as a PNG file
|
||||
png_file_path = file_path.replace('.csv', '.png')
|
||||
plt.savefig(png_file_path)
|
||||
plt.close() # Close the plot to prevent memory leaks
|
||||
# Print the path to the saved PNG file
|
||||
print(f'Heatmap saved as: {png_file_path}')
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Generate CDMEDataset.')
|
||||
|
||||
parser.add_argument('--processed_datasets_path',
|
||||
type=str,
|
||||
default='./data/CDME/processed')
|
||||
parser.add_argument('--data_path', type=str, default='./data/CDME')
|
||||
parser.add_argument('--tokenizer_model', type=str, default='gpt-4')
|
||||
parser.add_argument('--num_records_per_file', type=int, default=10)
|
||||
parser.add_argument('--length_buffer', type=int, default=200)
|
||||
parser.add_argument('--guided', type=bool, default=True)
|
||||
parser.add_argument('--file_list', nargs='*', default=['zh_finance.jsonl'])
|
||||
parser.add_argument('--context_lengths',
|
||||
nargs='*',
|
||||
type=int,
|
||||
default=list(range(1000, 9000, 1000)))
|
||||
parser.add_argument('--needle',
|
||||
type=str,
|
||||
default='\n小明最喜欢的实习的地点就是上海人工智能实验室。\n')
|
||||
parser.add_argument('--retrieval_question',
|
||||
type=str,
|
||||
default='小明最喜欢的实习地点是哪里?'
|
||||
'你的回答格式应该为“小明最喜欢的实习地点就是________。”')
|
||||
parser.add_argument('--document_depth_percent_intervals',
|
||||
type=int,
|
||||
default=35)
|
||||
parser.add_argument('--plot',
|
||||
action='store_true',
|
||||
help='Visualize the dataset results')
|
||||
parser.add_argument('--csv_file_paths',
|
||||
nargs='*',
|
||||
default=['path/to/your/result.csv'],
|
||||
help='Paths to CSV files for visualization')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.plot:
|
||||
if not args.csv_file_paths:
|
||||
print("Error: '--csv_file_paths' is required for visualization.")
|
||||
exit(1)
|
||||
CDMEDataset.visualize(args.csv_file_paths)
|
||||
|
||||
else:
|
||||
doc_depth_intervals = args.document_depth_percent_intervals
|
||||
CDMEDataset.generate(
|
||||
processed_datasets_path=args.processed_datasets_path,
|
||||
data_path=args.data_path,
|
||||
tokenizer_model=args.tokenizer_model,
|
||||
num_records_per_file=args.num_records_per_file,
|
||||
length_buffer=args.length_buffer,
|
||||
guided=args.guided,
|
||||
file_list=args.file_list,
|
||||
context_lengths=args.context_lengths,
|
||||
needle=args.needle,
|
||||
retrieval_question=args.retrieval_question,
|
||||
document_depth_percent_intervals=doc_depth_intervals)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue
Block a user