[Feature] Make NeedleBench available on HF (#1364)

* update_lint

* update_huggingface format

* fix bug

* update docs
This commit is contained in:
Mo Li 2024-07-25 19:01:56 +08:00 committed by GitHub
parent c3c02c2960
commit 69aa2f2d57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 70 additions and 24 deletions

View File

@ -18,6 +18,8 @@ Within the `NeedleBench` framework of `OpenCompass`, we have designed a series o
### Evaluation Steps
> Note: In the latest code, OpenCompass has been set to automatically load the dataset from [Huggingface API](https://huggingface.co/datasets/opencompass/NeedleBench), so you can **skip directly** the following steps of manually downloading and placing the dataset.
1. Download the dataset from [here](https://github.com/open-compass/opencompass/files/14741330/needlebench.zip).
2. Place the downloaded files in the `opencompass/data/needlebench/` directory. The expected file structure in the `needlebench` directory is shown below:

View File

@ -18,6 +18,8 @@
### 评估步骤
> 注意在最新代码中OpenCompass已经设置数据集从[Huggingface的接口](https://huggingface.co/datasets/opencompass/NeedleBench)中自动加载,可以直接跳过下面的手动下载安放数据集。
1. 从[这里](https://github.com/open-compass/opencompass/files/14741330/needlebench.zip)下载数据集。
2. 将下载的文件放置于`opencompass/data/needlebench/`目录下。`needlebench`目录中预期的文件结构如下所示:

View File

@ -1,10 +1,10 @@
import json
import os
import random
from pathlib import Path
import tiktoken
from datasets import Dataset
from huggingface_hub import hf_hub_download
from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
@ -37,7 +37,7 @@ class NeedleBenchMultiDataset(BaseDataset):
@staticmethod
def load(
path: str,
path: str, # depreciated
length: int,
depth: int,
tokenizer_model: str,
@ -152,13 +152,28 @@ class NeedleBenchMultiDataset(BaseDataset):
return prompt
files = Path(path).glob('*.jsonl')
needle_file_path = os.path.join(path, needle_file_name)
for file in files:
if file.name not in file_list:
repo_id = 'opencompass/NeedleBench'
file_names = [
'PaulGrahamEssays.jsonl', 'multi_needle_reasoning_en.json',
'multi_needle_reasoning_zh.json', 'zh_finance.jsonl',
'zh_game.jsonl', 'zh_general.jsonl', 'zh_government.jsonl',
'zh_movie.jsonl', 'zh_tech.jsonl'
]
downloaded_files = []
base_file_path = ''
for file_name in file_names:
file_path = hf_hub_download(repo_id=repo_id,
filename=file_name,
repo_type='dataset')
downloaded_files.append(file_path)
base_file_path = '/'.join(file_path.split('/')[:-1])
needle_file_path = os.path.join(base_file_path, needle_file_name)
for file_path in downloaded_files:
if file_path.split('/')[-1] not in file_list:
continue
with open(file, 'r', encoding='utf-8') as f:
with open(file_path, 'r', encoding='utf-8') as f:
lines_bak = [json.loads(line.strip()) for line in f]
lines = lines_bak.copy()
for counter in range(num_repeats_per_file):

View File

@ -2,10 +2,10 @@ import json
import os
import random
import re
from pathlib import Path
import tiktoken
from datasets import Dataset
from huggingface_hub import hf_hub_download
from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
@ -36,7 +36,7 @@ class NeedleBenchOriginDataset(BaseDataset):
@staticmethod
def load(
path: str,
path: str, # depreciated
length: int,
depth: int,
tokenizer_model: str,
@ -128,18 +128,33 @@ class NeedleBenchOriginDataset(BaseDataset):
return prompt
files = Path(path).glob('*.jsonl')
for file in files:
if file.name not in file_list:
continue
repo_id = 'opencompass/NeedleBench'
file_names = [
'PaulGrahamEssays.jsonl', 'needles.jsonl', 'zh_finance.jsonl',
'zh_game.jsonl', 'zh_general.jsonl', 'zh_government.jsonl',
'zh_movie.jsonl', 'zh_tech.jsonl'
]
with open(file, 'r', encoding='utf-8') as f:
downloaded_files = []
base_file_path = ''
for file_name in file_names:
file_path = hf_hub_download(repo_id=repo_id,
filename=file_name,
repo_type='dataset')
downloaded_files.append(file_path)
base_file_path = '/'.join(file_path.split('/')[:-1])
for file_path in downloaded_files:
if file_path.split('/')[-1] not in file_list:
continue
with open(file_path, 'r', encoding='utf-8') as f:
lines_bak = [json.loads(line.strip()) for line in f]
lines = lines_bak.copy()
for counter in range(num_repeats_per_file):
random.seed(counter)
random.shuffle(lines)
needle_file_path = os.path.join(path, needle_file_name)
needle_file_path = os.path.join(base_file_path,
needle_file_name)
random_needle = get_random_line_by_language(
counter, needle_file_path, language)
needle = '\n' + random_needle['needle'] + '\n'

View File

@ -1,9 +1,9 @@
import json
import random
from pathlib import Path
import tiktoken
from datasets import Dataset
from huggingface_hub import hf_hub_download
from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
@ -57,7 +57,7 @@ class NeedleBenchParallelDataset(BaseDataset):
@staticmethod
def load(
path: str,
path: str, # depreciated
needle_file_name: str,
length: int,
depths: list[int],
@ -72,9 +72,22 @@ class NeedleBenchParallelDataset(BaseDataset):
data = {'prompt': [], 'answer': []}
tokenizer = tiktoken.encoding_for_model(tokenizer_model)
files = Path(path).glob('*.jsonl')
for file in files:
if file.name == needle_file_name:
repo_id = 'opencompass/NeedleBench'
file_names = [
'PaulGrahamEssays.jsonl', 'needles.jsonl', 'zh_finance.jsonl',
'zh_game.jsonl', 'zh_general.jsonl', 'zh_government.jsonl',
'zh_movie.jsonl', 'zh_tech.jsonl'
]
downloaded_files = []
for file_name in file_names:
file_path = hf_hub_download(repo_id=repo_id,
filename=file_name,
repo_type='dataset')
downloaded_files.append(file_path)
for file in downloaded_files:
if file.split('/')[-1] == needle_file_name:
needle_file_path = file
predefined_needles_bak = get_unique_entries(needle_file_path,
@ -178,12 +191,11 @@ class NeedleBenchParallelDataset(BaseDataset):
return prompt
files = Path(path).glob('*.jsonl')
for file in files:
if file.name not in file_list:
for file_path in downloaded_files:
if file_path.split('/')[-1] not in file_list:
continue
with open(file, 'r', encoding='utf-8') as f:
with open(file_path, 'r', encoding='utf-8') as f:
lines_bak = [json.loads(line.strip()) for line in f]
lines = lines_bak.copy()
for counter in range(num_repeats_per_file):