This commit is contained in:
Mor-Li 2025-04-25 17:51:03 +08:00
parent 081c185b8f
commit 5b1a5fa596
6 changed files with 145 additions and 99 deletions

View File

@ -4,11 +4,14 @@ import os
import random
import re
from enum import Enum
from datasets import Dataset
from opencompass.datasets.base import BaseDataset
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.datasets.needlebench.atc_elder_only import clean_atc_answer, needlebench_atc_postprocess_v2, NeedleBenchATCEvaluator
from opencompass.datasets.needlebench.atc_elder_only import (
NeedleBenchATCEvaluator, clean_atc_answer, needlebench_atc_postprocess_v2)
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
TEXT_POSTPROCESSORS)
from opencompass.utils import get_data_path
@ -19,6 +22,7 @@ class QuestionType(Enum):
NTH_DESCENDANT = 2 # N级子节点
RELATIONSHIP_DISTANCE = 3 # 关系距离
# 定义关系术语的代数映射(一代关系还是两代关系)
relationship_generation_map_zh = {
'父亲': 1,
@ -92,7 +96,7 @@ relationship_templates_en = [
"but also {B}'s guardian."),
('For {B}, {A} is not just a {relationship}, '
'but also a friend.'),
"For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.",
'For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.',
]
# Eldest ancestor problem template
@ -249,6 +253,7 @@ Now, the scrambled family relationships are provided below:
Given the scrambled family relationships described above, what is the relationship distance between '{person_a}' and '{person_b}'?
"""
@LOAD_DATASET.register_module()
class NeedleBenchATCDataset(BaseDataset):
@ -293,7 +298,7 @@ class NeedleBenchATCDataset(BaseDataset):
# Randomly select the specified number of names from all names
# The number of names is num_needles + 1
names = random.sample(all_names, num_needles+1)
names = random.sample(all_names, num_needles + 1)
# Select the corresponding relationship terms and templates according to the language
if language == 'Chinese':
@ -305,10 +310,13 @@ class NeedleBenchATCDataset(BaseDataset):
relationship_templates = relationship_templates_en
relationship_map = relationship_generation_map_en
else:
raise ValueError('Unsupported language specified. '
raise ValueError(
'Unsupported language specified. '
'Please choose either "Chinese" or "English".')
def generate_chain_family_story(names, templates, relationship_terms, relationship_map):
def generate_chain_family_story(names, templates,
relationship_terms,
relationship_map):
story = ''
relationships = []
total_generations = 0 # Track the total generational difference
@ -322,20 +330,25 @@ class NeedleBenchATCDataset(BaseDataset):
story += f'{relation}*'
# Get the generation difference for this relationship
gen_diff = relationship_map.get(relation_term, 1) # Default to 1 generation
gen_diff = relationship_map.get(
relation_term, 1) # Default to 1 generation
total_generations += gen_diff
# Record relationship information for later use
relationships.append((names[i], names[i + 1], relation_term, gen_diff))
relationships.append(
(names[i], names[i + 1], relation_term, gen_diff))
return story, relationships, total_generations
chain_story, relationships, total_generations = generate_chain_family_story(
names, relationship_templates, relationship_terms, relationship_map)
names, relationship_templates, relationship_terms,
relationship_map)
# Split the chain_story into a list of fragments
family_story_fragments = chain_story.split('*')
family_story_fragments = [f for f in family_story_fragments if f]
family_story_fragments = [
f for f in family_story_fragments if f
]
# Shuffle the list of fragments
random.shuffle(family_story_fragments)
@ -348,15 +361,19 @@ class NeedleBenchATCDataset(BaseDataset):
last_person = names[-1]
if language == 'Chinese':
prompt = shuffled_story_with_prompt_zh_CN.format(
shuffled_story=shuffled_story, last_person=last_person)
shuffled_story=shuffled_story,
last_person=last_person)
else:
prompt = shuffled_story_with_prompt_en.format(
shuffled_story=shuffled_story, last_person=last_person)
answer = names[0] # The first person is the eldest ancestor
shuffled_story=shuffled_story,
last_person=last_person)
answer = names[
0] # The first person is the eldest ancestor
elif question_type == QuestionType.NTH_ANCESTOR:
# Nth ancestor question - trace from the youngest person to the oldest
person = names[-1] # The youngest person (end of the chain)
person = names[
-1] # The youngest person (end of the chain)
n = total_generations # Use the calculated total generational difference
if language == 'Chinese':
prompt = nth_ancestor_prompt_zh_CN.format(
@ -364,7 +381,8 @@ class NeedleBenchATCDataset(BaseDataset):
else:
prompt = nth_ancestor_prompt_en.format(
shuffled_story=shuffled_story, person=person, n=n)
answer = names[0] # The oldest person (start of the chain) is the nth ancestor
answer = names[
0] # The oldest person (start of the chain) is the nth ancestor
elif question_type == QuestionType.NTH_DESCENDANT:
# Nth descendant question - trace from the oldest person to the youngest
@ -376,7 +394,8 @@ class NeedleBenchATCDataset(BaseDataset):
else:
prompt = nth_descendant_prompt_en.format(
shuffled_story=shuffled_story, person=person, n=n)
answer = names[-1] # The youngest person (end of the chain) is the nth descendant
answer = names[
-1] # The youngest person (end of the chain) is the nth descendant
elif question_type == QuestionType.RELATIONSHIP_DISTANCE:
# Relationship distance question - calculate the relationship distance between the two ends of the chain
@ -384,10 +403,14 @@ class NeedleBenchATCDataset(BaseDataset):
person_b = names[-1] # The youngest person
if language == 'Chinese':
prompt = relationship_distance_prompt_zh_CN.format(
shuffled_story=shuffled_story, person_a=person_a, person_b=person_b)
shuffled_story=shuffled_story,
person_a=person_a,
person_b=person_b)
else:
prompt = relationship_distance_prompt_en.format(
shuffled_story=shuffled_story, person_a=person_a, person_b=person_b)
shuffled_story=shuffled_story,
person_a=person_a,
person_b=person_b)
# Use the calculated total generations as the relationship distance
answer = str(total_generations)
@ -396,11 +419,14 @@ class NeedleBenchATCDataset(BaseDataset):
last_person = names[-1]
if language == 'Chinese':
prompt = shuffled_story_with_prompt_zh_CN.format(
shuffled_story=shuffled_story, last_person=last_person)
shuffled_story=shuffled_story,
last_person=last_person)
else:
prompt = shuffled_story_with_prompt_en.format(
shuffled_story=shuffled_story, last_person=last_person)
answer = names[0] # The first person is the eldest ancestor
shuffled_story=shuffled_story,
last_person=last_person)
answer = names[
0] # The first person is the eldest ancestor
data['prompt'].append(prompt)
data['answer'].append(answer)

View File

@ -4,14 +4,16 @@ import json
import os
import random
from datasets import Dataset
import numpy as np
from datasets import Dataset
from opencompass.registry import LOAD_DATASET
from opencompass.utils import get_data_path
from ..base import BaseDataset
from .atc import relationship_terms_zh_CN, relationship_templates_zh_CN, relationship_terms_en, relationship_templates_en
from .atc import (relationship_templates_en, relationship_templates_zh_CN,
relationship_terms_en, relationship_terms_zh_CN)
def get_number(options):
result_string = ''
@ -175,8 +177,11 @@ Example 3: If Xiao Ming is Zhang Hong's great-granddaughter, Zhang Hong's grandm
num_samples = 3
if len(names) > 1:
indices = np.linspace(1, len(names) - 1, num_samples, dtype=int) # Generate evenly spaced indices
sampled_names = [names[i] for i in indices] # Select corresponding elements
indices = np.linspace(
1, len(names) - 1, num_samples,
dtype=int) # Generate evenly spaced indices
sampled_names = [names[i] for i in indices
] # Select corresponding elements
entry['options'] = names[:1] + sampled_names
else:
entry['options'] = names # Return directly if only one element

View File

@ -3,14 +3,15 @@ import json
import os
import random
import re
from datasets import Dataset
from opencompass.datasets.base import BaseDataset
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.utils import get_data_path
from opencompass.datasets.math import extract_boxed_answer
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
TEXT_POSTPROCESSORS)
from opencompass.utils import get_data_path
relationship_templates_zh_CN = [
'{A}{B}{relationship}',
@ -57,7 +58,7 @@ relationship_templates_en = [
'{B} is the child of {A}.',
('For {B}, {A} is not just a {relationship}, '
'but also a friend.'),
"For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.",
'For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.',
]
shuffled_story_with_prompt_zh_CN = """下面是对你的多步推理能力的测试,这个测试叫做祖先追溯测试,我们会模拟不同人的家庭亲属关系,你的任务是在其中不断推理,直到找到最年长的祖先。
@ -163,9 +164,11 @@ class NeedleBenchATCDataset(BaseDataset):
# Generating the prompt based on the language
if language == 'Chinese':
shuffled_story_with_prompt = shuffled_story_with_prompt_zh_CN.format(shuffled_story=shuffled_story, last_person=last_person)
shuffled_story_with_prompt = shuffled_story_with_prompt_zh_CN.format(
shuffled_story=shuffled_story, last_person=last_person)
elif language == 'English':
shuffled_story_with_prompt = shuffled_story_with_prompt_en.format(shuffled_story=shuffled_story, last_person=last_person)
shuffled_story_with_prompt = shuffled_story_with_prompt_en.format(
shuffled_story=shuffled_story, last_person=last_person)
else:
prompt = 'Language not supported.'
raise Exception('Unsupported language specified. '
@ -182,7 +185,7 @@ class NeedleBenchATCDataset(BaseDataset):
def clean_atc_answer(text: str) -> str:
"""Clean answer format specifically for QwQ-32B-Preview model
"""Clean answer format specifically for QwQ-32B-Preview model.
Args:
text: Raw prediction text
@ -190,8 +193,8 @@ def clean_atc_answer(text: str) -> str:
Returns:
Standardized name format after cleaning
"""
if not text or text == "None":
return "None"
if not text or text == 'None':
return 'None'
# Remove LaTeX commands but keep content
text = re.sub(r'\\text\{([^}]+)\}', r'\1', text)
@ -211,6 +214,7 @@ def clean_atc_answer(text: str) -> str:
return text
@TEXT_POSTPROCESSORS.register_module('needlebench_atc_postprocess_v2')
def needlebench_atc_postprocess_v2(text: str) -> str:
@ -218,10 +222,10 @@ def needlebench_atc_postprocess_v2(text: str) -> str:
if cand_ans:
return clean_atc_answer(cand_ans)
return "None"
return 'None'
@ICL_EVALUATORS.register_module("needlebench_atc_evaluator")
@ICL_EVALUATORS.register_module('needlebench_atc_evaluator')
class NeedleBenchATCEvaluator(BaseEvaluator):
def score(self, predictions, gold):
@ -243,6 +247,7 @@ class NeedleBenchATCEvaluator(BaseEvaluator):
}
details.append(detail)
accuracy = (correct_count / len(predictions)) * 100 if predictions else 0
accuracy = (correct_count /
len(predictions)) * 100 if predictions else 0
result = {'score': accuracy, 'details': details}
return result

View File

@ -7,8 +7,12 @@ from datasets import Dataset
from huggingface_hub import hf_hub_download
from opencompass.datasets.base import BaseDataset
from opencompass.datasets.needlebench.atc import (relationship_templates_en,
relationship_templates_zh_CN,
relationship_terms_en,
relationship_terms_zh_CN)
from opencompass.registry import LOAD_DATASET
from opencompass.datasets.needlebench.atc import relationship_templates_zh_CN, relationship_terms_zh_CN, relationship_templates_en, relationship_terms_en
def get_random_needles(counter, file_path, num_needles, language):
with open(file_path, 'r', encoding='utf-8') as file:
@ -33,17 +37,22 @@ def get_random_needles(counter, file_path, num_needles, language):
for i in range(len(names) - 1):
template = random.choice(templates)
relation_term = random.choice(relationship_terms)
relation = template.format(A=names[i], B=names[i + 1], relationship=relation_term)
relation = template.format(A=names[i],
B=names[i + 1],
relationship=relation_term)
story += f'{relation}*'
return story
chain_story = generate_chain_family_story(names, relationship_templates, relationship_terms)
chain_story = generate_chain_family_story(names, relationship_templates,
relationship_terms)
# Splitting the chain_story into a list of fragments
family_story_fragments = chain_story.split('*')
# Removing the empty string from the list
family_story_fragments = [fragment for fragment in family_story_fragments if fragment]
family_story_fragments = [
fragment for fragment in family_story_fragments if fragment
]
# Shuffling the list of fragments
random.shuffle(family_story_fragments)
@ -65,7 +74,6 @@ def get_random_needles(counter, file_path, num_needles, language):
}
@LOAD_DATASET.register_module()
class NeedleBenchMultiDataset(BaseDataset):
@ -216,7 +224,8 @@ The content of the long document is as follows
'''
else:
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. '
raise ValueError(
f'Unsupported quesiton_position {quesiton_position}. '
'Position must be "End" or "Start".')
else:
raise ValueError(f"Language '{language}' is not supported.")
@ -225,7 +234,7 @@ The content of the long document is as follows
repo_id = 'opencompass/NeedleBench'
file_names = [
'PaulGrahamEssays.jsonl','names.json', 'zh_finance.jsonl',
'PaulGrahamEssays.jsonl', 'names.json', 'zh_finance.jsonl',
'zh_game.jsonl', 'zh_general.jsonl', 'zh_government.jsonl',
'zh_movie.jsonl', 'zh_tech.jsonl'
]
@ -250,7 +259,7 @@ The content of the long document is as follows
random.seed(counter)
random.shuffle(lines)
random_needle_data = get_random_needles(
counter, needle_file_path, num_needles+1, language)
counter, needle_file_path, num_needles + 1, language)
last_person = random_needle_data['last_person']
needles = [
'\n' + needle + '\n'
@ -278,7 +287,8 @@ The content of the long document is as follows
needles)
processed_prompt = _generate_prompt(processed_text,
retrieval_question, last_person)
retrieval_question,
last_person)
data['prompt'].append(processed_prompt)
data['answer'].append(keyword)

View File

@ -114,7 +114,8 @@ The content of the long document is as follows
'''
else:
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. '
raise ValueError(
f'Unsupported quesiton_position {quesiton_position}. '
'Position must be "End" or "Start".')
else:
raise ValueError(f"Language '{language}' is not supported.")
@ -201,11 +202,7 @@ class NeedleBenchOriginEvaluator(BaseEvaluator):
else:
score = 0
detail = {
'pred': prediction,
'answer': reference,
'score': score
}
detail = {'pred': prediction, 'answer': reference, 'score': score}
total_score += score
details.append(detail)

View File

@ -158,7 +158,8 @@ class NeedleBenchParallelDataset(BaseDataset):
'''
else:
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. '
raise ValueError(
f'Unsupported quesiton_position {quesiton_position}. '
'Position must be "End" or "Start".')
elif language == 'English':
if quesiton_position == 'End':
@ -183,7 +184,8 @@ The content of the long document is as follows
'''
else:
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. '
raise ValueError(
f'Unsupported quesiton_position {quesiton_position}. '
'Position must be "End" or "Start".')
else:
raise ValueError(f"Language '{language}' is not supported.")
@ -269,6 +271,7 @@ The content of the long document is as follows
class NeedleBenchParallelEvaluator(BaseEvaluator):
def score(self, predictions, gold):
if len(predictions) != len(gold):
return {'error': 'predictions and gold have different lengths'}