OpenCompass/opencompass/datasets/needlebench/parallel.py

312 lines
12 KiB
Python
Raw Normal View History

import json
import random
from pathlib import Path
import tiktoken
from datasets import Dataset
from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET
def get_unique_entries(file_path,
n,
language,
unique_arg1=False,
unique_arg2=False,
unique_combination=False):
seen_arg1 = set()
seen_arg2 = set()
seen_combinations = set()
results = []
with open(file_path, 'r', encoding='utf-8') as file:
lines = file.readlines()
random.shuffle(lines)
for line in lines:
try:
entry = json.loads(line.strip())
except json.JSONDecodeError:
continue
if entry.get('language') != language:
continue
key1 = entry.get('arg1', '') if unique_arg1 else ''
key2 = entry.get('arg2', '') if unique_arg2 else ''
combination = (key1, key2) if unique_combination else ''
if (key1 not in seen_arg1 or not unique_arg1) and \
(key2 not in seen_arg2 or not unique_arg2) and \
(combination not in seen_combinations or not unique_combination):
seen_arg1.add(key1)
seen_arg2.add(key2)
seen_combinations.add(combination)
results.append(entry)
if len(results) == n:
break
return results
@LOAD_DATASET.register_module()
class NeedleBenchParallelDataset(BaseDataset):
@staticmethod
def load(
path: str,
needle_file_name: str,
length: int,
depths: list[int],
tokenizer_model: str,
file_list: list[str],
num_repeats_per_file: int,
length_buffer: int,
guide: bool,
language: str,
position: str = 'End',
):
data = {'prompt': [], 'answer': []}
tokenizer = tiktoken.encoding_for_model(tokenizer_model)
files = Path(path).glob('*.jsonl')
for file in files:
if file.name == needle_file_name:
needle_file_path = file
predefined_needles_bak = get_unique_entries(needle_file_path,
len(depths),
language,
unique_arg1=True,
unique_arg2=True,
unique_combination=True)
def _generate_context(tokens_context, depths, needles):
insertion_points = [
int(len(tokens_context) * (depth / 100)) for depth in depths
]
cumulative_inserted_length = 0
for i, needle in enumerate(needles):
needle_tokens = _get_tokens_from_context(needle)
current_insertion_point = min(
insertion_points[i] + cumulative_inserted_length,
len(tokens_context))
tokens_context = tokens_context[:current_insertion_point] + \
needle_tokens + tokens_context[current_insertion_point:]
cumulative_inserted_length += len(needle_tokens)
new_context = _decode_tokens(tokens_context)
return new_context
def _get_tokens_from_context(context):
if isinstance(context, list):
return [tokenizer.encode(item) for item in context]
else:
return tokenizer.encode(context)
def _decode_tokens(tokens):
return tokenizer.decode(tokens)
def _modify_retrieval_question(retrieval_question):
if language == 'Chinese':
parts = retrieval_question.split('请按照')
guide_retrieval_question = (parts[0] + '在回答之前,请思考文档中与此问题'
'最相关的内容是什么。请按照' + parts[1])
return guide_retrieval_question
elif language == 'English':
parts = retrieval_question.split('Please answer in the format')
guide_retrieval_question = (
parts[0] + 'Before answering, please consider'
' what in the document is most relevant to this question.'
' Please answer in the format' + parts[1])
return guide_retrieval_question
else:
raise ValueError(f"Language '{language}' is not supported.")
def _generate_prompt(context, retrieval_question):
if guide:
retrieval_question = _modify_retrieval_question(
retrieval_question)
if language == 'Chinese':
if position == 'End':
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
',或重复你的回答\n请先仔细阅读下面的文档再依次回答'
f'最后提出的问题\n用户现在给你的文档是{context}\n\n'
f'现在请问:{retrieval_question}\n')
elif position == 'Start':
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
',或重复你的回答\n请先仔细阅读下面的文档再依次回答'
f'最后提出的问题\n现在请问:{retrieval_question}\n\n'
f'用户现在给你的文档是{context}\n')
else:
[Feature] Add ATC Choice Version (#1019) * Squashed commit of the following: commit c48ad194c3976dc63d1b60d8c8ab2d5ff9e1cbfe Author: DseidLi <2568818204@qq.com> Date: Tue Apr 2 16:57:43 2024 +0800 add atc_choice commit 3ac6efea29619573e6fac8fa3cce464853dcead0 Merge: 2d4e559 8e3a9c3 Author: DseidLi <2568818204@qq.com> Date: Tue Apr 2 16:41:38 2024 +0800 Merge branch 'atc_choice' into atc_add_choice commit 8e3a9c396a3e5546d3faf584183f6fd60b974d5e Merge: 150a036 0a6a03f Author: DseidLi <2568818204@qq.com> Date: Tue Mar 26 04:47:07 2024 +0800 Merge branch 'main' into atc_choice Conflicts: configs/summarizers/needlebench.py opencompass/datasets/needlebench/multi.py opencompass/datasets/needlebench/origin.py opencompass/datasets/needlebench/parallel.py commit 150a036d6d990f26a57c974d1af83d88c31a0f9d Merge: 8d6ac9a 940dd18 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 20 03:49:08 2024 +0800 Merge branch 'needlebench_fix' into atc_choice commit 8d6ac9a1a43b1c9d0f0ea27e7d58968a203ea898 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 20 03:41:49 2024 +0800 optimize needlebench code commit 940dd18a4270f24bc69edd2a780182c68918e1a9 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 20 03:39:46 2024 +0800 fix vllm commit d8be6877bc41051f3edcc0421c462c834c0f1c9a Merge: ecad78a 2527fda Author: DseidLi <2568818204@qq.com> Date: Tue Mar 19 21:07:08 2024 +0800 Merge remote-tracking branch 'origin/add_1M_dataset' into atc_choice commit 2527fda8a546595bcaea1e5261367bc1097faec8 Author: DseidLi <2568818204@qq.com> Date: Tue Mar 19 16:03:40 2024 +0800 add model configs commit 75425acdf80d6d25ee24bb0aa60ac48539262e76 Author: DseidLi <2568818204@qq.com> Date: Tue Mar 19 16:02:15 2024 +0800 add prompt postion args commit 367ba1ba612a8cec5df1f80d5e5ae4e285baf38b Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 21:40:00 2024 +0800 add Needlebench-1000K configs commit ecad78af14c4bb00fe325779114b384c57ab30bf Author: DseidLi <2568818204@qq.com> Date: Thu Mar 14 22:08:32 2024 +0800 fix atc commit 08772c0787b18872abadc9ffec3223941a5ee0c2 Merge: 9f3f8cf caf1cf8 Author: DseidLi <2568818204@qq.com> Date: Thu Mar 14 22:07:28 2024 +0800 Merge branch 'main' into atc_choice Conflicts: configs/datasets/needlebench/readme.md configs/datasets/needlebench/readme_zh-CN.md configs/summarizers/needlebench.py opencompass/datasets/needlebench/atc.py opencompass/summarizers/needlebench.py commit 9f3f8cfb4452722734d334114ac1d14110e57406 Author: DseidLi <2568818204@qq.com> Date: Thu Mar 14 21:35:53 2024 +0800 add atc-choice test commit 52be7c1202376b4e09821188b826f1a805328129 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 6 02:54:15 2024 +0800 update needlebench randomseed and add vllm qwen14b commit fc1effce596ae2e5ece4933e8cd34aef8e64a6f9 Merge: 4e747ed caf1cf8 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 6 02:51:14 2024 +0800 Merge branch 'main' into add_model_configs commit 31834f9b23af3354ac3581ec86d693d0f05cdd1c Merge: 7dabc82 120bf8b Author: DseidLi <2568818204@qq.com> Date: Sun Mar 3 23:29:42 2024 +0800 Merge branch 'main' of https://github.com/open-compass/opencompass into atc_choice commit 4e747ed1988ddbcfcc7fff334601259ade72d363 Author: DseidLi <2568818204@qq.com> Date: Sun Mar 3 22:15:25 2024 +0800 add internlm2-lmdeploy model and gemma configs commit 7dabc828123d711c8cf834d6aab4137bb55e85ed Author: DseidLi <2568818204@qq.com> Date: Sat Mar 2 17:26:15 2024 +0800 add atc choice version -ZH commit 996f8ae43d3f946a052f736717ead139d153e2dd Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:58:56 2024 +0800 update readme for needlebench commit f7266e873cb34ccf18a7f20b2c5821af8416a14f Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:44:53 2024 +0800 move readme.md commit 1c7375681dea13996802e45b878dc4929ea8fa65 Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:38:31 2024 +0800 fix linting error commit b6524f3ebfb8a3a12a5ad3e3fa7a8a0921fcb6c1 Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:33:51 2024 +0800 lint summarizer commit c0d1190e39d3b6724f677346df2572df9af59f25 Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:29:03 2024 +0800 add needlebench intro, fix summarizer commit 0965baf78588e29d813b61d73f0ebd868a0ce3d0 Author: DseidLi <2568818204@qq.com> Date: Mon Feb 26 13:31:26 2024 +0800 fix bug in needlebench summarizer commit 5d32b31eb85382026935f356190ad92b103afd98 Author: DseidLi <2568818204@qq.com> Date: Sat Feb 24 03:19:08 2024 +0800 update act prompt commit af82a7f085e394d83aa84043e2881dd50115942c Merge: 32bf9fe 53fe788 Author: DseidLi <2568818204@qq.com> Date: Fri Feb 23 17:50:32 2024 +0800 Merge remote-tracking branch 'upstream/main' into needlebench commit 32bf9fe802eaf8e8e5b33ff17b2a897058f8b66b Author: DseidLi <2568818204@qq.com> Date: Fri Feb 23 17:31:32 2024 +0800 simplify needlebench 32k, 128k, 200k for eval commit a7cb025e05a48449de9839005fada02bd5bff15a Author: DseidLi <2568818204@qq.com> Date: Fri Feb 23 14:48:58 2024 +0800 add needlebench * fix summarizer * remove repeated code * remove chinese comments
2024-04-07 15:46:20 +08:00
raise ValueError(f'Unsupported position {position}. '
'Position must be "End" or "Start".')
elif language == 'English':
if position == 'End':
prompt = (
'You are an intelligent AI assistant skilled in '
'answering user questions.\n'
'Please keep your answers concise and clear. Do not'
' talk about irrelevant topics or repeat your '
'answers.\n'
f'The document given to you by the user is {context}'
f'\n\nNow, the questions are: {retrieval_question}\n')
elif position == 'Start':
prompt = (
'You are an intelligent AI assistant skilled in '
'answering user questions.\n'
'Please keep your answers concise and clear. Do not'
' talk about irrelevant topics or repeat your '
'answers.\n'
f'\nNow, the questions are: {retrieval_question}\n\n'
f'The document given to you by the user is {context}')
else:
[Feature] Add ATC Choice Version (#1019) * Squashed commit of the following: commit c48ad194c3976dc63d1b60d8c8ab2d5ff9e1cbfe Author: DseidLi <2568818204@qq.com> Date: Tue Apr 2 16:57:43 2024 +0800 add atc_choice commit 3ac6efea29619573e6fac8fa3cce464853dcead0 Merge: 2d4e559 8e3a9c3 Author: DseidLi <2568818204@qq.com> Date: Tue Apr 2 16:41:38 2024 +0800 Merge branch 'atc_choice' into atc_add_choice commit 8e3a9c396a3e5546d3faf584183f6fd60b974d5e Merge: 150a036 0a6a03f Author: DseidLi <2568818204@qq.com> Date: Tue Mar 26 04:47:07 2024 +0800 Merge branch 'main' into atc_choice Conflicts: configs/summarizers/needlebench.py opencompass/datasets/needlebench/multi.py opencompass/datasets/needlebench/origin.py opencompass/datasets/needlebench/parallel.py commit 150a036d6d990f26a57c974d1af83d88c31a0f9d Merge: 8d6ac9a 940dd18 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 20 03:49:08 2024 +0800 Merge branch 'needlebench_fix' into atc_choice commit 8d6ac9a1a43b1c9d0f0ea27e7d58968a203ea898 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 20 03:41:49 2024 +0800 optimize needlebench code commit 940dd18a4270f24bc69edd2a780182c68918e1a9 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 20 03:39:46 2024 +0800 fix vllm commit d8be6877bc41051f3edcc0421c462c834c0f1c9a Merge: ecad78a 2527fda Author: DseidLi <2568818204@qq.com> Date: Tue Mar 19 21:07:08 2024 +0800 Merge remote-tracking branch 'origin/add_1M_dataset' into atc_choice commit 2527fda8a546595bcaea1e5261367bc1097faec8 Author: DseidLi <2568818204@qq.com> Date: Tue Mar 19 16:03:40 2024 +0800 add model configs commit 75425acdf80d6d25ee24bb0aa60ac48539262e76 Author: DseidLi <2568818204@qq.com> Date: Tue Mar 19 16:02:15 2024 +0800 add prompt postion args commit 367ba1ba612a8cec5df1f80d5e5ae4e285baf38b Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 21:40:00 2024 +0800 add Needlebench-1000K configs commit ecad78af14c4bb00fe325779114b384c57ab30bf Author: DseidLi <2568818204@qq.com> Date: Thu Mar 14 22:08:32 2024 +0800 fix atc commit 08772c0787b18872abadc9ffec3223941a5ee0c2 Merge: 9f3f8cf caf1cf8 Author: DseidLi <2568818204@qq.com> Date: Thu Mar 14 22:07:28 2024 +0800 Merge branch 'main' into atc_choice Conflicts: configs/datasets/needlebench/readme.md configs/datasets/needlebench/readme_zh-CN.md configs/summarizers/needlebench.py opencompass/datasets/needlebench/atc.py opencompass/summarizers/needlebench.py commit 9f3f8cfb4452722734d334114ac1d14110e57406 Author: DseidLi <2568818204@qq.com> Date: Thu Mar 14 21:35:53 2024 +0800 add atc-choice test commit 52be7c1202376b4e09821188b826f1a805328129 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 6 02:54:15 2024 +0800 update needlebench randomseed and add vllm qwen14b commit fc1effce596ae2e5ece4933e8cd34aef8e64a6f9 Merge: 4e747ed caf1cf8 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 6 02:51:14 2024 +0800 Merge branch 'main' into add_model_configs commit 31834f9b23af3354ac3581ec86d693d0f05cdd1c Merge: 7dabc82 120bf8b Author: DseidLi <2568818204@qq.com> Date: Sun Mar 3 23:29:42 2024 +0800 Merge branch 'main' of https://github.com/open-compass/opencompass into atc_choice commit 4e747ed1988ddbcfcc7fff334601259ade72d363 Author: DseidLi <2568818204@qq.com> Date: Sun Mar 3 22:15:25 2024 +0800 add internlm2-lmdeploy model and gemma configs commit 7dabc828123d711c8cf834d6aab4137bb55e85ed Author: DseidLi <2568818204@qq.com> Date: Sat Mar 2 17:26:15 2024 +0800 add atc choice version -ZH commit 996f8ae43d3f946a052f736717ead139d153e2dd Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:58:56 2024 +0800 update readme for needlebench commit f7266e873cb34ccf18a7f20b2c5821af8416a14f Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:44:53 2024 +0800 move readme.md commit 1c7375681dea13996802e45b878dc4929ea8fa65 Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:38:31 2024 +0800 fix linting error commit b6524f3ebfb8a3a12a5ad3e3fa7a8a0921fcb6c1 Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:33:51 2024 +0800 lint summarizer commit c0d1190e39d3b6724f677346df2572df9af59f25 Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:29:03 2024 +0800 add needlebench intro, fix summarizer commit 0965baf78588e29d813b61d73f0ebd868a0ce3d0 Author: DseidLi <2568818204@qq.com> Date: Mon Feb 26 13:31:26 2024 +0800 fix bug in needlebench summarizer commit 5d32b31eb85382026935f356190ad92b103afd98 Author: DseidLi <2568818204@qq.com> Date: Sat Feb 24 03:19:08 2024 +0800 update act prompt commit af82a7f085e394d83aa84043e2881dd50115942c Merge: 32bf9fe 53fe788 Author: DseidLi <2568818204@qq.com> Date: Fri Feb 23 17:50:32 2024 +0800 Merge remote-tracking branch 'upstream/main' into needlebench commit 32bf9fe802eaf8e8e5b33ff17b2a897058f8b66b Author: DseidLi <2568818204@qq.com> Date: Fri Feb 23 17:31:32 2024 +0800 simplify needlebench 32k, 128k, 200k for eval commit a7cb025e05a48449de9839005fada02bd5bff15a Author: DseidLi <2568818204@qq.com> Date: Fri Feb 23 14:48:58 2024 +0800 add needlebench * fix summarizer * remove repeated code * remove chinese comments
2024-04-07 15:46:20 +08:00
raise ValueError(f'Unsupported position {position}. '
'Position must be "End" or "Start".')
else:
raise ValueError(f"Language '{language}' is not supported.")
return prompt
files = Path(path).glob('*.jsonl')
for file in files:
if file.name not in file_list:
continue
with open(file, 'r', encoding='utf-8') as f:
lines_bak = [json.loads(line.strip()) for line in f]
lines = lines_bak.copy()
for counter in range(num_repeats_per_file):
random.seed(counter)
random.shuffle(lines)
predefined_needles = predefined_needles_bak.copy()
[Feature] Add ATC Choice Version (#1019) * Squashed commit of the following: commit c48ad194c3976dc63d1b60d8c8ab2d5ff9e1cbfe Author: DseidLi <2568818204@qq.com> Date: Tue Apr 2 16:57:43 2024 +0800 add atc_choice commit 3ac6efea29619573e6fac8fa3cce464853dcead0 Merge: 2d4e559 8e3a9c3 Author: DseidLi <2568818204@qq.com> Date: Tue Apr 2 16:41:38 2024 +0800 Merge branch 'atc_choice' into atc_add_choice commit 8e3a9c396a3e5546d3faf584183f6fd60b974d5e Merge: 150a036 0a6a03f Author: DseidLi <2568818204@qq.com> Date: Tue Mar 26 04:47:07 2024 +0800 Merge branch 'main' into atc_choice Conflicts: configs/summarizers/needlebench.py opencompass/datasets/needlebench/multi.py opencompass/datasets/needlebench/origin.py opencompass/datasets/needlebench/parallel.py commit 150a036d6d990f26a57c974d1af83d88c31a0f9d Merge: 8d6ac9a 940dd18 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 20 03:49:08 2024 +0800 Merge branch 'needlebench_fix' into atc_choice commit 8d6ac9a1a43b1c9d0f0ea27e7d58968a203ea898 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 20 03:41:49 2024 +0800 optimize needlebench code commit 940dd18a4270f24bc69edd2a780182c68918e1a9 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 20 03:39:46 2024 +0800 fix vllm commit d8be6877bc41051f3edcc0421c462c834c0f1c9a Merge: ecad78a 2527fda Author: DseidLi <2568818204@qq.com> Date: Tue Mar 19 21:07:08 2024 +0800 Merge remote-tracking branch 'origin/add_1M_dataset' into atc_choice commit 2527fda8a546595bcaea1e5261367bc1097faec8 Author: DseidLi <2568818204@qq.com> Date: Tue Mar 19 16:03:40 2024 +0800 add model configs commit 75425acdf80d6d25ee24bb0aa60ac48539262e76 Author: DseidLi <2568818204@qq.com> Date: Tue Mar 19 16:02:15 2024 +0800 add prompt postion args commit 367ba1ba612a8cec5df1f80d5e5ae4e285baf38b Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 21:40:00 2024 +0800 add Needlebench-1000K configs commit ecad78af14c4bb00fe325779114b384c57ab30bf Author: DseidLi <2568818204@qq.com> Date: Thu Mar 14 22:08:32 2024 +0800 fix atc commit 08772c0787b18872abadc9ffec3223941a5ee0c2 Merge: 9f3f8cf caf1cf8 Author: DseidLi <2568818204@qq.com> Date: Thu Mar 14 22:07:28 2024 +0800 Merge branch 'main' into atc_choice Conflicts: configs/datasets/needlebench/readme.md configs/datasets/needlebench/readme_zh-CN.md configs/summarizers/needlebench.py opencompass/datasets/needlebench/atc.py opencompass/summarizers/needlebench.py commit 9f3f8cfb4452722734d334114ac1d14110e57406 Author: DseidLi <2568818204@qq.com> Date: Thu Mar 14 21:35:53 2024 +0800 add atc-choice test commit 52be7c1202376b4e09821188b826f1a805328129 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 6 02:54:15 2024 +0800 update needlebench randomseed and add vllm qwen14b commit fc1effce596ae2e5ece4933e8cd34aef8e64a6f9 Merge: 4e747ed caf1cf8 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 6 02:51:14 2024 +0800 Merge branch 'main' into add_model_configs commit 31834f9b23af3354ac3581ec86d693d0f05cdd1c Merge: 7dabc82 120bf8b Author: DseidLi <2568818204@qq.com> Date: Sun Mar 3 23:29:42 2024 +0800 Merge branch 'main' of https://github.com/open-compass/opencompass into atc_choice commit 4e747ed1988ddbcfcc7fff334601259ade72d363 Author: DseidLi <2568818204@qq.com> Date: Sun Mar 3 22:15:25 2024 +0800 add internlm2-lmdeploy model and gemma configs commit 7dabc828123d711c8cf834d6aab4137bb55e85ed Author: DseidLi <2568818204@qq.com> Date: Sat Mar 2 17:26:15 2024 +0800 add atc choice version -ZH commit 996f8ae43d3f946a052f736717ead139d153e2dd Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:58:56 2024 +0800 update readme for needlebench commit f7266e873cb34ccf18a7f20b2c5821af8416a14f Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:44:53 2024 +0800 move readme.md commit 1c7375681dea13996802e45b878dc4929ea8fa65 Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:38:31 2024 +0800 fix linting error commit b6524f3ebfb8a3a12a5ad3e3fa7a8a0921fcb6c1 Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:33:51 2024 +0800 lint summarizer commit c0d1190e39d3b6724f677346df2572df9af59f25 Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:29:03 2024 +0800 add needlebench intro, fix summarizer commit 0965baf78588e29d813b61d73f0ebd868a0ce3d0 Author: DseidLi <2568818204@qq.com> Date: Mon Feb 26 13:31:26 2024 +0800 fix bug in needlebench summarizer commit 5d32b31eb85382026935f356190ad92b103afd98 Author: DseidLi <2568818204@qq.com> Date: Sat Feb 24 03:19:08 2024 +0800 update act prompt commit af82a7f085e394d83aa84043e2881dd50115942c Merge: 32bf9fe 53fe788 Author: DseidLi <2568818204@qq.com> Date: Fri Feb 23 17:50:32 2024 +0800 Merge remote-tracking branch 'upstream/main' into needlebench commit 32bf9fe802eaf8e8e5b33ff17b2a897058f8b66b Author: DseidLi <2568818204@qq.com> Date: Fri Feb 23 17:31:32 2024 +0800 simplify needlebench 32k, 128k, 200k for eval commit a7cb025e05a48449de9839005fada02bd5bff15a Author: DseidLi <2568818204@qq.com> Date: Fri Feb 23 14:48:58 2024 +0800 add needlebench * fix summarizer * remove repeated code * remove chinese comments
2024-04-07 15:46:20 +08:00
random.seed(counter)
random.shuffle(predefined_needles)
needles = [
'\n' + item['needle'] + '\n' for item in predefined_needles
]
keywords = [item['arg2'] for item in predefined_needles]
if language == 'Chinese':
questions = ''.join([
item['retrieval_question'].split('')[0] + ''
for item in predefined_needles
])
answers_format = ''.join([
item['retrieval_question'].split("'")[1].split('')[0]
for item in predefined_needles
])
retrieval_question = questions + "请按照'" + \
answers_format + "'的格式回答。"
elif language == 'English':
questions = ''.join([
item['retrieval_question'].split('?')[0] + '?'
for item in predefined_needles
])
answers_format = ''.join([
item['retrieval_question'].split("'")[1].split('.')[0]
for item in predefined_needles
])
retrieval_question = questions + \
"Please answer in the format of '" + \
answers_format + "'"
context_length = length - length_buffer
target_length_per_record = context_length - \
sum(len(tokens) for tokens
in _get_tokens_from_context(needles))
target_length_per_record = max(target_length_per_record, 0)
accumulated_tokens = []
for line in lines:
tokens_current_line = _get_tokens_from_context(
line['text'])
accumulated_tokens.extend(tokens_current_line)
if len(accumulated_tokens) >= target_length_per_record:
break
processed_text = _generate_context(
accumulated_tokens[:target_length_per_record], depths,
needles)
processed_prompt = _generate_prompt(processed_text,
retrieval_question)
data['prompt'].append(processed_prompt)
data['answer'].append('*'.join(keywords) + '#' +
'*'.join(map(str, depths)))
dataset = Dataset.from_dict({
'prompt': data['prompt'],
'answer': data['answer'],
})
return dataset
class NeedleBenchParallelEvaluator(BaseEvaluator):
def levenshtein_distance(self, s1, s2):
if len(s1) < len(s2):
return self.levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
def score(self, predictions, gold):
if len(predictions) != len(gold):
return {'error': 'predictions and gold have different lengths'}
print('predictions:', predictions)
print('gold:', gold)
details = []
depths = [int(i) for i in gold[0].split('#')[1].split('*')]
scores_by_depth = {depth: 0 for depth in depths}
for prediction, reference in zip(predictions, gold):
print(reference)
keywords = reference.split('#')[0].split('*')
print(keywords)
for keyword, depth in zip(keywords, depths):
print('iterating:', keyword, depth)
if keyword in prediction:
print(f'{keyword} at depth {depth} is in {prediction}')
scores_by_depth[depth] += 100 / (len(predictions))
average_score = sum(scores_by_depth.values()) / len(scores_by_depth)
flattened_scores = {
'Depth' + str(depth): score
for depth, score in scores_by_depth.items()
}
result = {
**flattened_scores, 'details': details,
'average_score': average_score
}
return result