OpenCompass/opencompass/datasets/needlebench/parallel.py
Mo Li f2af49337d
[Feature] Add ATC Choice Version (#1019)
* Squashed commit of the following:

commit c48ad194c3976dc63d1b60d8c8ab2d5ff9e1cbfe
Author: DseidLi <2568818204@qq.com>
Date:   Tue Apr 2 16:57:43 2024 +0800

    add atc_choice

commit 3ac6efea29619573e6fac8fa3cce464853dcead0
Merge: 2d4e559 8e3a9c3
Author: DseidLi <2568818204@qq.com>
Date:   Tue Apr 2 16:41:38 2024 +0800

    Merge branch 'atc_choice' into atc_add_choice

commit 8e3a9c396a3e5546d3faf584183f6fd60b974d5e
Merge: 150a036 0a6a03f
Author: DseidLi <2568818204@qq.com>
Date:   Tue Mar 26 04:47:07 2024 +0800

    Merge branch 'main' into atc_choice

    Conflicts:
    	configs/summarizers/needlebench.py
    	opencompass/datasets/needlebench/multi.py
    	opencompass/datasets/needlebench/origin.py
    	opencompass/datasets/needlebench/parallel.py

commit 150a036d6d990f26a57c974d1af83d88c31a0f9d
Merge: 8d6ac9a 940dd18
Author: DseidLi <2568818204@qq.com>
Date:   Wed Mar 20 03:49:08 2024 +0800

    Merge branch 'needlebench_fix' into atc_choice

commit 8d6ac9a1a43b1c9d0f0ea27e7d58968a203ea898
Author: DseidLi <2568818204@qq.com>
Date:   Wed Mar 20 03:41:49 2024 +0800

    optimize needlebench code

commit 940dd18a4270f24bc69edd2a780182c68918e1a9
Author: DseidLi <2568818204@qq.com>
Date:   Wed Mar 20 03:39:46 2024 +0800

    fix vllm

commit d8be6877bc41051f3edcc0421c462c834c0f1c9a
Merge: ecad78a 2527fda
Author: DseidLi <2568818204@qq.com>
Date:   Tue Mar 19 21:07:08 2024 +0800

    Merge remote-tracking branch 'origin/add_1M_dataset' into atc_choice

commit 2527fda8a5
Author: DseidLi <2568818204@qq.com>
Date:   Tue Mar 19 16:03:40 2024 +0800

    add model configs

commit 75425acdf8
Author: DseidLi <2568818204@qq.com>
Date:   Tue Mar 19 16:02:15 2024 +0800

    add prompt postion args

commit 367ba1ba61
Author: DseidLi <2568818204@qq.com>
Date:   Wed Feb 28 21:40:00 2024 +0800

    add Needlebench-1000K configs

commit ecad78af14c4bb00fe325779114b384c57ab30bf
Author: DseidLi <2568818204@qq.com>
Date:   Thu Mar 14 22:08:32 2024 +0800

    fix atc

commit 08772c0787b18872abadc9ffec3223941a5ee0c2
Merge: 9f3f8cf caf1cf8
Author: DseidLi <2568818204@qq.com>
Date:   Thu Mar 14 22:07:28 2024 +0800

    Merge branch 'main' into atc_choice

    Conflicts:
    	configs/datasets/needlebench/readme.md
    	configs/datasets/needlebench/readme_zh-CN.md
    	configs/summarizers/needlebench.py
    	opencompass/datasets/needlebench/atc.py
    	opencompass/summarizers/needlebench.py

commit 9f3f8cfb4452722734d334114ac1d14110e57406
Author: DseidLi <2568818204@qq.com>
Date:   Thu Mar 14 21:35:53 2024 +0800

    add atc-choice test

commit 52be7c1202376b4e09821188b826f1a805328129
Author: DseidLi <2568818204@qq.com>
Date:   Wed Mar 6 02:54:15 2024 +0800

    update needlebench randomseed and add vllm qwen14b

commit fc1effce596ae2e5ece4933e8cd34aef8e64a6f9
Merge: 4e747ed caf1cf8
Author: DseidLi <2568818204@qq.com>
Date:   Wed Mar 6 02:51:14 2024 +0800

    Merge branch 'main' into add_model_configs

commit 31834f9b23af3354ac3581ec86d693d0f05cdd1c
Merge: 7dabc82 120bf8b
Author: DseidLi <2568818204@qq.com>
Date:   Sun Mar 3 23:29:42 2024 +0800

    Merge branch 'main' of https://github.com/open-compass/opencompass into atc_choice

commit 4e747ed1988ddbcfcc7fff334601259ade72d363
Author: DseidLi <2568818204@qq.com>
Date:   Sun Mar 3 22:15:25 2024 +0800

    add internlm2-lmdeploy model and gemma configs

commit 7dabc828123d711c8cf834d6aab4137bb55e85ed
Author: DseidLi <2568818204@qq.com>
Date:   Sat Mar 2 17:26:15 2024 +0800

    add atc choice version -ZH

commit 996f8ae43d
Author: DseidLi <2568818204@qq.com>
Date:   Wed Feb 28 16:58:56 2024 +0800

    update readme for needlebench

commit f7266e873c
Author: DseidLi <2568818204@qq.com>
Date:   Wed Feb 28 16:44:53 2024 +0800

    move readme.md

commit 1c7375681d
Author: DseidLi <2568818204@qq.com>
Date:   Wed Feb 28 16:38:31 2024 +0800

    fix linting error

commit b6524f3ebf
Author: DseidLi <2568818204@qq.com>
Date:   Wed Feb 28 16:33:51 2024 +0800

    lint summarizer

commit c0d1190e39
Author: DseidLi <2568818204@qq.com>
Date:   Wed Feb 28 16:29:03 2024 +0800

    add needlebench intro, fix summarizer

commit 0965baf785
Author: DseidLi <2568818204@qq.com>
Date:   Mon Feb 26 13:31:26 2024 +0800

    fix bug in needlebench summarizer

commit 5d32b31eb8
Author: DseidLi <2568818204@qq.com>
Date:   Sat Feb 24 03:19:08 2024 +0800

    update act prompt

commit af82a7f085
Merge: 32bf9fe 53fe788
Author: DseidLi <2568818204@qq.com>
Date:   Fri Feb 23 17:50:32 2024 +0800

    Merge remote-tracking branch 'upstream/main' into needlebench

commit 32bf9fe802
Author: DseidLi <2568818204@qq.com>
Date:   Fri Feb 23 17:31:32 2024 +0800

    simplify needlebench 32k, 128k, 200k for eval

commit a7cb025e05
Author: DseidLi <2568818204@qq.com>
Date:   Fri Feb 23 14:48:58 2024 +0800

    add needlebench

* fix summarizer

* remove repeated code

* remove chinese comments
2024-04-07 15:46:20 +08:00

312 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import random
from pathlib import Path
import tiktoken
from datasets import Dataset
from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET
def get_unique_entries(file_path,
n,
language,
unique_arg1=False,
unique_arg2=False,
unique_combination=False):
seen_arg1 = set()
seen_arg2 = set()
seen_combinations = set()
results = []
with open(file_path, 'r', encoding='utf-8') as file:
lines = file.readlines()
random.shuffle(lines)
for line in lines:
try:
entry = json.loads(line.strip())
except json.JSONDecodeError:
continue
if entry.get('language') != language:
continue
key1 = entry.get('arg1', '') if unique_arg1 else ''
key2 = entry.get('arg2', '') if unique_arg2 else ''
combination = (key1, key2) if unique_combination else ''
if (key1 not in seen_arg1 or not unique_arg1) and \
(key2 not in seen_arg2 or not unique_arg2) and \
(combination not in seen_combinations or not unique_combination):
seen_arg1.add(key1)
seen_arg2.add(key2)
seen_combinations.add(combination)
results.append(entry)
if len(results) == n:
break
return results
@LOAD_DATASET.register_module()
class NeedleBenchParallelDataset(BaseDataset):
@staticmethod
def load(
path: str,
needle_file_name: str,
length: int,
depths: list[int],
tokenizer_model: str,
file_list: list[str],
num_repeats_per_file: int,
length_buffer: int,
guide: bool,
language: str,
position: str = 'End',
):
data = {'prompt': [], 'answer': []}
tokenizer = tiktoken.encoding_for_model(tokenizer_model)
files = Path(path).glob('*.jsonl')
for file in files:
if file.name == needle_file_name:
needle_file_path = file
predefined_needles_bak = get_unique_entries(needle_file_path,
len(depths),
language,
unique_arg1=True,
unique_arg2=True,
unique_combination=True)
def _generate_context(tokens_context, depths, needles):
insertion_points = [
int(len(tokens_context) * (depth / 100)) for depth in depths
]
cumulative_inserted_length = 0
for i, needle in enumerate(needles):
needle_tokens = _get_tokens_from_context(needle)
current_insertion_point = min(
insertion_points[i] + cumulative_inserted_length,
len(tokens_context))
tokens_context = tokens_context[:current_insertion_point] + \
needle_tokens + tokens_context[current_insertion_point:]
cumulative_inserted_length += len(needle_tokens)
new_context = _decode_tokens(tokens_context)
return new_context
def _get_tokens_from_context(context):
if isinstance(context, list):
return [tokenizer.encode(item) for item in context]
else:
return tokenizer.encode(context)
def _decode_tokens(tokens):
return tokenizer.decode(tokens)
def _modify_retrieval_question(retrieval_question):
if language == 'Chinese':
parts = retrieval_question.split('请按照')
guide_retrieval_question = (parts[0] + '在回答之前,请思考文档中与此问题'
'最相关的内容是什么。请按照' + parts[1])
return guide_retrieval_question
elif language == 'English':
parts = retrieval_question.split('Please answer in the format')
guide_retrieval_question = (
parts[0] + 'Before answering, please consider'
' what in the document is most relevant to this question.'
' Please answer in the format' + parts[1])
return guide_retrieval_question
else:
raise ValueError(f"Language '{language}' is not supported.")
def _generate_prompt(context, retrieval_question):
if guide:
retrieval_question = _modify_retrieval_question(
retrieval_question)
if language == 'Chinese':
if position == 'End':
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
',或重复你的回答\n请先仔细阅读下面的文档再依次回答'
f'最后提出的问题\n用户现在给你的文档是{context}\n\n'
f'现在请问:{retrieval_question}\n')
elif position == 'Start':
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
',或重复你的回答\n请先仔细阅读下面的文档再依次回答'
f'最后提出的问题\n现在请问:{retrieval_question}\n\n'
f'用户现在给你的文档是{context}\n')
else:
raise ValueError(f'Unsupported position {position}. '
'Position must be "End" or "Start".')
elif language == 'English':
if position == 'End':
prompt = (
'You are an intelligent AI assistant skilled in '
'answering user questions.\n'
'Please keep your answers concise and clear. Do not'
' talk about irrelevant topics or repeat your '
'answers.\n'
f'The document given to you by the user is {context}'
f'\n\nNow, the questions are: {retrieval_question}\n')
elif position == 'Start':
prompt = (
'You are an intelligent AI assistant skilled in '
'answering user questions.\n'
'Please keep your answers concise and clear. Do not'
' talk about irrelevant topics or repeat your '
'answers.\n'
f'\nNow, the questions are: {retrieval_question}\n\n'
f'The document given to you by the user is {context}')
else:
raise ValueError(f'Unsupported position {position}. '
'Position must be "End" or "Start".')
else:
raise ValueError(f"Language '{language}' is not supported.")
return prompt
files = Path(path).glob('*.jsonl')
for file in files:
if file.name not in file_list:
continue
with open(file, 'r', encoding='utf-8') as f:
lines_bak = [json.loads(line.strip()) for line in f]
lines = lines_bak.copy()
for counter in range(num_repeats_per_file):
random.seed(counter)
random.shuffle(lines)
predefined_needles = predefined_needles_bak.copy()
random.seed(counter)
random.shuffle(predefined_needles)
needles = [
'\n' + item['needle'] + '\n' for item in predefined_needles
]
keywords = [item['arg2'] for item in predefined_needles]
if language == 'Chinese':
questions = ''.join([
item['retrieval_question'].split('')[0] + ''
for item in predefined_needles
])
answers_format = ''.join([
item['retrieval_question'].split("'")[1].split('')[0]
for item in predefined_needles
])
retrieval_question = questions + "请按照'" + \
answers_format + "'的格式回答。"
elif language == 'English':
questions = ''.join([
item['retrieval_question'].split('?')[0] + '?'
for item in predefined_needles
])
answers_format = ''.join([
item['retrieval_question'].split("'")[1].split('.')[0]
for item in predefined_needles
])
retrieval_question = questions + \
"Please answer in the format of '" + \
answers_format + "'"
context_length = length - length_buffer
target_length_per_record = context_length - \
sum(len(tokens) for tokens
in _get_tokens_from_context(needles))
target_length_per_record = max(target_length_per_record, 0)
accumulated_tokens = []
for line in lines:
tokens_current_line = _get_tokens_from_context(
line['text'])
accumulated_tokens.extend(tokens_current_line)
if len(accumulated_tokens) >= target_length_per_record:
break
processed_text = _generate_context(
accumulated_tokens[:target_length_per_record], depths,
needles)
processed_prompt = _generate_prompt(processed_text,
retrieval_question)
data['prompt'].append(processed_prompt)
data['answer'].append('*'.join(keywords) + '#' +
'*'.join(map(str, depths)))
dataset = Dataset.from_dict({
'prompt': data['prompt'],
'answer': data['answer'],
})
return dataset
class NeedleBenchParallelEvaluator(BaseEvaluator):
def levenshtein_distance(self, s1, s2):
if len(s1) < len(s2):
return self.levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
def score(self, predictions, gold):
if len(predictions) != len(gold):
return {'error': 'predictions and gold have different lengths'}
print('predictions:', predictions)
print('gold:', gold)
details = []
depths = [int(i) for i in gold[0].split('#')[1].split('*')]
scores_by_depth = {depth: 0 for depth in depths}
for prediction, reference in zip(predictions, gold):
print(reference)
keywords = reference.split('#')[0].split('*')
print(keywords)
for keyword, depth in zip(keywords, depths):
print('iterating:', keyword, depth)
if keyword in prediction:
print(f'{keyword} at depth {depth} is in {prediction}')
scores_by_depth[depth] += 100 / (len(predictions))
average_score = sum(scores_by_depth.values()) / len(scores_by_depth)
flattened_scores = {
'Depth' + str(depth): score
for depth, score in scores_by_depth.items()
}
result = {
**flattened_scores, 'details': details,
'average_score': average_score
}
return result