OpenCompass/opencompass/datasets/needlebench/origin.py
Mo Li f2af49337d
[Feature] Add ATC Choice Version (#1019)
* Squashed commit of the following:

commit c48ad194c3976dc63d1b60d8c8ab2d5ff9e1cbfe
Author: DseidLi <2568818204@qq.com>
Date:   Tue Apr 2 16:57:43 2024 +0800

    add atc_choice

commit 3ac6efea29619573e6fac8fa3cce464853dcead0
Merge: 2d4e559 8e3a9c3
Author: DseidLi <2568818204@qq.com>
Date:   Tue Apr 2 16:41:38 2024 +0800

    Merge branch 'atc_choice' into atc_add_choice

commit 8e3a9c396a3e5546d3faf584183f6fd60b974d5e
Merge: 150a036 0a6a03f
Author: DseidLi <2568818204@qq.com>
Date:   Tue Mar 26 04:47:07 2024 +0800

    Merge branch 'main' into atc_choice

    Conflicts:
    	configs/summarizers/needlebench.py
    	opencompass/datasets/needlebench/multi.py
    	opencompass/datasets/needlebench/origin.py
    	opencompass/datasets/needlebench/parallel.py

commit 150a036d6d990f26a57c974d1af83d88c31a0f9d
Merge: 8d6ac9a 940dd18
Author: DseidLi <2568818204@qq.com>
Date:   Wed Mar 20 03:49:08 2024 +0800

    Merge branch 'needlebench_fix' into atc_choice

commit 8d6ac9a1a43b1c9d0f0ea27e7d58968a203ea898
Author: DseidLi <2568818204@qq.com>
Date:   Wed Mar 20 03:41:49 2024 +0800

    optimize needlebench code

commit 940dd18a4270f24bc69edd2a780182c68918e1a9
Author: DseidLi <2568818204@qq.com>
Date:   Wed Mar 20 03:39:46 2024 +0800

    fix vllm

commit d8be6877bc41051f3edcc0421c462c834c0f1c9a
Merge: ecad78a 2527fda
Author: DseidLi <2568818204@qq.com>
Date:   Tue Mar 19 21:07:08 2024 +0800

    Merge remote-tracking branch 'origin/add_1M_dataset' into atc_choice

commit 2527fda8a5
Author: DseidLi <2568818204@qq.com>
Date:   Tue Mar 19 16:03:40 2024 +0800

    add model configs

commit 75425acdf8
Author: DseidLi <2568818204@qq.com>
Date:   Tue Mar 19 16:02:15 2024 +0800

    add prompt postion args

commit 367ba1ba61
Author: DseidLi <2568818204@qq.com>
Date:   Wed Feb 28 21:40:00 2024 +0800

    add Needlebench-1000K configs

commit ecad78af14c4bb00fe325779114b384c57ab30bf
Author: DseidLi <2568818204@qq.com>
Date:   Thu Mar 14 22:08:32 2024 +0800

    fix atc

commit 08772c0787b18872abadc9ffec3223941a5ee0c2
Merge: 9f3f8cf caf1cf8
Author: DseidLi <2568818204@qq.com>
Date:   Thu Mar 14 22:07:28 2024 +0800

    Merge branch 'main' into atc_choice

    Conflicts:
    	configs/datasets/needlebench/readme.md
    	configs/datasets/needlebench/readme_zh-CN.md
    	configs/summarizers/needlebench.py
    	opencompass/datasets/needlebench/atc.py
    	opencompass/summarizers/needlebench.py

commit 9f3f8cfb4452722734d334114ac1d14110e57406
Author: DseidLi <2568818204@qq.com>
Date:   Thu Mar 14 21:35:53 2024 +0800

    add atc-choice test

commit 52be7c1202376b4e09821188b826f1a805328129
Author: DseidLi <2568818204@qq.com>
Date:   Wed Mar 6 02:54:15 2024 +0800

    update needlebench randomseed and add vllm qwen14b

commit fc1effce596ae2e5ece4933e8cd34aef8e64a6f9
Merge: 4e747ed caf1cf8
Author: DseidLi <2568818204@qq.com>
Date:   Wed Mar 6 02:51:14 2024 +0800

    Merge branch 'main' into add_model_configs

commit 31834f9b23af3354ac3581ec86d693d0f05cdd1c
Merge: 7dabc82 120bf8b
Author: DseidLi <2568818204@qq.com>
Date:   Sun Mar 3 23:29:42 2024 +0800

    Merge branch 'main' of https://github.com/open-compass/opencompass into atc_choice

commit 4e747ed1988ddbcfcc7fff334601259ade72d363
Author: DseidLi <2568818204@qq.com>
Date:   Sun Mar 3 22:15:25 2024 +0800

    add internlm2-lmdeploy model and gemma configs

commit 7dabc828123d711c8cf834d6aab4137bb55e85ed
Author: DseidLi <2568818204@qq.com>
Date:   Sat Mar 2 17:26:15 2024 +0800

    add atc choice version -ZH

commit 996f8ae43d
Author: DseidLi <2568818204@qq.com>
Date:   Wed Feb 28 16:58:56 2024 +0800

    update readme for needlebench

commit f7266e873c
Author: DseidLi <2568818204@qq.com>
Date:   Wed Feb 28 16:44:53 2024 +0800

    move readme.md

commit 1c7375681d
Author: DseidLi <2568818204@qq.com>
Date:   Wed Feb 28 16:38:31 2024 +0800

    fix linting error

commit b6524f3ebf
Author: DseidLi <2568818204@qq.com>
Date:   Wed Feb 28 16:33:51 2024 +0800

    lint summarizer

commit c0d1190e39
Author: DseidLi <2568818204@qq.com>
Date:   Wed Feb 28 16:29:03 2024 +0800

    add needlebench intro, fix summarizer

commit 0965baf785
Author: DseidLi <2568818204@qq.com>
Date:   Mon Feb 26 13:31:26 2024 +0800

    fix bug in needlebench summarizer

commit 5d32b31eb8
Author: DseidLi <2568818204@qq.com>
Date:   Sat Feb 24 03:19:08 2024 +0800

    update act prompt

commit af82a7f085
Merge: 32bf9fe 53fe788
Author: DseidLi <2568818204@qq.com>
Date:   Fri Feb 23 17:50:32 2024 +0800

    Merge remote-tracking branch 'upstream/main' into needlebench

commit 32bf9fe802
Author: DseidLi <2568818204@qq.com>
Date:   Fri Feb 23 17:31:32 2024 +0800

    simplify needlebench 32k, 128k, 200k for eval

commit a7cb025e05
Author: DseidLi <2568818204@qq.com>
Date:   Fri Feb 23 14:48:58 2024 +0800

    add needlebench

* fix summarizer

* remove repeated code

* remove chinese comments
2024-04-07 15:46:20 +08:00

278 lines
11 KiB
Python

import json
import os
import random
import re
from pathlib import Path
import tiktoken
from datasets import Dataset
from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
def get_random_line_by_language(counter, file_path, language):
with open(file_path, 'r', encoding='utf-8') as file:
lines = [
json.loads(line.strip()) for line in file
if json.loads(line.strip())['language'] == language
]
if lines:
random.seed(counter)
random_line = random.choice(lines)
return {
'needle': random_line['needle'],
'retrieval_question': random_line['retrieval_question'],
'keyword': random_line['arg2']
}
else:
return None
@LOAD_DATASET.register_module()
class NeedleBenchOriginDataset(BaseDataset):
@staticmethod
def load(
path: str,
length: int,
depth: int,
tokenizer_model: str,
file_list: list[str],
num_repeats_per_file: int,
length_buffer: int,
guide: bool,
language: str,
needle_file_name: str,
position: str = 'End',
):
data = {'prompt': [], 'answer': []}
tokenizer = tiktoken.encoding_for_model(tokenizer_model)
def _generate_context(tokens_context, depth_percent, needle):
tokens_needle = _get_tokens_from_context(needle)
insertion_point = int(len(tokens_context) * (depth_percent / 100))
tokens_context = (tokens_context[:insertion_point] +
tokens_needle + tokens_context[insertion_point:])
new_context = _decode_tokens(tokens_context)
return new_context
def _get_tokens_from_context(context):
return tokenizer.encode(context)
def _decode_tokens(tokens):
return tokenizer.decode(tokens)
def _modify_retrieval_question(retrieval_question):
if language == 'Chinese':
parts = retrieval_question.split('请按照')
guide_retrieval_question = (parts[0] + '在回答之前,请思考文档中与此问题'
'最相关的内容是什么。请按照' + parts[1])
return guide_retrieval_question
elif language == 'English':
parts = retrieval_question.split('Please answer in the format')
guide_retrieval_question = (
parts[0] + 'Before answering, please consider'
' what in the document is most relevant to this question.'
' Please answer in the format' + parts[1])
return guide_retrieval_question
else:
raise ValueError(f"Language '{language}' is not supported.")
def _generate_prompt(context, retrieval_question):
if guide:
retrieval_question = _modify_retrieval_question(
retrieval_question)
if language == 'Chinese':
if position == 'End':
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
',或重复你的回答\n'
f'用户现在给你的文档是{context}\n\n'
f'现在请问:{retrieval_question}')
elif position == 'Start':
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
',或重复你的回答\n'
f'现在请问:{retrieval_question}',
f'用户现在给你的文档是{context}\n\n')
else:
raise ValueError('Unsupported position. '
'Position must be "End" or "Start".')
elif language == 'English':
if position == 'End':
prompt = ('You are an intelligent AI assistant skilled in '
'answering user questions.\n'
'Please keep your answers concise and clear. Do '
'not talk about irrelevant topics or repeat '
'your answers.\nThe document '
f'given to you by the user is {context}\n\n'
f'Now, the question is: {retrieval_question}')
elif position == 'Start':
prompt = ('You are an intelligent AI assistant skilled in '
'answering user questions.\n'
'Please keep your answers concise and clear. Do '
'not talk about irrelevant topics or repeat '
'your answers.\n'
f'Now, the question is: {retrieval_question}'
'The document given to you by the user'
f' is {context}\n\n')
else:
raise ValueError(f'Unsupported position {position}. '
'Position must be "End" or "Start".')
else:
raise ValueError(f"Language '{language}' is not supported.")
return prompt
files = Path(path).glob('*.jsonl')
for file in files:
if file.name not in file_list:
continue
with open(file, 'r', encoding='utf-8') as f:
lines_bak = [json.loads(line.strip()) for line in f]
lines = lines_bak.copy()
for counter in range(num_repeats_per_file):
random.seed(counter)
random.shuffle(lines)
needle_file_path = os.path.join(path, needle_file_name)
random_needle = get_random_line_by_language(
counter, needle_file_path, language)
needle = '\n' + random_needle['needle'] + '\n'
retrieval_question = random_needle['retrieval_question']
keyword = random_needle['keyword']
context_length = length - length_buffer
target_length_per_record = context_length - len(
_get_tokens_from_context(needle))
target_length_per_record = max(target_length_per_record, 0)
accumulated_tokens = []
for line in lines:
tokens_current_line = _get_tokens_from_context(
line['text'])
accumulated_tokens.extend(tokens_current_line)
if len(accumulated_tokens) >= target_length_per_record:
break
processed_text = _generate_context(
accumulated_tokens[:target_length_per_record], depth,
needle)
processed_prompt = _generate_prompt(processed_text,
retrieval_question)
data['prompt'].append(processed_prompt)
data['answer'].append(needle + '*' + keyword)
dataset = Dataset.from_dict({
'prompt': data['prompt'],
'answer': data['answer'],
})
return dataset
class NeedleBenchOriginEvaluator(BaseEvaluator):
def __init__(self, use_trim=False):
self.use_trim = use_trim
@staticmethod
def _trim_prediction(prediction, reference):
"""Trims the prediction string based on the length of the reference
string.
Args:
prediction (str): The prediction string.
reference (str): The reference string.
Returns:
str: The trimmed prediction string.
"""
l08 = int(0.8 * len(reference))
l12 = int(1.2 * len(reference))
trimmed_prediction = prediction[:l12]
if len(trimmed_prediction) > l08 and \
reference[-1] in trimmed_prediction[l08:]:
end_pos = l08 + trimmed_prediction[l08:].index(reference[-1]) + 1
trimmed_prediction = trimmed_prediction[:end_pos]
return trimmed_prediction
def levenshtein_distance(self, s1, s2):
if len(s1) < len(s2):
return self.levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
def score(self, predictions, gold):
if len(predictions) != len(gold):
return {'error': 'predictions and gold have different lengths'}
total_score = 0
details = []
for prediction, reference in zip(predictions, gold):
keyword = reference.split('*')[1]
reference = reference.split('*')[0]
raw_prediction = prediction
prediction = re.sub(r'\s+', '', prediction)
reference = re.sub(r'\s+', '', reference)
if self.use_trim:
prediction = NeedleBenchOriginEvaluator._trim_prediction(
prediction, reference)
edit_distance = self.levenshtein_distance(prediction, reference)
max_len = max(len(prediction), len(reference))
score = 100 * (1 -
edit_distance / max_len) if max_len != 0 else 100
if keyword in raw_prediction:
print(f'{keyword} is in {prediction}')
score = 100
else:
print(f'{keyword} is not in {prediction}')
score = 0.2 * score
detail = {
'pred': prediction,
'answer': reference,
'edit_distance': edit_distance,
'score': score
}
total_score += score
details.append(detail)
average_score = total_score / len(predictions) if predictions else 0
result = {'score': average_score, 'details': details}
return result
@TEXT_POSTPROCESSORS.register_module('needlebench')
def needlebench_postprocess(text: str) -> str:
return text
@TEXT_POSTPROCESSORS.register_module('needlebench_dataset')
def needlebench_dataset_postprocess(text: str) -> str:
return text