mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* Squashed commit of the following: commit c48ad194c3976dc63d1b60d8c8ab2d5ff9e1cbfe Author: DseidLi <2568818204@qq.com> Date: Tue Apr 2 16:57:43 2024 +0800 add atc_choice commit 3ac6efea29619573e6fac8fa3cce464853dcead0 Merge:2d4e559
8e3a9c3 Author: DseidLi <2568818204@qq.com> Date: Tue Apr 2 16:41:38 2024 +0800 Merge branch 'atc_choice' into atc_add_choice commit 8e3a9c396a3e5546d3faf584183f6fd60b974d5e Merge: 150a0360a6a03f
Author: DseidLi <2568818204@qq.com> Date: Tue Mar 26 04:47:07 2024 +0800 Merge branch 'main' into atc_choice Conflicts: configs/summarizers/needlebench.py opencompass/datasets/needlebench/multi.py opencompass/datasets/needlebench/origin.py opencompass/datasets/needlebench/parallel.py commit 150a036d6d990f26a57c974d1af83d88c31a0f9d Merge: 8d6ac9a 940dd18 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 20 03:49:08 2024 +0800 Merge branch 'needlebench_fix' into atc_choice commit 8d6ac9a1a43b1c9d0f0ea27e7d58968a203ea898 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 20 03:41:49 2024 +0800 optimize needlebench code commit 940dd18a4270f24bc69edd2a780182c68918e1a9 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 20 03:39:46 2024 +0800 fix vllm commit d8be6877bc41051f3edcc0421c462c834c0f1c9a Merge: ecad78a2527fda
Author: DseidLi <2568818204@qq.com> Date: Tue Mar 19 21:07:08 2024 +0800 Merge remote-tracking branch 'origin/add_1M_dataset' into atc_choice commit2527fda8a5
Author: DseidLi <2568818204@qq.com> Date: Tue Mar 19 16:03:40 2024 +0800 add model configs commit75425acdf8
Author: DseidLi <2568818204@qq.com> Date: Tue Mar 19 16:02:15 2024 +0800 add prompt postion args commit367ba1ba61
Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 21:40:00 2024 +0800 add Needlebench-1000K configs commit ecad78af14c4bb00fe325779114b384c57ab30bf Author: DseidLi <2568818204@qq.com> Date: Thu Mar 14 22:08:32 2024 +0800 fix atc commit 08772c0787b18872abadc9ffec3223941a5ee0c2 Merge: 9f3f8cfcaf1cf8
Author: DseidLi <2568818204@qq.com> Date: Thu Mar 14 22:07:28 2024 +0800 Merge branch 'main' into atc_choice Conflicts: configs/datasets/needlebench/readme.md configs/datasets/needlebench/readme_zh-CN.md configs/summarizers/needlebench.py opencompass/datasets/needlebench/atc.py opencompass/summarizers/needlebench.py commit 9f3f8cfb4452722734d334114ac1d14110e57406 Author: DseidLi <2568818204@qq.com> Date: Thu Mar 14 21:35:53 2024 +0800 add atc-choice test commit 52be7c1202376b4e09821188b826f1a805328129 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 6 02:54:15 2024 +0800 update needlebench randomseed and add vllm qwen14b commit fc1effce596ae2e5ece4933e8cd34aef8e64a6f9 Merge: 4e747edcaf1cf8
Author: DseidLi <2568818204@qq.com> Date: Wed Mar 6 02:51:14 2024 +0800 Merge branch 'main' into add_model_configs commit 31834f9b23af3354ac3581ec86d693d0f05cdd1c Merge: 7dabc82120bf8b
Author: DseidLi <2568818204@qq.com> Date: Sun Mar 3 23:29:42 2024 +0800 Merge branch 'main' of https://github.com/open-compass/opencompass into atc_choice commit 4e747ed1988ddbcfcc7fff334601259ade72d363 Author: DseidLi <2568818204@qq.com> Date: Sun Mar 3 22:15:25 2024 +0800 add internlm2-lmdeploy model and gemma configs commit 7dabc828123d711c8cf834d6aab4137bb55e85ed Author: DseidLi <2568818204@qq.com> Date: Sat Mar 2 17:26:15 2024 +0800 add atc choice version -ZH commit996f8ae43d
Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:58:56 2024 +0800 update readme for needlebench commitf7266e873c
Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:44:53 2024 +0800 move readme.md commit1c7375681d
Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:38:31 2024 +0800 fix linting error commitb6524f3ebf
Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:33:51 2024 +0800 lint summarizer commitc0d1190e39
Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:29:03 2024 +0800 add needlebench intro, fix summarizer commit0965baf785
Author: DseidLi <2568818204@qq.com> Date: Mon Feb 26 13:31:26 2024 +0800 fix bug in needlebench summarizer commit5d32b31eb8
Author: DseidLi <2568818204@qq.com> Date: Sat Feb 24 03:19:08 2024 +0800 update act prompt commitaf82a7f085
Merge:32bf9fe
53fe788
Author: DseidLi <2568818204@qq.com> Date: Fri Feb 23 17:50:32 2024 +0800 Merge remote-tracking branch 'upstream/main' into needlebench commit32bf9fe802
Author: DseidLi <2568818204@qq.com> Date: Fri Feb 23 17:31:32 2024 +0800 simplify needlebench 32k, 128k, 200k for eval commita7cb025e05
Author: DseidLi <2568818204@qq.com> Date: Fri Feb 23 14:48:58 2024 +0800 add needlebench * fix summarizer * remove repeated code * remove chinese comments
278 lines
11 KiB
Python
278 lines
11 KiB
Python
import json
|
|
import os
|
|
import random
|
|
import re
|
|
from pathlib import Path
|
|
|
|
import tiktoken
|
|
from datasets import Dataset
|
|
|
|
from opencompass.datasets.base import BaseDataset
|
|
from opencompass.openicl import BaseEvaluator
|
|
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
|
|
|
|
|
def get_random_line_by_language(counter, file_path, language):
|
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|
lines = [
|
|
json.loads(line.strip()) for line in file
|
|
if json.loads(line.strip())['language'] == language
|
|
]
|
|
|
|
if lines:
|
|
random.seed(counter)
|
|
random_line = random.choice(lines)
|
|
return {
|
|
'needle': random_line['needle'],
|
|
'retrieval_question': random_line['retrieval_question'],
|
|
'keyword': random_line['arg2']
|
|
}
|
|
else:
|
|
return None
|
|
|
|
|
|
@LOAD_DATASET.register_module()
|
|
class NeedleBenchOriginDataset(BaseDataset):
|
|
|
|
@staticmethod
|
|
def load(
|
|
path: str,
|
|
length: int,
|
|
depth: int,
|
|
tokenizer_model: str,
|
|
file_list: list[str],
|
|
num_repeats_per_file: int,
|
|
length_buffer: int,
|
|
guide: bool,
|
|
language: str,
|
|
needle_file_name: str,
|
|
position: str = 'End',
|
|
):
|
|
data = {'prompt': [], 'answer': []}
|
|
tokenizer = tiktoken.encoding_for_model(tokenizer_model)
|
|
|
|
def _generate_context(tokens_context, depth_percent, needle):
|
|
tokens_needle = _get_tokens_from_context(needle)
|
|
insertion_point = int(len(tokens_context) * (depth_percent / 100))
|
|
tokens_context = (tokens_context[:insertion_point] +
|
|
tokens_needle + tokens_context[insertion_point:])
|
|
new_context = _decode_tokens(tokens_context)
|
|
return new_context
|
|
|
|
def _get_tokens_from_context(context):
|
|
return tokenizer.encode(context)
|
|
|
|
def _decode_tokens(tokens):
|
|
return tokenizer.decode(tokens)
|
|
|
|
def _modify_retrieval_question(retrieval_question):
|
|
if language == 'Chinese':
|
|
parts = retrieval_question.split('请按照')
|
|
guide_retrieval_question = (parts[0] + '在回答之前,请思考文档中与此问题'
|
|
'最相关的内容是什么。请按照' + parts[1])
|
|
return guide_retrieval_question
|
|
elif language == 'English':
|
|
parts = retrieval_question.split('Please answer in the format')
|
|
guide_retrieval_question = (
|
|
parts[0] + 'Before answering, please consider'
|
|
' what in the document is most relevant to this question.'
|
|
' Please answer in the format' + parts[1])
|
|
return guide_retrieval_question
|
|
else:
|
|
raise ValueError(f"Language '{language}' is not supported.")
|
|
|
|
def _generate_prompt(context, retrieval_question):
|
|
if guide:
|
|
retrieval_question = _modify_retrieval_question(
|
|
retrieval_question)
|
|
|
|
if language == 'Chinese':
|
|
if position == 'End':
|
|
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
|
|
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
|
|
',或重复你的回答\n'
|
|
f'用户现在给你的文档是{context}\n\n'
|
|
f'现在请问:{retrieval_question}')
|
|
elif position == 'Start':
|
|
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
|
|
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
|
|
',或重复你的回答\n'
|
|
f'现在请问:{retrieval_question}',
|
|
f'用户现在给你的文档是{context}\n\n')
|
|
else:
|
|
raise ValueError('Unsupported position. '
|
|
'Position must be "End" or "Start".')
|
|
elif language == 'English':
|
|
if position == 'End':
|
|
prompt = ('You are an intelligent AI assistant skilled in '
|
|
'answering user questions.\n'
|
|
'Please keep your answers concise and clear. Do '
|
|
'not talk about irrelevant topics or repeat '
|
|
'your answers.\nThe document '
|
|
f'given to you by the user is {context}\n\n'
|
|
f'Now, the question is: {retrieval_question}')
|
|
elif position == 'Start':
|
|
prompt = ('You are an intelligent AI assistant skilled in '
|
|
'answering user questions.\n'
|
|
'Please keep your answers concise and clear. Do '
|
|
'not talk about irrelevant topics or repeat '
|
|
'your answers.\n'
|
|
f'Now, the question is: {retrieval_question}'
|
|
'The document given to you by the user'
|
|
f' is {context}\n\n')
|
|
else:
|
|
raise ValueError(f'Unsupported position {position}. '
|
|
'Position must be "End" or "Start".')
|
|
else:
|
|
raise ValueError(f"Language '{language}' is not supported.")
|
|
|
|
return prompt
|
|
|
|
files = Path(path).glob('*.jsonl')
|
|
for file in files:
|
|
if file.name not in file_list:
|
|
continue
|
|
|
|
with open(file, 'r', encoding='utf-8') as f:
|
|
lines_bak = [json.loads(line.strip()) for line in f]
|
|
lines = lines_bak.copy()
|
|
for counter in range(num_repeats_per_file):
|
|
random.seed(counter)
|
|
random.shuffle(lines)
|
|
needle_file_path = os.path.join(path, needle_file_name)
|
|
random_needle = get_random_line_by_language(
|
|
counter, needle_file_path, language)
|
|
needle = '\n' + random_needle['needle'] + '\n'
|
|
retrieval_question = random_needle['retrieval_question']
|
|
keyword = random_needle['keyword']
|
|
|
|
context_length = length - length_buffer
|
|
target_length_per_record = context_length - len(
|
|
_get_tokens_from_context(needle))
|
|
target_length_per_record = max(target_length_per_record, 0)
|
|
accumulated_tokens = []
|
|
for line in lines:
|
|
tokens_current_line = _get_tokens_from_context(
|
|
line['text'])
|
|
accumulated_tokens.extend(tokens_current_line)
|
|
|
|
if len(accumulated_tokens) >= target_length_per_record:
|
|
break
|
|
|
|
processed_text = _generate_context(
|
|
accumulated_tokens[:target_length_per_record], depth,
|
|
needle)
|
|
|
|
processed_prompt = _generate_prompt(processed_text,
|
|
retrieval_question)
|
|
|
|
data['prompt'].append(processed_prompt)
|
|
data['answer'].append(needle + '*' + keyword)
|
|
|
|
dataset = Dataset.from_dict({
|
|
'prompt': data['prompt'],
|
|
'answer': data['answer'],
|
|
})
|
|
return dataset
|
|
|
|
|
|
class NeedleBenchOriginEvaluator(BaseEvaluator):
|
|
|
|
def __init__(self, use_trim=False):
|
|
self.use_trim = use_trim
|
|
|
|
@staticmethod
|
|
def _trim_prediction(prediction, reference):
|
|
"""Trims the prediction string based on the length of the reference
|
|
string.
|
|
|
|
Args:
|
|
prediction (str): The prediction string.
|
|
reference (str): The reference string.
|
|
|
|
Returns:
|
|
str: The trimmed prediction string.
|
|
"""
|
|
l08 = int(0.8 * len(reference))
|
|
l12 = int(1.2 * len(reference))
|
|
trimmed_prediction = prediction[:l12]
|
|
|
|
if len(trimmed_prediction) > l08 and \
|
|
reference[-1] in trimmed_prediction[l08:]:
|
|
end_pos = l08 + trimmed_prediction[l08:].index(reference[-1]) + 1
|
|
trimmed_prediction = trimmed_prediction[:end_pos]
|
|
|
|
return trimmed_prediction
|
|
|
|
def levenshtein_distance(self, s1, s2):
|
|
if len(s1) < len(s2):
|
|
return self.levenshtein_distance(s2, s1)
|
|
|
|
if len(s2) == 0:
|
|
return len(s1)
|
|
|
|
previous_row = range(len(s2) + 1)
|
|
for i, c1 in enumerate(s1):
|
|
current_row = [i + 1]
|
|
for j, c2 in enumerate(s2):
|
|
insertions = previous_row[j + 1] + 1
|
|
deletions = current_row[j] + 1
|
|
substitutions = previous_row[j] + (c1 != c2)
|
|
current_row.append(min(insertions, deletions, substitutions))
|
|
previous_row = current_row
|
|
|
|
return previous_row[-1]
|
|
|
|
def score(self, predictions, gold):
|
|
|
|
if len(predictions) != len(gold):
|
|
return {'error': 'predictions and gold have different lengths'}
|
|
|
|
total_score = 0
|
|
details = []
|
|
for prediction, reference in zip(predictions, gold):
|
|
keyword = reference.split('*')[1]
|
|
reference = reference.split('*')[0]
|
|
raw_prediction = prediction
|
|
prediction = re.sub(r'\s+', '', prediction)
|
|
reference = re.sub(r'\s+', '', reference)
|
|
|
|
if self.use_trim:
|
|
prediction = NeedleBenchOriginEvaluator._trim_prediction(
|
|
prediction, reference)
|
|
|
|
edit_distance = self.levenshtein_distance(prediction, reference)
|
|
max_len = max(len(prediction), len(reference))
|
|
score = 100 * (1 -
|
|
edit_distance / max_len) if max_len != 0 else 100
|
|
|
|
if keyword in raw_prediction:
|
|
print(f'{keyword} is in {prediction}')
|
|
score = 100
|
|
else:
|
|
print(f'{keyword} is not in {prediction}')
|
|
score = 0.2 * score
|
|
|
|
detail = {
|
|
'pred': prediction,
|
|
'answer': reference,
|
|
'edit_distance': edit_distance,
|
|
'score': score
|
|
}
|
|
total_score += score
|
|
details.append(detail)
|
|
|
|
average_score = total_score / len(predictions) if predictions else 0
|
|
result = {'score': average_score, 'details': details}
|
|
return result
|
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('needlebench')
|
|
def needlebench_postprocess(text: str) -> str:
|
|
return text
|
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('needlebench_dataset')
|
|
def needlebench_dataset_postprocess(text: str) -> str:
|
|
return text
|