[Feature] Make NeedleBench available on HF (#1364)

* update_lint

* update_huggingface format

* fix bug

* update docs
This commit is contained in:
Mo Li 2024-07-25 19:01:56 +08:00 committed by GitHub
parent c3c02c2960
commit 69aa2f2d57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 70 additions and 24 deletions

View File

@ -18,6 +18,8 @@ Within the `NeedleBench` framework of `OpenCompass`, we have designed a series o
### Evaluation Steps ### Evaluation Steps
> Note: In the latest code, OpenCompass has been set to automatically load the dataset from [Huggingface API](https://huggingface.co/datasets/opencompass/NeedleBench), so you can **skip directly** the following steps of manually downloading and placing the dataset.
1. Download the dataset from [here](https://github.com/open-compass/opencompass/files/14741330/needlebench.zip). 1. Download the dataset from [here](https://github.com/open-compass/opencompass/files/14741330/needlebench.zip).
2. Place the downloaded files in the `opencompass/data/needlebench/` directory. The expected file structure in the `needlebench` directory is shown below: 2. Place the downloaded files in the `opencompass/data/needlebench/` directory. The expected file structure in the `needlebench` directory is shown below:

View File

@ -18,6 +18,8 @@
### 评估步骤 ### 评估步骤
> 注意在最新代码中OpenCompass已经设置数据集从[Huggingface的接口](https://huggingface.co/datasets/opencompass/NeedleBench)中自动加载,可以直接跳过下面的手动下载安放数据集。
1. 从[这里](https://github.com/open-compass/opencompass/files/14741330/needlebench.zip)下载数据集。 1. 从[这里](https://github.com/open-compass/opencompass/files/14741330/needlebench.zip)下载数据集。
2. 将下载的文件放置于`opencompass/data/needlebench/`目录下。`needlebench`目录中预期的文件结构如下所示: 2. 将下载的文件放置于`opencompass/data/needlebench/`目录下。`needlebench`目录中预期的文件结构如下所示:

View File

@ -1,10 +1,10 @@
import json import json
import os import os
import random import random
from pathlib import Path
import tiktoken import tiktoken
from datasets import Dataset from datasets import Dataset
from huggingface_hub import hf_hub_download
from opencompass.datasets.base import BaseDataset from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator from opencompass.openicl import BaseEvaluator
@ -37,7 +37,7 @@ class NeedleBenchMultiDataset(BaseDataset):
@staticmethod @staticmethod
def load( def load(
path: str, path: str, # depreciated
length: int, length: int,
depth: int, depth: int,
tokenizer_model: str, tokenizer_model: str,
@ -152,13 +152,28 @@ class NeedleBenchMultiDataset(BaseDataset):
return prompt return prompt
files = Path(path).glob('*.jsonl') repo_id = 'opencompass/NeedleBench'
needle_file_path = os.path.join(path, needle_file_name) file_names = [
for file in files: 'PaulGrahamEssays.jsonl', 'multi_needle_reasoning_en.json',
if file.name not in file_list: 'multi_needle_reasoning_zh.json', 'zh_finance.jsonl',
'zh_game.jsonl', 'zh_general.jsonl', 'zh_government.jsonl',
'zh_movie.jsonl', 'zh_tech.jsonl'
]
downloaded_files = []
base_file_path = ''
for file_name in file_names:
file_path = hf_hub_download(repo_id=repo_id,
filename=file_name,
repo_type='dataset')
downloaded_files.append(file_path)
base_file_path = '/'.join(file_path.split('/')[:-1])
needle_file_path = os.path.join(base_file_path, needle_file_name)
for file_path in downloaded_files:
if file_path.split('/')[-1] not in file_list:
continue continue
with open(file, 'r', encoding='utf-8') as f: with open(file_path, 'r', encoding='utf-8') as f:
lines_bak = [json.loads(line.strip()) for line in f] lines_bak = [json.loads(line.strip()) for line in f]
lines = lines_bak.copy() lines = lines_bak.copy()
for counter in range(num_repeats_per_file): for counter in range(num_repeats_per_file):

View File

@ -2,10 +2,10 @@ import json
import os import os
import random import random
import re import re
from pathlib import Path
import tiktoken import tiktoken
from datasets import Dataset from datasets import Dataset
from huggingface_hub import hf_hub_download
from opencompass.datasets.base import BaseDataset from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator from opencompass.openicl import BaseEvaluator
@ -36,7 +36,7 @@ class NeedleBenchOriginDataset(BaseDataset):
@staticmethod @staticmethod
def load( def load(
path: str, path: str, # depreciated
length: int, length: int,
depth: int, depth: int,
tokenizer_model: str, tokenizer_model: str,
@ -128,18 +128,33 @@ class NeedleBenchOriginDataset(BaseDataset):
return prompt return prompt
files = Path(path).glob('*.jsonl') repo_id = 'opencompass/NeedleBench'
for file in files: file_names = [
if file.name not in file_list: 'PaulGrahamEssays.jsonl', 'needles.jsonl', 'zh_finance.jsonl',
continue 'zh_game.jsonl', 'zh_general.jsonl', 'zh_government.jsonl',
'zh_movie.jsonl', 'zh_tech.jsonl'
]
with open(file, 'r', encoding='utf-8') as f: downloaded_files = []
base_file_path = ''
for file_name in file_names:
file_path = hf_hub_download(repo_id=repo_id,
filename=file_name,
repo_type='dataset')
downloaded_files.append(file_path)
base_file_path = '/'.join(file_path.split('/')[:-1])
for file_path in downloaded_files:
if file_path.split('/')[-1] not in file_list:
continue
with open(file_path, 'r', encoding='utf-8') as f:
lines_bak = [json.loads(line.strip()) for line in f] lines_bak = [json.loads(line.strip()) for line in f]
lines = lines_bak.copy() lines = lines_bak.copy()
for counter in range(num_repeats_per_file): for counter in range(num_repeats_per_file):
random.seed(counter) random.seed(counter)
random.shuffle(lines) random.shuffle(lines)
needle_file_path = os.path.join(path, needle_file_name) needle_file_path = os.path.join(base_file_path,
needle_file_name)
random_needle = get_random_line_by_language( random_needle = get_random_line_by_language(
counter, needle_file_path, language) counter, needle_file_path, language)
needle = '\n' + random_needle['needle'] + '\n' needle = '\n' + random_needle['needle'] + '\n'

View File

@ -1,9 +1,9 @@
import json import json
import random import random
from pathlib import Path
import tiktoken import tiktoken
from datasets import Dataset from datasets import Dataset
from huggingface_hub import hf_hub_download
from opencompass.datasets.base import BaseDataset from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator from opencompass.openicl import BaseEvaluator
@ -57,7 +57,7 @@ class NeedleBenchParallelDataset(BaseDataset):
@staticmethod @staticmethod
def load( def load(
path: str, path: str, # depreciated
needle_file_name: str, needle_file_name: str,
length: int, length: int,
depths: list[int], depths: list[int],
@ -72,9 +72,22 @@ class NeedleBenchParallelDataset(BaseDataset):
data = {'prompt': [], 'answer': []} data = {'prompt': [], 'answer': []}
tokenizer = tiktoken.encoding_for_model(tokenizer_model) tokenizer = tiktoken.encoding_for_model(tokenizer_model)
files = Path(path).glob('*.jsonl') repo_id = 'opencompass/NeedleBench'
for file in files: file_names = [
if file.name == needle_file_name: 'PaulGrahamEssays.jsonl', 'needles.jsonl', 'zh_finance.jsonl',
'zh_game.jsonl', 'zh_general.jsonl', 'zh_government.jsonl',
'zh_movie.jsonl', 'zh_tech.jsonl'
]
downloaded_files = []
for file_name in file_names:
file_path = hf_hub_download(repo_id=repo_id,
filename=file_name,
repo_type='dataset')
downloaded_files.append(file_path)
for file in downloaded_files:
if file.split('/')[-1] == needle_file_name:
needle_file_path = file needle_file_path = file
predefined_needles_bak = get_unique_entries(needle_file_path, predefined_needles_bak = get_unique_entries(needle_file_path,
@ -178,12 +191,11 @@ class NeedleBenchParallelDataset(BaseDataset):
return prompt return prompt
files = Path(path).glob('*.jsonl') for file_path in downloaded_files:
for file in files: if file_path.split('/')[-1] not in file_list:
if file.name not in file_list:
continue continue
with open(file, 'r', encoding='utf-8') as f: with open(file_path, 'r', encoding='utf-8') as f:
lines_bak = [json.loads(line.strip()) for line in f] lines_bak = [json.loads(line.strip()) for line in f]
lines = lines_bak.copy() lines = lines_bak.copy()
for counter in range(num_repeats_per_file): for counter in range(num_repeats_per_file):