This commit is contained in:
Mor-Li 2025-04-25 17:51:03 +08:00
parent 081c185b8f
commit 5b1a5fa596
6 changed files with 145 additions and 99 deletions

View File

@ -4,20 +4,24 @@ import os
import random import random
import re import re
from enum import Enum from enum import Enum
from datasets import Dataset from datasets import Dataset
from opencompass.datasets.base import BaseDataset from opencompass.datasets.base import BaseDataset
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, TEXT_POSTPROCESSORS from opencompass.datasets.needlebench.atc_elder_only import (
from opencompass.datasets.needlebench.atc_elder_only import clean_atc_answer, needlebench_atc_postprocess_v2, NeedleBenchATCEvaluator NeedleBenchATCEvaluator, clean_atc_answer, needlebench_atc_postprocess_v2)
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
TEXT_POSTPROCESSORS)
from opencompass.utils import get_data_path from opencompass.utils import get_data_path
# 定义问题类型枚举 # 定义问题类型枚举
class QuestionType(Enum): class QuestionType(Enum):
ELDEST_ANCESTOR = 0 # 最年长祖先 ELDEST_ANCESTOR = 0 # 最年长祖先
NTH_ANCESTOR = 1 # N级祖先 NTH_ANCESTOR = 1 # N级祖先
NTH_DESCENDANT = 2 # N级子节点 NTH_DESCENDANT = 2 # N级子节点
RELATIONSHIP_DISTANCE = 3 # 关系距离 RELATIONSHIP_DISTANCE = 3 # 关系距离
# 定义关系术语的代数映射(一代关系还是两代关系) # 定义关系术语的代数映射(一代关系还是两代关系)
relationship_generation_map_zh = { relationship_generation_map_zh = {
@ -92,7 +96,7 @@ relationship_templates_en = [
"but also {B}'s guardian."), "but also {B}'s guardian."),
('For {B}, {A} is not just a {relationship}, ' ('For {B}, {A} is not just a {relationship}, '
'but also a friend.'), 'but also a friend.'),
"For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.", 'For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.',
] ]
# Eldest ancestor problem template # Eldest ancestor problem template
@ -249,23 +253,24 @@ Now, the scrambled family relationships are provided below:
Given the scrambled family relationships described above, what is the relationship distance between '{person_a}' and '{person_b}'? Given the scrambled family relationships described above, what is the relationship distance between '{person_a}' and '{person_b}'?
""" """
@LOAD_DATASET.register_module() @LOAD_DATASET.register_module()
class NeedleBenchATCDataset(BaseDataset): class NeedleBenchATCDataset(BaseDataset):
@staticmethod @staticmethod
def load( def load(
path, path,
file_name: str, file_name: str,
num_needles: int, num_needles: int,
language: str, language: str,
repeats: int, repeats: int,
# This parameter cannot be passed through mmengine because it is blocked as lazy # This parameter cannot be passed through mmengine because it is blocked as lazy
question_types: list[QuestionType] = [ question_types: list[QuestionType] = [
QuestionType.ELDEST_ANCESTOR, QuestionType.ELDEST_ANCESTOR,
QuestionType.NTH_ANCESTOR, QuestionType.NTH_ANCESTOR,
QuestionType.NTH_DESCENDANT, QuestionType.NTH_DESCENDANT,
QuestionType.RELATIONSHIP_DISTANCE, QuestionType.RELATIONSHIP_DISTANCE,
], # Support specifying a list of question types ], # Support specifying a list of question types
): ):
data = {'prompt': [], 'answer': [], 'question_type': []} data = {'prompt': [], 'answer': [], 'question_type': []}
path = get_data_path(path) path = get_data_path(path)
@ -282,7 +287,7 @@ class NeedleBenchATCDataset(BaseDataset):
# Ensure question_types is not empty # Ensure question_types is not empty
if not question_types: if not question_types:
raise ValueError('question_types cannot be empty') raise ValueError('question_types cannot be empty')
for question_type in question_types: for question_type in question_types:
# Generate the specified number of examples for each question type # Generate the specified number of examples for each question type
for i in range(repeats): for i in range(repeats):
@ -290,11 +295,11 @@ class NeedleBenchATCDataset(BaseDataset):
# Use the enum value of the question type multiplied by 10000 as the base to ensure non-overlapping seed ranges # Use the enum value of the question type multiplied by 10000 as the base to ensure non-overlapping seed ranges
seed = (i + 1) + (10000 * question_type.value) seed = (i + 1) + (10000 * question_type.value)
random.seed(seed) random.seed(seed)
# Randomly select the specified number of names from all names # Randomly select the specified number of names from all names
# The number of names is num_needles + 1 # The number of names is num_needles + 1
names = random.sample(all_names, num_needles+1) names = random.sample(all_names, num_needles + 1)
# Select the corresponding relationship terms and templates according to the language # Select the corresponding relationship terms and templates according to the language
if language == 'Chinese': if language == 'Chinese':
relationship_terms = relationship_terms_zh_CN relationship_terms = relationship_terms_zh_CN
@ -305,10 +310,13 @@ class NeedleBenchATCDataset(BaseDataset):
relationship_templates = relationship_templates_en relationship_templates = relationship_templates_en
relationship_map = relationship_generation_map_en relationship_map = relationship_generation_map_en
else: else:
raise ValueError('Unsupported language specified. ' raise ValueError(
'Please choose either "Chinese" or "English".') 'Unsupported language specified. '
'Please choose either "Chinese" or "English".')
def generate_chain_family_story(names, templates, relationship_terms, relationship_map): def generate_chain_family_story(names, templates,
relationship_terms,
relationship_map):
story = '' story = ''
relationships = [] relationships = []
total_generations = 0 # Track the total generational difference total_generations = 0 # Track the total generational difference
@ -317,25 +325,30 @@ class NeedleBenchATCDataset(BaseDataset):
template = random.choice(templates) template = random.choice(templates)
relation_term = random.choice(relationship_terms) relation_term = random.choice(relationship_terms)
relation = template.format(A=names[i], relation = template.format(A=names[i],
B=names[i + 1], B=names[i + 1],
relationship=relation_term) relationship=relation_term)
story += f'{relation}*' story += f'{relation}*'
# Get the generation difference for this relationship # Get the generation difference for this relationship
gen_diff = relationship_map.get(relation_term, 1) # Default to 1 generation gen_diff = relationship_map.get(
relation_term, 1) # Default to 1 generation
total_generations += gen_diff total_generations += gen_diff
# Record relationship information for later use # Record relationship information for later use
relationships.append((names[i], names[i + 1], relation_term, gen_diff)) relationships.append(
(names[i], names[i + 1], relation_term, gen_diff))
return story, relationships, total_generations return story, relationships, total_generations
chain_story, relationships, total_generations = generate_chain_family_story( chain_story, relationships, total_generations = generate_chain_family_story(
names, relationship_templates, relationship_terms, relationship_map) names, relationship_templates, relationship_terms,
relationship_map)
# Split the chain_story into a list of fragments # Split the chain_story into a list of fragments
family_story_fragments = chain_story.split('*') family_story_fragments = chain_story.split('*')
family_story_fragments = [f for f in family_story_fragments if f] family_story_fragments = [
f for f in family_story_fragments if f
]
# Shuffle the list of fragments # Shuffle the list of fragments
random.shuffle(family_story_fragments) random.shuffle(family_story_fragments)
@ -348,15 +361,19 @@ class NeedleBenchATCDataset(BaseDataset):
last_person = names[-1] last_person = names[-1]
if language == 'Chinese': if language == 'Chinese':
prompt = shuffled_story_with_prompt_zh_CN.format( prompt = shuffled_story_with_prompt_zh_CN.format(
shuffled_story=shuffled_story, last_person=last_person) shuffled_story=shuffled_story,
last_person=last_person)
else: else:
prompt = shuffled_story_with_prompt_en.format( prompt = shuffled_story_with_prompt_en.format(
shuffled_story=shuffled_story, last_person=last_person) shuffled_story=shuffled_story,
answer = names[0] # The first person is the eldest ancestor last_person=last_person)
answer = names[
0] # The first person is the eldest ancestor
elif question_type == QuestionType.NTH_ANCESTOR: elif question_type == QuestionType.NTH_ANCESTOR:
# Nth ancestor question - trace from the youngest person to the oldest # Nth ancestor question - trace from the youngest person to the oldest
person = names[-1] # The youngest person (end of the chain) person = names[
-1] # The youngest person (end of the chain)
n = total_generations # Use the calculated total generational difference n = total_generations # Use the calculated total generational difference
if language == 'Chinese': if language == 'Chinese':
prompt = nth_ancestor_prompt_zh_CN.format( prompt = nth_ancestor_prompt_zh_CN.format(
@ -364,7 +381,8 @@ class NeedleBenchATCDataset(BaseDataset):
else: else:
prompt = nth_ancestor_prompt_en.format( prompt = nth_ancestor_prompt_en.format(
shuffled_story=shuffled_story, person=person, n=n) shuffled_story=shuffled_story, person=person, n=n)
answer = names[0] # The oldest person (start of the chain) is the nth ancestor answer = names[
0] # The oldest person (start of the chain) is the nth ancestor
elif question_type == QuestionType.NTH_DESCENDANT: elif question_type == QuestionType.NTH_DESCENDANT:
# Nth descendant question - trace from the oldest person to the youngest # Nth descendant question - trace from the oldest person to the youngest
@ -376,7 +394,8 @@ class NeedleBenchATCDataset(BaseDataset):
else: else:
prompt = nth_descendant_prompt_en.format( prompt = nth_descendant_prompt_en.format(
shuffled_story=shuffled_story, person=person, n=n) shuffled_story=shuffled_story, person=person, n=n)
answer = names[-1] # The youngest person (end of the chain) is the nth descendant answer = names[
-1] # The youngest person (end of the chain) is the nth descendant
elif question_type == QuestionType.RELATIONSHIP_DISTANCE: elif question_type == QuestionType.RELATIONSHIP_DISTANCE:
# Relationship distance question - calculate the relationship distance between the two ends of the chain # Relationship distance question - calculate the relationship distance between the two ends of the chain
@ -384,10 +403,14 @@ class NeedleBenchATCDataset(BaseDataset):
person_b = names[-1] # The youngest person person_b = names[-1] # The youngest person
if language == 'Chinese': if language == 'Chinese':
prompt = relationship_distance_prompt_zh_CN.format( prompt = relationship_distance_prompt_zh_CN.format(
shuffled_story=shuffled_story, person_a=person_a, person_b=person_b) shuffled_story=shuffled_story,
person_a=person_a,
person_b=person_b)
else: else:
prompt = relationship_distance_prompt_en.format( prompt = relationship_distance_prompt_en.format(
shuffled_story=shuffled_story, person_a=person_a, person_b=person_b) shuffled_story=shuffled_story,
person_a=person_a,
person_b=person_b)
# Use the calculated total generations as the relationship distance # Use the calculated total generations as the relationship distance
answer = str(total_generations) answer = str(total_generations)
@ -396,11 +419,14 @@ class NeedleBenchATCDataset(BaseDataset):
last_person = names[-1] last_person = names[-1]
if language == 'Chinese': if language == 'Chinese':
prompt = shuffled_story_with_prompt_zh_CN.format( prompt = shuffled_story_with_prompt_zh_CN.format(
shuffled_story=shuffled_story, last_person=last_person) shuffled_story=shuffled_story,
last_person=last_person)
else: else:
prompt = shuffled_story_with_prompt_en.format( prompt = shuffled_story_with_prompt_en.format(
shuffled_story=shuffled_story, last_person=last_person) shuffled_story=shuffled_story,
answer = names[0] # The first person is the eldest ancestor last_person=last_person)
answer = names[
0] # The first person is the eldest ancestor
data['prompt'].append(prompt) data['prompt'].append(prompt)
data['answer'].append(answer) data['answer'].append(answer)
@ -411,4 +437,4 @@ class NeedleBenchATCDataset(BaseDataset):
'answer': data['answer'], 'answer': data['answer'],
'question_type': data['question_type'], 'question_type': data['question_type'],
}) })
return dataset return dataset

View File

@ -4,14 +4,16 @@ import json
import os import os
import random import random
from datasets import Dataset
import numpy as np import numpy as np
from datasets import Dataset
from opencompass.registry import LOAD_DATASET from opencompass.registry import LOAD_DATASET
from opencompass.utils import get_data_path from opencompass.utils import get_data_path
from ..base import BaseDataset from ..base import BaseDataset
from .atc import relationship_terms_zh_CN, relationship_templates_zh_CN, relationship_terms_en, relationship_templates_en from .atc import (relationship_templates_en, relationship_templates_zh_CN,
relationship_terms_en, relationship_terms_zh_CN)
def get_number(options): def get_number(options):
result_string = '' result_string = ''
@ -173,10 +175,13 @@ Example 3: If Xiao Ming is Zhang Hong's great-granddaughter, Zhang Hong's grandm
) )
names.extend(additional_names) names.extend(additional_names)
num_samples = 3 num_samples = 3
if len(names) > 1: if len(names) > 1:
indices = np.linspace(1, len(names) - 1, num_samples, dtype=int) # Generate evenly spaced indices indices = np.linspace(
sampled_names = [names[i] for i in indices] # Select corresponding elements 1, len(names) - 1, num_samples,
dtype=int) # Generate evenly spaced indices
sampled_names = [names[i] for i in indices
] # Select corresponding elements
entry['options'] = names[:1] + sampled_names entry['options'] = names[:1] + sampled_names
else: else:
entry['options'] = names # Return directly if only one element entry['options'] = names # Return directly if only one element

View File

@ -3,14 +3,15 @@ import json
import os import os
import random import random
import re import re
from datasets import Dataset from datasets import Dataset
from opencompass.datasets.base import BaseDataset from opencompass.datasets.base import BaseDataset
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.utils import get_data_path
from opencompass.datasets.math import extract_boxed_answer from opencompass.datasets.math import extract_boxed_answer
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
TEXT_POSTPROCESSORS)
from opencompass.utils import get_data_path
relationship_templates_zh_CN = [ relationship_templates_zh_CN = [
'{A}{B}{relationship}', '{A}{B}{relationship}',
@ -57,7 +58,7 @@ relationship_templates_en = [
'{B} is the child of {A}.', '{B} is the child of {A}.',
('For {B}, {A} is not just a {relationship}, ' ('For {B}, {A} is not just a {relationship}, '
'but also a friend.'), 'but also a friend.'),
"For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.", 'For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.',
] ]
shuffled_story_with_prompt_zh_CN = """下面是对你的多步推理能力的测试,这个测试叫做祖先追溯测试,我们会模拟不同人的家庭亲属关系,你的任务是在其中不断推理,直到找到最年长的祖先。 shuffled_story_with_prompt_zh_CN = """下面是对你的多步推理能力的测试,这个测试叫做祖先追溯测试,我们会模拟不同人的家庭亲属关系,你的任务是在其中不断推理,直到找到最年长的祖先。
@ -125,7 +126,7 @@ class NeedleBenchATCDataset(BaseDataset):
# 使用固定种子来保持样本稳定性 # 使用固定种子来保持样本稳定性
seed = i seed = i
random.seed(seed) random.seed(seed)
names = random.sample(all_names, num_needles) names = random.sample(all_names, num_needles)
if language == 'Chinese': if language == 'Chinese':
relationship_terms = relationship_terms_zh_CN relationship_terms = relationship_terms_zh_CN
@ -163,9 +164,11 @@ class NeedleBenchATCDataset(BaseDataset):
# Generating the prompt based on the language # Generating the prompt based on the language
if language == 'Chinese': if language == 'Chinese':
shuffled_story_with_prompt = shuffled_story_with_prompt_zh_CN.format(shuffled_story=shuffled_story, last_person=last_person) shuffled_story_with_prompt = shuffled_story_with_prompt_zh_CN.format(
shuffled_story=shuffled_story, last_person=last_person)
elif language == 'English': elif language == 'English':
shuffled_story_with_prompt = shuffled_story_with_prompt_en.format(shuffled_story=shuffled_story, last_person=last_person) shuffled_story_with_prompt = shuffled_story_with_prompt_en.format(
shuffled_story=shuffled_story, last_person=last_person)
else: else:
prompt = 'Language not supported.' prompt = 'Language not supported.'
raise Exception('Unsupported language specified. ' raise Exception('Unsupported language specified. '
@ -182,46 +185,47 @@ class NeedleBenchATCDataset(BaseDataset):
def clean_atc_answer(text: str) -> str: def clean_atc_answer(text: str) -> str:
"""Clean answer format specifically for QwQ-32B-Preview model """Clean answer format specifically for QwQ-32B-Preview model.
Args: Args:
text: Raw prediction text text: Raw prediction text
Returns: Returns:
Standardized name format after cleaning Standardized name format after cleaning
""" """
if not text or text == "None": if not text or text == 'None':
return "None" return 'None'
# Remove LaTeX commands but keep content # Remove LaTeX commands but keep content
text = re.sub(r'\\text\{([^}]+)\}', r'\1', text) text = re.sub(r'\\text\{([^}]+)\}', r'\1', text)
text = re.sub(r'\\boxed\{([^}]+)\}', r'\1', text) text = re.sub(r'\\boxed\{([^}]+)\}', r'\1', text)
text = re.sub(r'\\[\[\]]', '', text) text = re.sub(r'\\[\[\]]', '', text)
# Remove extra backslashes # Remove extra backslashes
text = text.replace('\\\\', '').replace('\\', '') text = text.replace('\\\\', '').replace('\\', '')
# Handle extra spaces # Handle extra spaces
text = re.sub(r'\s+', ' ', text).strip() text = re.sub(r'\s+', ' ', text).strip()
# Remove quotes # Remove quotes
text = text.replace('"', '').replace("'", '') text = text.replace('"', '').replace("'", '')
# Remove tildes (波浪符号) # Remove tildes (波浪符号)
text = text.replace('~', ' ') text = text.replace('~', ' ')
return text return text
@TEXT_POSTPROCESSORS.register_module('needlebench_atc_postprocess_v2') @TEXT_POSTPROCESSORS.register_module('needlebench_atc_postprocess_v2')
def needlebench_atc_postprocess_v2(text: str) -> str: def needlebench_atc_postprocess_v2(text: str) -> str:
cand_ans = extract_boxed_answer(text, strip_double_curly_brace=True) cand_ans = extract_boxed_answer(text, strip_double_curly_brace=True)
if cand_ans: if cand_ans:
return clean_atc_answer(cand_ans) return clean_atc_answer(cand_ans)
return "None" return 'None'
@ICL_EVALUATORS.register_module("needlebench_atc_evaluator") @ICL_EVALUATORS.register_module('needlebench_atc_evaluator')
class NeedleBenchATCEvaluator(BaseEvaluator): class NeedleBenchATCEvaluator(BaseEvaluator):
def score(self, predictions, gold): def score(self, predictions, gold):
@ -230,12 +234,12 @@ class NeedleBenchATCEvaluator(BaseEvaluator):
correct_count = 0 correct_count = 0
details = [] details = []
for prediction, reference in zip(predictions, gold): for prediction, reference in zip(predictions, gold):
reference_name = reference reference_name = reference
if prediction.strip() == reference_name.strip(): if prediction.strip() == reference_name.strip():
correct_count += 1 correct_count += 1
detail = { detail = {
'pred': prediction, 'pred': prediction,
'answer': reference_name, 'answer': reference_name,
@ -243,6 +247,7 @@ class NeedleBenchATCEvaluator(BaseEvaluator):
} }
details.append(detail) details.append(detail)
accuracy = (correct_count / len(predictions)) * 100 if predictions else 0 accuracy = (correct_count /
len(predictions)) * 100 if predictions else 0
result = {'score': accuracy, 'details': details} result = {'score': accuracy, 'details': details}
return result return result

View File

@ -7,8 +7,12 @@ from datasets import Dataset
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from opencompass.datasets.base import BaseDataset from opencompass.datasets.base import BaseDataset
from opencompass.datasets.needlebench.atc import (relationship_templates_en,
relationship_templates_zh_CN,
relationship_terms_en,
relationship_terms_zh_CN)
from opencompass.registry import LOAD_DATASET from opencompass.registry import LOAD_DATASET
from opencompass.datasets.needlebench.atc import relationship_templates_zh_CN, relationship_terms_zh_CN, relationship_templates_en, relationship_terms_en
def get_random_needles(counter, file_path, num_needles, language): def get_random_needles(counter, file_path, num_needles, language):
with open(file_path, 'r', encoding='utf-8') as file: with open(file_path, 'r', encoding='utf-8') as file:
@ -33,18 +37,23 @@ def get_random_needles(counter, file_path, num_needles, language):
for i in range(len(names) - 1): for i in range(len(names) - 1):
template = random.choice(templates) template = random.choice(templates)
relation_term = random.choice(relationship_terms) relation_term = random.choice(relationship_terms)
relation = template.format(A=names[i], B=names[i + 1], relationship=relation_term) relation = template.format(A=names[i],
B=names[i + 1],
relationship=relation_term)
story += f'{relation}*' story += f'{relation}*'
return story return story
chain_story = generate_chain_family_story(names, relationship_templates, relationship_terms) chain_story = generate_chain_family_story(names, relationship_templates,
relationship_terms)
# Splitting the chain_story into a list of fragments # Splitting the chain_story into a list of fragments
family_story_fragments = chain_story.split('*') family_story_fragments = chain_story.split('*')
# Removing the empty string from the list # Removing the empty string from the list
family_story_fragments = [fragment for fragment in family_story_fragments if fragment] family_story_fragments = [
fragment for fragment in family_story_fragments if fragment
]
# Shuffling the list of fragments # Shuffling the list of fragments
random.shuffle(family_story_fragments) random.shuffle(family_story_fragments)
@ -55,7 +64,7 @@ def get_random_needles(counter, file_path, num_needles, language):
retrieval_question = f"在上面提供的文本中,'{last_person}'的能够向上追溯到的最年长的亲人是谁?" retrieval_question = f"在上面提供的文本中,'{last_person}'的能够向上追溯到的最年长的亲人是谁?"
elif language == 'English': elif language == 'English':
retrieval_question = f"Given the context described above, who is the eldest relative that '{last_person}' can trace back to in the context?" retrieval_question = f"Given the context described above, who is the eldest relative that '{last_person}' can trace back to in the context?"
# Returning the story, answer, and retrieval question # Returning the story, answer, and retrieval question
return { return {
'needles': family_story_fragments, 'needles': family_story_fragments,
@ -65,7 +74,6 @@ def get_random_needles(counter, file_path, num_needles, language):
} }
@LOAD_DATASET.register_module() @LOAD_DATASET.register_module()
class NeedleBenchMultiDataset(BaseDataset): class NeedleBenchMultiDataset(BaseDataset):
@ -216,8 +224,9 @@ The content of the long document is as follows
''' '''
else: else:
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. ' raise ValueError(
'Position must be "End" or "Start".') f'Unsupported quesiton_position {quesiton_position}. '
'Position must be "End" or "Start".')
else: else:
raise ValueError(f"Language '{language}' is not supported.") raise ValueError(f"Language '{language}' is not supported.")
@ -225,7 +234,7 @@ The content of the long document is as follows
repo_id = 'opencompass/NeedleBench' repo_id = 'opencompass/NeedleBench'
file_names = [ file_names = [
'PaulGrahamEssays.jsonl','names.json', 'zh_finance.jsonl', 'PaulGrahamEssays.jsonl', 'names.json', 'zh_finance.jsonl',
'zh_game.jsonl', 'zh_general.jsonl', 'zh_government.jsonl', 'zh_game.jsonl', 'zh_general.jsonl', 'zh_government.jsonl',
'zh_movie.jsonl', 'zh_tech.jsonl' 'zh_movie.jsonl', 'zh_tech.jsonl'
] ]
@ -250,7 +259,7 @@ The content of the long document is as follows
random.seed(counter) random.seed(counter)
random.shuffle(lines) random.shuffle(lines)
random_needle_data = get_random_needles( random_needle_data = get_random_needles(
counter, needle_file_path, num_needles+1, language) counter, needle_file_path, num_needles + 1, language)
last_person = random_needle_data['last_person'] last_person = random_needle_data['last_person']
needles = [ needles = [
'\n' + needle + '\n' '\n' + needle + '\n'
@ -278,7 +287,8 @@ The content of the long document is as follows
needles) needles)
processed_prompt = _generate_prompt(processed_text, processed_prompt = _generate_prompt(processed_text,
retrieval_question, last_person) retrieval_question,
last_person)
data['prompt'].append(processed_prompt) data['prompt'].append(processed_prompt)
data['answer'].append(keyword) data['answer'].append(keyword)
@ -287,4 +297,4 @@ The content of the long document is as follows
'prompt': data['prompt'], 'prompt': data['prompt'],
'answer': data['answer'], 'answer': data['answer'],
}) })
return dataset return dataset

View File

@ -114,8 +114,9 @@ The content of the long document is as follows
''' '''
else: else:
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. ' raise ValueError(
'Position must be "End" or "Start".') f'Unsupported quesiton_position {quesiton_position}. '
'Position must be "End" or "Start".')
else: else:
raise ValueError(f"Language '{language}' is not supported.") raise ValueError(f"Language '{language}' is not supported.")
@ -201,11 +202,7 @@ class NeedleBenchOriginEvaluator(BaseEvaluator):
else: else:
score = 0 score = 0
detail = { detail = {'pred': prediction, 'answer': reference, 'score': score}
'pred': prediction,
'answer': reference,
'score': score
}
total_score += score total_score += score
details.append(detail) details.append(detail)

View File

@ -158,8 +158,9 @@ class NeedleBenchParallelDataset(BaseDataset):
''' '''
else: else:
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. ' raise ValueError(
'Position must be "End" or "Start".') f'Unsupported quesiton_position {quesiton_position}. '
'Position must be "End" or "Start".')
elif language == 'English': elif language == 'English':
if quesiton_position == 'End': if quesiton_position == 'End':
prompt = f'''This is a test of long-text capability. You need to first read the long document below, and then answer the final questions one by one based on the information in the document. prompt = f'''This is a test of long-text capability. You need to first read the long document below, and then answer the final questions one by one based on the information in the document.
@ -183,8 +184,9 @@ The content of the long document is as follows
''' '''
else: else:
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. ' raise ValueError(
'Position must be "End" or "Start".') f'Unsupported quesiton_position {quesiton_position}. '
'Position must be "End" or "Start".')
else: else:
raise ValueError(f"Language '{language}' is not supported.") raise ValueError(f"Language '{language}' is not supported.")
@ -269,6 +271,7 @@ The content of the long document is as follows
class NeedleBenchParallelEvaluator(BaseEvaluator): class NeedleBenchParallelEvaluator(BaseEvaluator):
def score(self, predictions, gold): def score(self, predictions, gold):
if len(predictions) != len(gold): if len(predictions) != len(gold):
return {'error': 'predictions and gold have different lengths'} return {'error': 'predictions and gold have different lengths'}