OpenCompass/opencompass/datasets/needlebench/multi.py

233 lines
8.7 KiB
Python
Raw Normal View History

import json
2024-02-23 14:48:58 +08:00
import os
import random
from pathlib import Path
import tiktoken
from datasets import Dataset
from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
2024-02-23 14:48:58 +08:00
from opencompass.registry import LOAD_DATASET
def get_random_needles(file_path, needle_count):
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
matching_records = [
record for record in data
if record.get('derivation_count') == needle_count
]
if matching_records:
random_record = random.choice(matching_records)
return {
'needles': random_record['derivations'],
'answer': random_record['answer'],
'retrieval_question': random_record['question']
}
else:
return None
@LOAD_DATASET.register_module()
2024-02-23 14:48:58 +08:00
class NeedleBenchMultiDataset(BaseDataset):
@staticmethod
def load(
path: str,
length: int,
depth: int,
tokenizer_model: str,
file_list: 'list[str]',
num_repeats_per_file: int,
length_buffer: int,
guide: bool,
language: str,
2024-02-23 14:48:58 +08:00
needle_file_name: str,
num_needles: int,
diff: int,
):
data = {'prompt': [], 'answer': []}
tokenizer = tiktoken.encoding_for_model(tokenizer_model)
def _generate_context(tokens_context, depth_percent, needles):
tokens_needle = [
_get_tokens_from_context(needle) for needle in needles
]
insertion_points = []
total_length = len(tokens_context)
for i, needle_tokens in enumerate(tokens_needle):
if i == 0:
insertion_point = int(total_length * (depth_percent / 100))
else:
insertion_point = int(insertion_points[i - 1] +
len(tokens_needle[i - 1]) +
total_length * (diff / 100))
insertion_point = min(
insertion_point,
total_length + sum(len(tn) for tn in tokens_needle[:i]))
insertion_points.append(insertion_point)
for i, needle_tokens in enumerate(tokens_needle):
tokens_context = tokens_context[:insertion_points[i]] \
+ needle_tokens + tokens_context[insertion_points[i]:]
for j in range(i + 1, len(insertion_points)):
insertion_points[j] += len(needle_tokens)
new_context = _decode_tokens(tokens_context)
return new_context
def _get_tokens_from_context(context):
if isinstance(context, list):
return [tokenizer.encode(item) for item in context]
else:
return tokenizer.encode(context)
def _decode_tokens(tokens):
return tokenizer.decode(tokens)
def _modify_retrieval_question(retrieval_question):
if language == 'Chinese':
2024-02-23 14:48:58 +08:00
guide_retrieval_question = (retrieval_question +
'在回答之前,请思考文档中与此问题'
'最相关的内容是什么。')
return guide_retrieval_question
elif language == 'English':
guide_retrieval_question = (
2024-02-23 14:48:58 +08:00
retrieval_question + 'Before answering, please consider'
' what in the document is most relevant to this question.')
return guide_retrieval_question
else:
raise ValueError(f"Language '{language}' is not supported.")
def _generate_prompt(context, retrieval_question):
if guide:
retrieval_question = _modify_retrieval_question(
retrieval_question)
if language == 'Chinese':
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
',或重复你的回答\n'
f'用户现在给你的文档是{context}\n\n'
f'现在请问:{retrieval_question}')
elif language == 'English':
prompt = ('You are an intelligent AI assistant skilled in '
'answering user questions.\n'
'Please keep your answers concise and clear. Do not'
' talk about irrelevant topics or repeat your '
'answers.\n'
f'The document given to you by the user is {context}'
f'\n\nNow, the question is: {retrieval_question}')
else:
raise ValueError(f"Language '{language}' is not supported.")
return prompt
files = Path(path).glob('*.jsonl')
2024-02-23 14:48:58 +08:00
needle_file_path = os.path.join(path, needle_file_name)
for file in files:
if file.name not in file_list:
continue
with open(file, 'r', encoding='utf-8') as f:
lines_bak = [json.loads(line.strip()) for line in f]
lines = lines_bak.copy()
for counter in range(num_repeats_per_file):
random.seed(counter)
random.shuffle(lines)
2024-02-23 14:48:58 +08:00
random_needle_data = get_random_needles(
needle_file_path, num_needles)
needles = [
'\n' + needle + '\n'
for needle in random_needle_data['needles']
]
answer = random_needle_data['answer']
keyword = answer
retrieval_question = random_needle_data['retrieval_question']
context_length = length - length_buffer
target_length_per_record = context_length - \
sum(len(tokens) for tokens
in _get_tokens_from_context(needles))
2024-02-23 14:48:58 +08:00
target_length_per_record = max(target_length_per_record, 0)
accumulated_tokens = []
for line in lines:
tokens_current_line = _get_tokens_from_context(
line['text'])
accumulated_tokens.extend(tokens_current_line)
if len(accumulated_tokens) >= target_length_per_record:
break
processed_text = _generate_context(
accumulated_tokens[:target_length_per_record], depth,
needles)
processed_prompt = _generate_prompt(processed_text,
retrieval_question)
data['prompt'].append(processed_prompt)
data['answer'].append(answer + '*' + keyword)
dataset = Dataset.from_dict({
'prompt': data['prompt'],
'answer': data['answer'],
})
return dataset
2024-02-23 14:48:58 +08:00
class NeedleBenchMultiEvaluator(BaseEvaluator):
def levenshtein_distance(self, s1, s2):
if len(s1) < len(s2):
return self.levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
2024-02-23 14:48:58 +08:00
def score(self, predictions, gold):
if len(predictions) != len(gold):
return {'error': 'predictions and gold have different lengths'}
total_score = 0
details = []
2024-02-23 14:48:58 +08:00
for prediction, reference in zip(predictions, gold):
answer, keyword = reference.split('*')
keywords = keyword.lower().split()
prediction = prediction.lower()
keyword_score = 100 / len(keywords) if keywords else 0
matched_keywords = sum(1 for kword in keywords
if kword in prediction)
score = matched_keywords * keyword_score
detail = {
'pred': prediction,
'answer': reference,
2024-02-23 14:48:58 +08:00
'matched_keywords': matched_keywords,
'score': score
}
2024-02-23 14:48:58 +08:00
total_score += score
details.append(detail)
average_score = total_score / len(predictions) if predictions else 0
2024-02-23 14:48:58 +08:00
return {'score': average_score, 'details': details}