mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
fix lint
This commit is contained in:
parent
081c185b8f
commit
5b1a5fa596
@ -4,20 +4,24 @@ import os
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
||||||
from opencompass.datasets.base import BaseDataset
|
from opencompass.datasets.base import BaseDataset
|
||||||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, TEXT_POSTPROCESSORS
|
from opencompass.datasets.needlebench.atc_elder_only import (
|
||||||
from opencompass.datasets.needlebench.atc_elder_only import clean_atc_answer, needlebench_atc_postprocess_v2, NeedleBenchATCEvaluator
|
NeedleBenchATCEvaluator, clean_atc_answer, needlebench_atc_postprocess_v2)
|
||||||
|
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
|
||||||
|
TEXT_POSTPROCESSORS)
|
||||||
from opencompass.utils import get_data_path
|
from opencompass.utils import get_data_path
|
||||||
|
|
||||||
|
|
||||||
# 定义问题类型枚举
|
# 定义问题类型枚举
|
||||||
class QuestionType(Enum):
|
class QuestionType(Enum):
|
||||||
ELDEST_ANCESTOR = 0 # 最年长祖先
|
ELDEST_ANCESTOR = 0 # 最年长祖先
|
||||||
NTH_ANCESTOR = 1 # N级祖先
|
NTH_ANCESTOR = 1 # N级祖先
|
||||||
NTH_DESCENDANT = 2 # N级子节点
|
NTH_DESCENDANT = 2 # N级子节点
|
||||||
RELATIONSHIP_DISTANCE = 3 # 关系距离
|
RELATIONSHIP_DISTANCE = 3 # 关系距离
|
||||||
|
|
||||||
|
|
||||||
# 定义关系术语的代数映射(一代关系还是两代关系)
|
# 定义关系术语的代数映射(一代关系还是两代关系)
|
||||||
relationship_generation_map_zh = {
|
relationship_generation_map_zh = {
|
||||||
@ -92,7 +96,7 @@ relationship_templates_en = [
|
|||||||
"but also {B}'s guardian."),
|
"but also {B}'s guardian."),
|
||||||
('For {B}, {A} is not just a {relationship}, '
|
('For {B}, {A} is not just a {relationship}, '
|
||||||
'but also a friend.'),
|
'but also a friend.'),
|
||||||
"For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.",
|
'For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.',
|
||||||
]
|
]
|
||||||
|
|
||||||
# Eldest ancestor problem template
|
# Eldest ancestor problem template
|
||||||
@ -249,23 +253,24 @@ Now, the scrambled family relationships are provided below:
|
|||||||
Given the scrambled family relationships described above, what is the relationship distance between '{person_a}' and '{person_b}'?
|
Given the scrambled family relationships described above, what is the relationship distance between '{person_a}' and '{person_b}'?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@LOAD_DATASET.register_module()
|
@LOAD_DATASET.register_module()
|
||||||
class NeedleBenchATCDataset(BaseDataset):
|
class NeedleBenchATCDataset(BaseDataset):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(
|
def load(
|
||||||
path,
|
path,
|
||||||
file_name: str,
|
file_name: str,
|
||||||
num_needles: int,
|
num_needles: int,
|
||||||
language: str,
|
language: str,
|
||||||
repeats: int,
|
repeats: int,
|
||||||
# This parameter cannot be passed through mmengine because it is blocked as lazy
|
# This parameter cannot be passed through mmengine because it is blocked as lazy
|
||||||
question_types: list[QuestionType] = [
|
question_types: list[QuestionType] = [
|
||||||
QuestionType.ELDEST_ANCESTOR,
|
QuestionType.ELDEST_ANCESTOR,
|
||||||
QuestionType.NTH_ANCESTOR,
|
QuestionType.NTH_ANCESTOR,
|
||||||
QuestionType.NTH_DESCENDANT,
|
QuestionType.NTH_DESCENDANT,
|
||||||
QuestionType.RELATIONSHIP_DISTANCE,
|
QuestionType.RELATIONSHIP_DISTANCE,
|
||||||
], # Support specifying a list of question types
|
], # Support specifying a list of question types
|
||||||
):
|
):
|
||||||
data = {'prompt': [], 'answer': [], 'question_type': []}
|
data = {'prompt': [], 'answer': [], 'question_type': []}
|
||||||
path = get_data_path(path)
|
path = get_data_path(path)
|
||||||
@ -282,7 +287,7 @@ class NeedleBenchATCDataset(BaseDataset):
|
|||||||
# Ensure question_types is not empty
|
# Ensure question_types is not empty
|
||||||
if not question_types:
|
if not question_types:
|
||||||
raise ValueError('question_types cannot be empty')
|
raise ValueError('question_types cannot be empty')
|
||||||
|
|
||||||
for question_type in question_types:
|
for question_type in question_types:
|
||||||
# Generate the specified number of examples for each question type
|
# Generate the specified number of examples for each question type
|
||||||
for i in range(repeats):
|
for i in range(repeats):
|
||||||
@ -290,11 +295,11 @@ class NeedleBenchATCDataset(BaseDataset):
|
|||||||
# Use the enum value of the question type multiplied by 10000 as the base to ensure non-overlapping seed ranges
|
# Use the enum value of the question type multiplied by 10000 as the base to ensure non-overlapping seed ranges
|
||||||
seed = (i + 1) + (10000 * question_type.value)
|
seed = (i + 1) + (10000 * question_type.value)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
# Randomly select the specified number of names from all names
|
# Randomly select the specified number of names from all names
|
||||||
# The number of names is num_needles + 1
|
# The number of names is num_needles + 1
|
||||||
names = random.sample(all_names, num_needles+1)
|
names = random.sample(all_names, num_needles + 1)
|
||||||
|
|
||||||
# Select the corresponding relationship terms and templates according to the language
|
# Select the corresponding relationship terms and templates according to the language
|
||||||
if language == 'Chinese':
|
if language == 'Chinese':
|
||||||
relationship_terms = relationship_terms_zh_CN
|
relationship_terms = relationship_terms_zh_CN
|
||||||
@ -305,10 +310,13 @@ class NeedleBenchATCDataset(BaseDataset):
|
|||||||
relationship_templates = relationship_templates_en
|
relationship_templates = relationship_templates_en
|
||||||
relationship_map = relationship_generation_map_en
|
relationship_map = relationship_generation_map_en
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unsupported language specified. '
|
raise ValueError(
|
||||||
'Please choose either "Chinese" or "English".')
|
'Unsupported language specified. '
|
||||||
|
'Please choose either "Chinese" or "English".')
|
||||||
|
|
||||||
def generate_chain_family_story(names, templates, relationship_terms, relationship_map):
|
def generate_chain_family_story(names, templates,
|
||||||
|
relationship_terms,
|
||||||
|
relationship_map):
|
||||||
story = ''
|
story = ''
|
||||||
relationships = []
|
relationships = []
|
||||||
total_generations = 0 # Track the total generational difference
|
total_generations = 0 # Track the total generational difference
|
||||||
@ -317,25 +325,30 @@ class NeedleBenchATCDataset(BaseDataset):
|
|||||||
template = random.choice(templates)
|
template = random.choice(templates)
|
||||||
relation_term = random.choice(relationship_terms)
|
relation_term = random.choice(relationship_terms)
|
||||||
relation = template.format(A=names[i],
|
relation = template.format(A=names[i],
|
||||||
B=names[i + 1],
|
B=names[i + 1],
|
||||||
relationship=relation_term)
|
relationship=relation_term)
|
||||||
story += f'{relation}*'
|
story += f'{relation}*'
|
||||||
|
|
||||||
# Get the generation difference for this relationship
|
# Get the generation difference for this relationship
|
||||||
gen_diff = relationship_map.get(relation_term, 1) # Default to 1 generation
|
gen_diff = relationship_map.get(
|
||||||
|
relation_term, 1) # Default to 1 generation
|
||||||
total_generations += gen_diff
|
total_generations += gen_diff
|
||||||
|
|
||||||
# Record relationship information for later use
|
# Record relationship information for later use
|
||||||
relationships.append((names[i], names[i + 1], relation_term, gen_diff))
|
relationships.append(
|
||||||
|
(names[i], names[i + 1], relation_term, gen_diff))
|
||||||
|
|
||||||
return story, relationships, total_generations
|
return story, relationships, total_generations
|
||||||
|
|
||||||
chain_story, relationships, total_generations = generate_chain_family_story(
|
chain_story, relationships, total_generations = generate_chain_family_story(
|
||||||
names, relationship_templates, relationship_terms, relationship_map)
|
names, relationship_templates, relationship_terms,
|
||||||
|
relationship_map)
|
||||||
|
|
||||||
# Split the chain_story into a list of fragments
|
# Split the chain_story into a list of fragments
|
||||||
family_story_fragments = chain_story.split('*')
|
family_story_fragments = chain_story.split('*')
|
||||||
family_story_fragments = [f for f in family_story_fragments if f]
|
family_story_fragments = [
|
||||||
|
f for f in family_story_fragments if f
|
||||||
|
]
|
||||||
|
|
||||||
# Shuffle the list of fragments
|
# Shuffle the list of fragments
|
||||||
random.shuffle(family_story_fragments)
|
random.shuffle(family_story_fragments)
|
||||||
@ -348,15 +361,19 @@ class NeedleBenchATCDataset(BaseDataset):
|
|||||||
last_person = names[-1]
|
last_person = names[-1]
|
||||||
if language == 'Chinese':
|
if language == 'Chinese':
|
||||||
prompt = shuffled_story_with_prompt_zh_CN.format(
|
prompt = shuffled_story_with_prompt_zh_CN.format(
|
||||||
shuffled_story=shuffled_story, last_person=last_person)
|
shuffled_story=shuffled_story,
|
||||||
|
last_person=last_person)
|
||||||
else:
|
else:
|
||||||
prompt = shuffled_story_with_prompt_en.format(
|
prompt = shuffled_story_with_prompt_en.format(
|
||||||
shuffled_story=shuffled_story, last_person=last_person)
|
shuffled_story=shuffled_story,
|
||||||
answer = names[0] # The first person is the eldest ancestor
|
last_person=last_person)
|
||||||
|
answer = names[
|
||||||
|
0] # The first person is the eldest ancestor
|
||||||
|
|
||||||
elif question_type == QuestionType.NTH_ANCESTOR:
|
elif question_type == QuestionType.NTH_ANCESTOR:
|
||||||
# Nth ancestor question - trace from the youngest person to the oldest
|
# Nth ancestor question - trace from the youngest person to the oldest
|
||||||
person = names[-1] # The youngest person (end of the chain)
|
person = names[
|
||||||
|
-1] # The youngest person (end of the chain)
|
||||||
n = total_generations # Use the calculated total generational difference
|
n = total_generations # Use the calculated total generational difference
|
||||||
if language == 'Chinese':
|
if language == 'Chinese':
|
||||||
prompt = nth_ancestor_prompt_zh_CN.format(
|
prompt = nth_ancestor_prompt_zh_CN.format(
|
||||||
@ -364,7 +381,8 @@ class NeedleBenchATCDataset(BaseDataset):
|
|||||||
else:
|
else:
|
||||||
prompt = nth_ancestor_prompt_en.format(
|
prompt = nth_ancestor_prompt_en.format(
|
||||||
shuffled_story=shuffled_story, person=person, n=n)
|
shuffled_story=shuffled_story, person=person, n=n)
|
||||||
answer = names[0] # The oldest person (start of the chain) is the nth ancestor
|
answer = names[
|
||||||
|
0] # The oldest person (start of the chain) is the nth ancestor
|
||||||
|
|
||||||
elif question_type == QuestionType.NTH_DESCENDANT:
|
elif question_type == QuestionType.NTH_DESCENDANT:
|
||||||
# Nth descendant question - trace from the oldest person to the youngest
|
# Nth descendant question - trace from the oldest person to the youngest
|
||||||
@ -376,7 +394,8 @@ class NeedleBenchATCDataset(BaseDataset):
|
|||||||
else:
|
else:
|
||||||
prompt = nth_descendant_prompt_en.format(
|
prompt = nth_descendant_prompt_en.format(
|
||||||
shuffled_story=shuffled_story, person=person, n=n)
|
shuffled_story=shuffled_story, person=person, n=n)
|
||||||
answer = names[-1] # The youngest person (end of the chain) is the nth descendant
|
answer = names[
|
||||||
|
-1] # The youngest person (end of the chain) is the nth descendant
|
||||||
|
|
||||||
elif question_type == QuestionType.RELATIONSHIP_DISTANCE:
|
elif question_type == QuestionType.RELATIONSHIP_DISTANCE:
|
||||||
# Relationship distance question - calculate the relationship distance between the two ends of the chain
|
# Relationship distance question - calculate the relationship distance between the two ends of the chain
|
||||||
@ -384,10 +403,14 @@ class NeedleBenchATCDataset(BaseDataset):
|
|||||||
person_b = names[-1] # The youngest person
|
person_b = names[-1] # The youngest person
|
||||||
if language == 'Chinese':
|
if language == 'Chinese':
|
||||||
prompt = relationship_distance_prompt_zh_CN.format(
|
prompt = relationship_distance_prompt_zh_CN.format(
|
||||||
shuffled_story=shuffled_story, person_a=person_a, person_b=person_b)
|
shuffled_story=shuffled_story,
|
||||||
|
person_a=person_a,
|
||||||
|
person_b=person_b)
|
||||||
else:
|
else:
|
||||||
prompt = relationship_distance_prompt_en.format(
|
prompt = relationship_distance_prompt_en.format(
|
||||||
shuffled_story=shuffled_story, person_a=person_a, person_b=person_b)
|
shuffled_story=shuffled_story,
|
||||||
|
person_a=person_a,
|
||||||
|
person_b=person_b)
|
||||||
# Use the calculated total generations as the relationship distance
|
# Use the calculated total generations as the relationship distance
|
||||||
answer = str(total_generations)
|
answer = str(total_generations)
|
||||||
|
|
||||||
@ -396,11 +419,14 @@ class NeedleBenchATCDataset(BaseDataset):
|
|||||||
last_person = names[-1]
|
last_person = names[-1]
|
||||||
if language == 'Chinese':
|
if language == 'Chinese':
|
||||||
prompt = shuffled_story_with_prompt_zh_CN.format(
|
prompt = shuffled_story_with_prompt_zh_CN.format(
|
||||||
shuffled_story=shuffled_story, last_person=last_person)
|
shuffled_story=shuffled_story,
|
||||||
|
last_person=last_person)
|
||||||
else:
|
else:
|
||||||
prompt = shuffled_story_with_prompt_en.format(
|
prompt = shuffled_story_with_prompt_en.format(
|
||||||
shuffled_story=shuffled_story, last_person=last_person)
|
shuffled_story=shuffled_story,
|
||||||
answer = names[0] # The first person is the eldest ancestor
|
last_person=last_person)
|
||||||
|
answer = names[
|
||||||
|
0] # The first person is the eldest ancestor
|
||||||
|
|
||||||
data['prompt'].append(prompt)
|
data['prompt'].append(prompt)
|
||||||
data['answer'].append(answer)
|
data['answer'].append(answer)
|
||||||
@ -411,4 +437,4 @@ class NeedleBenchATCDataset(BaseDataset):
|
|||||||
'answer': data['answer'],
|
'answer': data['answer'],
|
||||||
'question_type': data['question_type'],
|
'question_type': data['question_type'],
|
||||||
})
|
})
|
||||||
return dataset
|
return dataset
|
||||||
|
@ -4,14 +4,16 @@ import json
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from datasets import Dataset
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
from opencompass.registry import LOAD_DATASET
|
from opencompass.registry import LOAD_DATASET
|
||||||
from opencompass.utils import get_data_path
|
from opencompass.utils import get_data_path
|
||||||
|
|
||||||
from ..base import BaseDataset
|
from ..base import BaseDataset
|
||||||
from .atc import relationship_terms_zh_CN, relationship_templates_zh_CN, relationship_terms_en, relationship_templates_en
|
from .atc import (relationship_templates_en, relationship_templates_zh_CN,
|
||||||
|
relationship_terms_en, relationship_terms_zh_CN)
|
||||||
|
|
||||||
|
|
||||||
def get_number(options):
|
def get_number(options):
|
||||||
result_string = ''
|
result_string = ''
|
||||||
@ -173,10 +175,13 @@ Example 3: If Xiao Ming is Zhang Hong's great-granddaughter, Zhang Hong's grandm
|
|||||||
)
|
)
|
||||||
names.extend(additional_names)
|
names.extend(additional_names)
|
||||||
|
|
||||||
num_samples = 3
|
num_samples = 3
|
||||||
if len(names) > 1:
|
if len(names) > 1:
|
||||||
indices = np.linspace(1, len(names) - 1, num_samples, dtype=int) # Generate evenly spaced indices
|
indices = np.linspace(
|
||||||
sampled_names = [names[i] for i in indices] # Select corresponding elements
|
1, len(names) - 1, num_samples,
|
||||||
|
dtype=int) # Generate evenly spaced indices
|
||||||
|
sampled_names = [names[i] for i in indices
|
||||||
|
] # Select corresponding elements
|
||||||
entry['options'] = names[:1] + sampled_names
|
entry['options'] = names[:1] + sampled_names
|
||||||
else:
|
else:
|
||||||
entry['options'] = names # Return directly if only one element
|
entry['options'] = names # Return directly if only one element
|
||||||
|
@ -3,14 +3,15 @@ import json
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
||||||
from opencompass.datasets.base import BaseDataset
|
from opencompass.datasets.base import BaseDataset
|
||||||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, TEXT_POSTPROCESSORS
|
|
||||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
|
||||||
|
|
||||||
from opencompass.utils import get_data_path
|
|
||||||
from opencompass.datasets.math import extract_boxed_answer
|
from opencompass.datasets.math import extract_boxed_answer
|
||||||
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||||
|
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
|
||||||
|
TEXT_POSTPROCESSORS)
|
||||||
|
from opencompass.utils import get_data_path
|
||||||
|
|
||||||
relationship_templates_zh_CN = [
|
relationship_templates_zh_CN = [
|
||||||
'{A}是{B}的{relationship}。',
|
'{A}是{B}的{relationship}。',
|
||||||
@ -57,7 +58,7 @@ relationship_templates_en = [
|
|||||||
'{B} is the child of {A}.',
|
'{B} is the child of {A}.',
|
||||||
('For {B}, {A} is not just a {relationship}, '
|
('For {B}, {A} is not just a {relationship}, '
|
||||||
'but also a friend.'),
|
'but also a friend.'),
|
||||||
"For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.",
|
'For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.',
|
||||||
]
|
]
|
||||||
|
|
||||||
shuffled_story_with_prompt_zh_CN = """下面是对你的多步推理能力的测试,这个测试叫做祖先追溯测试,我们会模拟不同人的家庭亲属关系,你的任务是在其中不断推理,直到找到最年长的祖先。
|
shuffled_story_with_prompt_zh_CN = """下面是对你的多步推理能力的测试,这个测试叫做祖先追溯测试,我们会模拟不同人的家庭亲属关系,你的任务是在其中不断推理,直到找到最年长的祖先。
|
||||||
@ -125,7 +126,7 @@ class NeedleBenchATCDataset(BaseDataset):
|
|||||||
# 使用固定种子来保持样本稳定性
|
# 使用固定种子来保持样本稳定性
|
||||||
seed = i
|
seed = i
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
names = random.sample(all_names, num_needles)
|
names = random.sample(all_names, num_needles)
|
||||||
if language == 'Chinese':
|
if language == 'Chinese':
|
||||||
relationship_terms = relationship_terms_zh_CN
|
relationship_terms = relationship_terms_zh_CN
|
||||||
@ -163,9 +164,11 @@ class NeedleBenchATCDataset(BaseDataset):
|
|||||||
|
|
||||||
# Generating the prompt based on the language
|
# Generating the prompt based on the language
|
||||||
if language == 'Chinese':
|
if language == 'Chinese':
|
||||||
shuffled_story_with_prompt = shuffled_story_with_prompt_zh_CN.format(shuffled_story=shuffled_story, last_person=last_person)
|
shuffled_story_with_prompt = shuffled_story_with_prompt_zh_CN.format(
|
||||||
|
shuffled_story=shuffled_story, last_person=last_person)
|
||||||
elif language == 'English':
|
elif language == 'English':
|
||||||
shuffled_story_with_prompt = shuffled_story_with_prompt_en.format(shuffled_story=shuffled_story, last_person=last_person)
|
shuffled_story_with_prompt = shuffled_story_with_prompt_en.format(
|
||||||
|
shuffled_story=shuffled_story, last_person=last_person)
|
||||||
else:
|
else:
|
||||||
prompt = 'Language not supported.'
|
prompt = 'Language not supported.'
|
||||||
raise Exception('Unsupported language specified. '
|
raise Exception('Unsupported language specified. '
|
||||||
@ -182,46 +185,47 @@ class NeedleBenchATCDataset(BaseDataset):
|
|||||||
|
|
||||||
|
|
||||||
def clean_atc_answer(text: str) -> str:
|
def clean_atc_answer(text: str) -> str:
|
||||||
"""Clean answer format specifically for QwQ-32B-Preview model
|
"""Clean answer format specifically for QwQ-32B-Preview model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Raw prediction text
|
text: Raw prediction text
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Standardized name format after cleaning
|
Standardized name format after cleaning
|
||||||
"""
|
"""
|
||||||
if not text or text == "None":
|
if not text or text == 'None':
|
||||||
return "None"
|
return 'None'
|
||||||
|
|
||||||
# Remove LaTeX commands but keep content
|
# Remove LaTeX commands but keep content
|
||||||
text = re.sub(r'\\text\{([^}]+)\}', r'\1', text)
|
text = re.sub(r'\\text\{([^}]+)\}', r'\1', text)
|
||||||
text = re.sub(r'\\boxed\{([^}]+)\}', r'\1', text)
|
text = re.sub(r'\\boxed\{([^}]+)\}', r'\1', text)
|
||||||
text = re.sub(r'\\[\[\]]', '', text)
|
text = re.sub(r'\\[\[\]]', '', text)
|
||||||
|
|
||||||
# Remove extra backslashes
|
# Remove extra backslashes
|
||||||
text = text.replace('\\\\', '').replace('\\', '')
|
text = text.replace('\\\\', '').replace('\\', '')
|
||||||
|
|
||||||
# Handle extra spaces
|
# Handle extra spaces
|
||||||
text = re.sub(r'\s+', ' ', text).strip()
|
text = re.sub(r'\s+', ' ', text).strip()
|
||||||
|
|
||||||
# Remove quotes
|
# Remove quotes
|
||||||
text = text.replace('"', '').replace("'", '')
|
text = text.replace('"', '').replace("'", '')
|
||||||
# Remove tildes (波浪符号)
|
# Remove tildes (波浪符号)
|
||||||
text = text.replace('~', ' ')
|
text = text.replace('~', ' ')
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
@TEXT_POSTPROCESSORS.register_module('needlebench_atc_postprocess_v2')
|
@TEXT_POSTPROCESSORS.register_module('needlebench_atc_postprocess_v2')
|
||||||
def needlebench_atc_postprocess_v2(text: str) -> str:
|
def needlebench_atc_postprocess_v2(text: str) -> str:
|
||||||
|
|
||||||
cand_ans = extract_boxed_answer(text, strip_double_curly_brace=True)
|
cand_ans = extract_boxed_answer(text, strip_double_curly_brace=True)
|
||||||
|
|
||||||
if cand_ans:
|
if cand_ans:
|
||||||
return clean_atc_answer(cand_ans)
|
return clean_atc_answer(cand_ans)
|
||||||
return "None"
|
return 'None'
|
||||||
|
|
||||||
|
|
||||||
@ICL_EVALUATORS.register_module("needlebench_atc_evaluator")
|
@ICL_EVALUATORS.register_module('needlebench_atc_evaluator')
|
||||||
class NeedleBenchATCEvaluator(BaseEvaluator):
|
class NeedleBenchATCEvaluator(BaseEvaluator):
|
||||||
|
|
||||||
def score(self, predictions, gold):
|
def score(self, predictions, gold):
|
||||||
@ -230,12 +234,12 @@ class NeedleBenchATCEvaluator(BaseEvaluator):
|
|||||||
|
|
||||||
correct_count = 0
|
correct_count = 0
|
||||||
details = []
|
details = []
|
||||||
|
|
||||||
for prediction, reference in zip(predictions, gold):
|
for prediction, reference in zip(predictions, gold):
|
||||||
reference_name = reference
|
reference_name = reference
|
||||||
if prediction.strip() == reference_name.strip():
|
if prediction.strip() == reference_name.strip():
|
||||||
correct_count += 1
|
correct_count += 1
|
||||||
|
|
||||||
detail = {
|
detail = {
|
||||||
'pred': prediction,
|
'pred': prediction,
|
||||||
'answer': reference_name,
|
'answer': reference_name,
|
||||||
@ -243,6 +247,7 @@ class NeedleBenchATCEvaluator(BaseEvaluator):
|
|||||||
}
|
}
|
||||||
details.append(detail)
|
details.append(detail)
|
||||||
|
|
||||||
accuracy = (correct_count / len(predictions)) * 100 if predictions else 0
|
accuracy = (correct_count /
|
||||||
|
len(predictions)) * 100 if predictions else 0
|
||||||
result = {'score': accuracy, 'details': details}
|
result = {'score': accuracy, 'details': details}
|
||||||
return result
|
return result
|
||||||
|
@ -7,8 +7,12 @@ from datasets import Dataset
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
from opencompass.datasets.base import BaseDataset
|
from opencompass.datasets.base import BaseDataset
|
||||||
|
from opencompass.datasets.needlebench.atc import (relationship_templates_en,
|
||||||
|
relationship_templates_zh_CN,
|
||||||
|
relationship_terms_en,
|
||||||
|
relationship_terms_zh_CN)
|
||||||
from opencompass.registry import LOAD_DATASET
|
from opencompass.registry import LOAD_DATASET
|
||||||
from opencompass.datasets.needlebench.atc import relationship_templates_zh_CN, relationship_terms_zh_CN, relationship_templates_en, relationship_terms_en
|
|
||||||
|
|
||||||
def get_random_needles(counter, file_path, num_needles, language):
|
def get_random_needles(counter, file_path, num_needles, language):
|
||||||
with open(file_path, 'r', encoding='utf-8') as file:
|
with open(file_path, 'r', encoding='utf-8') as file:
|
||||||
@ -33,18 +37,23 @@ def get_random_needles(counter, file_path, num_needles, language):
|
|||||||
for i in range(len(names) - 1):
|
for i in range(len(names) - 1):
|
||||||
template = random.choice(templates)
|
template = random.choice(templates)
|
||||||
relation_term = random.choice(relationship_terms)
|
relation_term = random.choice(relationship_terms)
|
||||||
relation = template.format(A=names[i], B=names[i + 1], relationship=relation_term)
|
relation = template.format(A=names[i],
|
||||||
|
B=names[i + 1],
|
||||||
|
relationship=relation_term)
|
||||||
story += f'{relation}*'
|
story += f'{relation}*'
|
||||||
return story
|
return story
|
||||||
|
|
||||||
chain_story = generate_chain_family_story(names, relationship_templates, relationship_terms)
|
chain_story = generate_chain_family_story(names, relationship_templates,
|
||||||
|
relationship_terms)
|
||||||
|
|
||||||
# Splitting the chain_story into a list of fragments
|
# Splitting the chain_story into a list of fragments
|
||||||
family_story_fragments = chain_story.split('*')
|
family_story_fragments = chain_story.split('*')
|
||||||
|
|
||||||
# Removing the empty string from the list
|
# Removing the empty string from the list
|
||||||
family_story_fragments = [fragment for fragment in family_story_fragments if fragment]
|
family_story_fragments = [
|
||||||
|
fragment for fragment in family_story_fragments if fragment
|
||||||
|
]
|
||||||
|
|
||||||
# Shuffling the list of fragments
|
# Shuffling the list of fragments
|
||||||
random.shuffle(family_story_fragments)
|
random.shuffle(family_story_fragments)
|
||||||
|
|
||||||
@ -55,7 +64,7 @@ def get_random_needles(counter, file_path, num_needles, language):
|
|||||||
retrieval_question = f"在上面提供的文本中,'{last_person}'的能够向上追溯到的最年长的亲人是谁?"
|
retrieval_question = f"在上面提供的文本中,'{last_person}'的能够向上追溯到的最年长的亲人是谁?"
|
||||||
elif language == 'English':
|
elif language == 'English':
|
||||||
retrieval_question = f"Given the context described above, who is the eldest relative that '{last_person}' can trace back to in the context?"
|
retrieval_question = f"Given the context described above, who is the eldest relative that '{last_person}' can trace back to in the context?"
|
||||||
|
|
||||||
# Returning the story, answer, and retrieval question
|
# Returning the story, answer, and retrieval question
|
||||||
return {
|
return {
|
||||||
'needles': family_story_fragments,
|
'needles': family_story_fragments,
|
||||||
@ -65,7 +74,6 @@ def get_random_needles(counter, file_path, num_needles, language):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@LOAD_DATASET.register_module()
|
@LOAD_DATASET.register_module()
|
||||||
class NeedleBenchMultiDataset(BaseDataset):
|
class NeedleBenchMultiDataset(BaseDataset):
|
||||||
|
|
||||||
@ -216,8 +224,9 @@ The content of the long document is as follows
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. '
|
raise ValueError(
|
||||||
'Position must be "End" or "Start".')
|
f'Unsupported quesiton_position {quesiton_position}. '
|
||||||
|
'Position must be "End" or "Start".')
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Language '{language}' is not supported.")
|
raise ValueError(f"Language '{language}' is not supported.")
|
||||||
|
|
||||||
@ -225,7 +234,7 @@ The content of the long document is as follows
|
|||||||
|
|
||||||
repo_id = 'opencompass/NeedleBench'
|
repo_id = 'opencompass/NeedleBench'
|
||||||
file_names = [
|
file_names = [
|
||||||
'PaulGrahamEssays.jsonl','names.json', 'zh_finance.jsonl',
|
'PaulGrahamEssays.jsonl', 'names.json', 'zh_finance.jsonl',
|
||||||
'zh_game.jsonl', 'zh_general.jsonl', 'zh_government.jsonl',
|
'zh_game.jsonl', 'zh_general.jsonl', 'zh_government.jsonl',
|
||||||
'zh_movie.jsonl', 'zh_tech.jsonl'
|
'zh_movie.jsonl', 'zh_tech.jsonl'
|
||||||
]
|
]
|
||||||
@ -250,7 +259,7 @@ The content of the long document is as follows
|
|||||||
random.seed(counter)
|
random.seed(counter)
|
||||||
random.shuffle(lines)
|
random.shuffle(lines)
|
||||||
random_needle_data = get_random_needles(
|
random_needle_data = get_random_needles(
|
||||||
counter, needle_file_path, num_needles+1, language)
|
counter, needle_file_path, num_needles + 1, language)
|
||||||
last_person = random_needle_data['last_person']
|
last_person = random_needle_data['last_person']
|
||||||
needles = [
|
needles = [
|
||||||
'\n' + needle + '\n'
|
'\n' + needle + '\n'
|
||||||
@ -278,7 +287,8 @@ The content of the long document is as follows
|
|||||||
needles)
|
needles)
|
||||||
|
|
||||||
processed_prompt = _generate_prompt(processed_text,
|
processed_prompt = _generate_prompt(processed_text,
|
||||||
retrieval_question, last_person)
|
retrieval_question,
|
||||||
|
last_person)
|
||||||
|
|
||||||
data['prompt'].append(processed_prompt)
|
data['prompt'].append(processed_prompt)
|
||||||
data['answer'].append(keyword)
|
data['answer'].append(keyword)
|
||||||
@ -287,4 +297,4 @@ The content of the long document is as follows
|
|||||||
'prompt': data['prompt'],
|
'prompt': data['prompt'],
|
||||||
'answer': data['answer'],
|
'answer': data['answer'],
|
||||||
})
|
})
|
||||||
return dataset
|
return dataset
|
||||||
|
@ -114,8 +114,9 @@ The content of the long document is as follows
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. '
|
raise ValueError(
|
||||||
'Position must be "End" or "Start".')
|
f'Unsupported quesiton_position {quesiton_position}. '
|
||||||
|
'Position must be "End" or "Start".')
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Language '{language}' is not supported.")
|
raise ValueError(f"Language '{language}' is not supported.")
|
||||||
|
|
||||||
@ -201,11 +202,7 @@ class NeedleBenchOriginEvaluator(BaseEvaluator):
|
|||||||
else:
|
else:
|
||||||
score = 0
|
score = 0
|
||||||
|
|
||||||
detail = {
|
detail = {'pred': prediction, 'answer': reference, 'score': score}
|
||||||
'pred': prediction,
|
|
||||||
'answer': reference,
|
|
||||||
'score': score
|
|
||||||
}
|
|
||||||
total_score += score
|
total_score += score
|
||||||
details.append(detail)
|
details.append(detail)
|
||||||
|
|
||||||
|
@ -158,8 +158,9 @@ class NeedleBenchParallelDataset(BaseDataset):
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. '
|
raise ValueError(
|
||||||
'Position must be "End" or "Start".')
|
f'Unsupported quesiton_position {quesiton_position}. '
|
||||||
|
'Position must be "End" or "Start".')
|
||||||
elif language == 'English':
|
elif language == 'English':
|
||||||
if quesiton_position == 'End':
|
if quesiton_position == 'End':
|
||||||
prompt = f'''This is a test of long-text capability. You need to first read the long document below, and then answer the final questions one by one based on the information in the document.
|
prompt = f'''This is a test of long-text capability. You need to first read the long document below, and then answer the final questions one by one based on the information in the document.
|
||||||
@ -183,8 +184,9 @@ The content of the long document is as follows
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. '
|
raise ValueError(
|
||||||
'Position must be "End" or "Start".')
|
f'Unsupported quesiton_position {quesiton_position}. '
|
||||||
|
'Position must be "End" or "Start".')
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Language '{language}' is not supported.")
|
raise ValueError(f"Language '{language}' is not supported.")
|
||||||
|
|
||||||
@ -269,6 +271,7 @@ The content of the long document is as follows
|
|||||||
|
|
||||||
|
|
||||||
class NeedleBenchParallelEvaluator(BaseEvaluator):
|
class NeedleBenchParallelEvaluator(BaseEvaluator):
|
||||||
|
|
||||||
def score(self, predictions, gold):
|
def score(self, predictions, gold):
|
||||||
if len(predictions) != len(gold):
|
if len(predictions) != len(gold):
|
||||||
return {'error': 'predictions and gold have different lengths'}
|
return {'error': 'predictions and gold have different lengths'}
|
||||||
|
Loading…
Reference in New Issue
Block a user