mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
fix lint
This commit is contained in:
parent
081c185b8f
commit
5b1a5fa596
@ -4,11 +4,14 @@ import os
|
||||
import random
|
||||
import re
|
||||
from enum import Enum
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.datasets.base import BaseDataset
|
||||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, TEXT_POSTPROCESSORS
|
||||
from opencompass.datasets.needlebench.atc_elder_only import clean_atc_answer, needlebench_atc_postprocess_v2, NeedleBenchATCEvaluator
|
||||
from opencompass.datasets.needlebench.atc_elder_only import (
|
||||
NeedleBenchATCEvaluator, clean_atc_answer, needlebench_atc_postprocess_v2)
|
||||
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
|
||||
TEXT_POSTPROCESSORS)
|
||||
from opencompass.utils import get_data_path
|
||||
|
||||
|
||||
@ -19,6 +22,7 @@ class QuestionType(Enum):
|
||||
NTH_DESCENDANT = 2 # N级子节点
|
||||
RELATIONSHIP_DISTANCE = 3 # 关系距离
|
||||
|
||||
|
||||
# 定义关系术语的代数映射(一代关系还是两代关系)
|
||||
relationship_generation_map_zh = {
|
||||
'父亲': 1,
|
||||
@ -92,7 +96,7 @@ relationship_templates_en = [
|
||||
"but also {B}'s guardian."),
|
||||
('For {B}, {A} is not just a {relationship}, '
|
||||
'but also a friend.'),
|
||||
"For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.",
|
||||
'For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.',
|
||||
]
|
||||
|
||||
# Eldest ancestor problem template
|
||||
@ -249,6 +253,7 @@ Now, the scrambled family relationships are provided below:
|
||||
Given the scrambled family relationships described above, what is the relationship distance between '{person_a}' and '{person_b}'?
|
||||
"""
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class NeedleBenchATCDataset(BaseDataset):
|
||||
|
||||
@ -305,10 +310,13 @@ class NeedleBenchATCDataset(BaseDataset):
|
||||
relationship_templates = relationship_templates_en
|
||||
relationship_map = relationship_generation_map_en
|
||||
else:
|
||||
raise ValueError('Unsupported language specified. '
|
||||
raise ValueError(
|
||||
'Unsupported language specified. '
|
||||
'Please choose either "Chinese" or "English".')
|
||||
|
||||
def generate_chain_family_story(names, templates, relationship_terms, relationship_map):
|
||||
def generate_chain_family_story(names, templates,
|
||||
relationship_terms,
|
||||
relationship_map):
|
||||
story = ''
|
||||
relationships = []
|
||||
total_generations = 0 # Track the total generational difference
|
||||
@ -322,20 +330,25 @@ class NeedleBenchATCDataset(BaseDataset):
|
||||
story += f'{relation}*'
|
||||
|
||||
# Get the generation difference for this relationship
|
||||
gen_diff = relationship_map.get(relation_term, 1) # Default to 1 generation
|
||||
gen_diff = relationship_map.get(
|
||||
relation_term, 1) # Default to 1 generation
|
||||
total_generations += gen_diff
|
||||
|
||||
# Record relationship information for later use
|
||||
relationships.append((names[i], names[i + 1], relation_term, gen_diff))
|
||||
relationships.append(
|
||||
(names[i], names[i + 1], relation_term, gen_diff))
|
||||
|
||||
return story, relationships, total_generations
|
||||
|
||||
chain_story, relationships, total_generations = generate_chain_family_story(
|
||||
names, relationship_templates, relationship_terms, relationship_map)
|
||||
names, relationship_templates, relationship_terms,
|
||||
relationship_map)
|
||||
|
||||
# Split the chain_story into a list of fragments
|
||||
family_story_fragments = chain_story.split('*')
|
||||
family_story_fragments = [f for f in family_story_fragments if f]
|
||||
family_story_fragments = [
|
||||
f for f in family_story_fragments if f
|
||||
]
|
||||
|
||||
# Shuffle the list of fragments
|
||||
random.shuffle(family_story_fragments)
|
||||
@ -348,15 +361,19 @@ class NeedleBenchATCDataset(BaseDataset):
|
||||
last_person = names[-1]
|
||||
if language == 'Chinese':
|
||||
prompt = shuffled_story_with_prompt_zh_CN.format(
|
||||
shuffled_story=shuffled_story, last_person=last_person)
|
||||
shuffled_story=shuffled_story,
|
||||
last_person=last_person)
|
||||
else:
|
||||
prompt = shuffled_story_with_prompt_en.format(
|
||||
shuffled_story=shuffled_story, last_person=last_person)
|
||||
answer = names[0] # The first person is the eldest ancestor
|
||||
shuffled_story=shuffled_story,
|
||||
last_person=last_person)
|
||||
answer = names[
|
||||
0] # The first person is the eldest ancestor
|
||||
|
||||
elif question_type == QuestionType.NTH_ANCESTOR:
|
||||
# Nth ancestor question - trace from the youngest person to the oldest
|
||||
person = names[-1] # The youngest person (end of the chain)
|
||||
person = names[
|
||||
-1] # The youngest person (end of the chain)
|
||||
n = total_generations # Use the calculated total generational difference
|
||||
if language == 'Chinese':
|
||||
prompt = nth_ancestor_prompt_zh_CN.format(
|
||||
@ -364,7 +381,8 @@ class NeedleBenchATCDataset(BaseDataset):
|
||||
else:
|
||||
prompt = nth_ancestor_prompt_en.format(
|
||||
shuffled_story=shuffled_story, person=person, n=n)
|
||||
answer = names[0] # The oldest person (start of the chain) is the nth ancestor
|
||||
answer = names[
|
||||
0] # The oldest person (start of the chain) is the nth ancestor
|
||||
|
||||
elif question_type == QuestionType.NTH_DESCENDANT:
|
||||
# Nth descendant question - trace from the oldest person to the youngest
|
||||
@ -376,7 +394,8 @@ class NeedleBenchATCDataset(BaseDataset):
|
||||
else:
|
||||
prompt = nth_descendant_prompt_en.format(
|
||||
shuffled_story=shuffled_story, person=person, n=n)
|
||||
answer = names[-1] # The youngest person (end of the chain) is the nth descendant
|
||||
answer = names[
|
||||
-1] # The youngest person (end of the chain) is the nth descendant
|
||||
|
||||
elif question_type == QuestionType.RELATIONSHIP_DISTANCE:
|
||||
# Relationship distance question - calculate the relationship distance between the two ends of the chain
|
||||
@ -384,10 +403,14 @@ class NeedleBenchATCDataset(BaseDataset):
|
||||
person_b = names[-1] # The youngest person
|
||||
if language == 'Chinese':
|
||||
prompt = relationship_distance_prompt_zh_CN.format(
|
||||
shuffled_story=shuffled_story, person_a=person_a, person_b=person_b)
|
||||
shuffled_story=shuffled_story,
|
||||
person_a=person_a,
|
||||
person_b=person_b)
|
||||
else:
|
||||
prompt = relationship_distance_prompt_en.format(
|
||||
shuffled_story=shuffled_story, person_a=person_a, person_b=person_b)
|
||||
shuffled_story=shuffled_story,
|
||||
person_a=person_a,
|
||||
person_b=person_b)
|
||||
# Use the calculated total generations as the relationship distance
|
||||
answer = str(total_generations)
|
||||
|
||||
@ -396,11 +419,14 @@ class NeedleBenchATCDataset(BaseDataset):
|
||||
last_person = names[-1]
|
||||
if language == 'Chinese':
|
||||
prompt = shuffled_story_with_prompt_zh_CN.format(
|
||||
shuffled_story=shuffled_story, last_person=last_person)
|
||||
shuffled_story=shuffled_story,
|
||||
last_person=last_person)
|
||||
else:
|
||||
prompt = shuffled_story_with_prompt_en.format(
|
||||
shuffled_story=shuffled_story, last_person=last_person)
|
||||
answer = names[0] # The first person is the eldest ancestor
|
||||
shuffled_story=shuffled_story,
|
||||
last_person=last_person)
|
||||
answer = names[
|
||||
0] # The first person is the eldest ancestor
|
||||
|
||||
data['prompt'].append(prompt)
|
||||
data['answer'].append(answer)
|
||||
|
@ -4,14 +4,16 @@ import json
|
||||
import os
|
||||
import random
|
||||
|
||||
from datasets import Dataset
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
from opencompass.utils import get_data_path
|
||||
|
||||
from ..base import BaseDataset
|
||||
from .atc import relationship_terms_zh_CN, relationship_templates_zh_CN, relationship_terms_en, relationship_templates_en
|
||||
from .atc import (relationship_templates_en, relationship_templates_zh_CN,
|
||||
relationship_terms_en, relationship_terms_zh_CN)
|
||||
|
||||
|
||||
def get_number(options):
|
||||
result_string = ''
|
||||
@ -175,8 +177,11 @@ Example 3: If Xiao Ming is Zhang Hong's great-granddaughter, Zhang Hong's grandm
|
||||
|
||||
num_samples = 3
|
||||
if len(names) > 1:
|
||||
indices = np.linspace(1, len(names) - 1, num_samples, dtype=int) # Generate evenly spaced indices
|
||||
sampled_names = [names[i] for i in indices] # Select corresponding elements
|
||||
indices = np.linspace(
|
||||
1, len(names) - 1, num_samples,
|
||||
dtype=int) # Generate evenly spaced indices
|
||||
sampled_names = [names[i] for i in indices
|
||||
] # Select corresponding elements
|
||||
entry['options'] = names[:1] + sampled_names
|
||||
else:
|
||||
entry['options'] = names # Return directly if only one element
|
||||
|
@ -3,14 +3,15 @@ import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.datasets.base import BaseDataset
|
||||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, TEXT_POSTPROCESSORS
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
|
||||
from opencompass.utils import get_data_path
|
||||
from opencompass.datasets.math import extract_boxed_answer
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
|
||||
TEXT_POSTPROCESSORS)
|
||||
from opencompass.utils import get_data_path
|
||||
|
||||
relationship_templates_zh_CN = [
|
||||
'{A}是{B}的{relationship}。',
|
||||
@ -57,7 +58,7 @@ relationship_templates_en = [
|
||||
'{B} is the child of {A}.',
|
||||
('For {B}, {A} is not just a {relationship}, '
|
||||
'but also a friend.'),
|
||||
"For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.",
|
||||
'For {B}, {A} is more than just a {relationship}; {A} is a lifelong mentor of {B}.',
|
||||
]
|
||||
|
||||
shuffled_story_with_prompt_zh_CN = """下面是对你的多步推理能力的测试,这个测试叫做祖先追溯测试,我们会模拟不同人的家庭亲属关系,你的任务是在其中不断推理,直到找到最年长的祖先。
|
||||
@ -163,9 +164,11 @@ class NeedleBenchATCDataset(BaseDataset):
|
||||
|
||||
# Generating the prompt based on the language
|
||||
if language == 'Chinese':
|
||||
shuffled_story_with_prompt = shuffled_story_with_prompt_zh_CN.format(shuffled_story=shuffled_story, last_person=last_person)
|
||||
shuffled_story_with_prompt = shuffled_story_with_prompt_zh_CN.format(
|
||||
shuffled_story=shuffled_story, last_person=last_person)
|
||||
elif language == 'English':
|
||||
shuffled_story_with_prompt = shuffled_story_with_prompt_en.format(shuffled_story=shuffled_story, last_person=last_person)
|
||||
shuffled_story_with_prompt = shuffled_story_with_prompt_en.format(
|
||||
shuffled_story=shuffled_story, last_person=last_person)
|
||||
else:
|
||||
prompt = 'Language not supported.'
|
||||
raise Exception('Unsupported language specified. '
|
||||
@ -182,7 +185,7 @@ class NeedleBenchATCDataset(BaseDataset):
|
||||
|
||||
|
||||
def clean_atc_answer(text: str) -> str:
|
||||
"""Clean answer format specifically for QwQ-32B-Preview model
|
||||
"""Clean answer format specifically for QwQ-32B-Preview model.
|
||||
|
||||
Args:
|
||||
text: Raw prediction text
|
||||
@ -190,8 +193,8 @@ def clean_atc_answer(text: str) -> str:
|
||||
Returns:
|
||||
Standardized name format after cleaning
|
||||
"""
|
||||
if not text or text == "None":
|
||||
return "None"
|
||||
if not text or text == 'None':
|
||||
return 'None'
|
||||
|
||||
# Remove LaTeX commands but keep content
|
||||
text = re.sub(r'\\text\{([^}]+)\}', r'\1', text)
|
||||
@ -211,6 +214,7 @@ def clean_atc_answer(text: str) -> str:
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@TEXT_POSTPROCESSORS.register_module('needlebench_atc_postprocess_v2')
|
||||
def needlebench_atc_postprocess_v2(text: str) -> str:
|
||||
|
||||
@ -218,10 +222,10 @@ def needlebench_atc_postprocess_v2(text: str) -> str:
|
||||
|
||||
if cand_ans:
|
||||
return clean_atc_answer(cand_ans)
|
||||
return "None"
|
||||
return 'None'
|
||||
|
||||
|
||||
@ICL_EVALUATORS.register_module("needlebench_atc_evaluator")
|
||||
@ICL_EVALUATORS.register_module('needlebench_atc_evaluator')
|
||||
class NeedleBenchATCEvaluator(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, gold):
|
||||
@ -243,6 +247,7 @@ class NeedleBenchATCEvaluator(BaseEvaluator):
|
||||
}
|
||||
details.append(detail)
|
||||
|
||||
accuracy = (correct_count / len(predictions)) * 100 if predictions else 0
|
||||
accuracy = (correct_count /
|
||||
len(predictions)) * 100 if predictions else 0
|
||||
result = {'score': accuracy, 'details': details}
|
||||
return result
|
@ -7,8 +7,12 @@ from datasets import Dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from opencompass.datasets.base import BaseDataset
|
||||
from opencompass.datasets.needlebench.atc import (relationship_templates_en,
|
||||
relationship_templates_zh_CN,
|
||||
relationship_terms_en,
|
||||
relationship_terms_zh_CN)
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
from opencompass.datasets.needlebench.atc import relationship_templates_zh_CN, relationship_terms_zh_CN, relationship_templates_en, relationship_terms_en
|
||||
|
||||
|
||||
def get_random_needles(counter, file_path, num_needles, language):
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
@ -33,17 +37,22 @@ def get_random_needles(counter, file_path, num_needles, language):
|
||||
for i in range(len(names) - 1):
|
||||
template = random.choice(templates)
|
||||
relation_term = random.choice(relationship_terms)
|
||||
relation = template.format(A=names[i], B=names[i + 1], relationship=relation_term)
|
||||
relation = template.format(A=names[i],
|
||||
B=names[i + 1],
|
||||
relationship=relation_term)
|
||||
story += f'{relation}*'
|
||||
return story
|
||||
|
||||
chain_story = generate_chain_family_story(names, relationship_templates, relationship_terms)
|
||||
chain_story = generate_chain_family_story(names, relationship_templates,
|
||||
relationship_terms)
|
||||
|
||||
# Splitting the chain_story into a list of fragments
|
||||
family_story_fragments = chain_story.split('*')
|
||||
|
||||
# Removing the empty string from the list
|
||||
family_story_fragments = [fragment for fragment in family_story_fragments if fragment]
|
||||
family_story_fragments = [
|
||||
fragment for fragment in family_story_fragments if fragment
|
||||
]
|
||||
|
||||
# Shuffling the list of fragments
|
||||
random.shuffle(family_story_fragments)
|
||||
@ -65,7 +74,6 @@ def get_random_needles(counter, file_path, num_needles, language):
|
||||
}
|
||||
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class NeedleBenchMultiDataset(BaseDataset):
|
||||
|
||||
@ -216,7 +224,8 @@ The content of the long document is as follows
|
||||
|
||||
'''
|
||||
else:
|
||||
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. '
|
||||
raise ValueError(
|
||||
f'Unsupported quesiton_position {quesiton_position}. '
|
||||
'Position must be "End" or "Start".')
|
||||
else:
|
||||
raise ValueError(f"Language '{language}' is not supported.")
|
||||
@ -278,7 +287,8 @@ The content of the long document is as follows
|
||||
needles)
|
||||
|
||||
processed_prompt = _generate_prompt(processed_text,
|
||||
retrieval_question, last_person)
|
||||
retrieval_question,
|
||||
last_person)
|
||||
|
||||
data['prompt'].append(processed_prompt)
|
||||
data['answer'].append(keyword)
|
||||
|
@ -114,7 +114,8 @@ The content of the long document is as follows
|
||||
|
||||
'''
|
||||
else:
|
||||
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. '
|
||||
raise ValueError(
|
||||
f'Unsupported quesiton_position {quesiton_position}. '
|
||||
'Position must be "End" or "Start".')
|
||||
else:
|
||||
raise ValueError(f"Language '{language}' is not supported.")
|
||||
@ -201,11 +202,7 @@ class NeedleBenchOriginEvaluator(BaseEvaluator):
|
||||
else:
|
||||
score = 0
|
||||
|
||||
detail = {
|
||||
'pred': prediction,
|
||||
'answer': reference,
|
||||
'score': score
|
||||
}
|
||||
detail = {'pred': prediction, 'answer': reference, 'score': score}
|
||||
total_score += score
|
||||
details.append(detail)
|
||||
|
||||
|
@ -158,7 +158,8 @@ class NeedleBenchParallelDataset(BaseDataset):
|
||||
|
||||
'''
|
||||
else:
|
||||
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. '
|
||||
raise ValueError(
|
||||
f'Unsupported quesiton_position {quesiton_position}. '
|
||||
'Position must be "End" or "Start".')
|
||||
elif language == 'English':
|
||||
if quesiton_position == 'End':
|
||||
@ -183,7 +184,8 @@ The content of the long document is as follows
|
||||
|
||||
'''
|
||||
else:
|
||||
raise ValueError(f'Unsupported quesiton_position {quesiton_position}. '
|
||||
raise ValueError(
|
||||
f'Unsupported quesiton_position {quesiton_position}. '
|
||||
'Position must be "End" or "Start".')
|
||||
else:
|
||||
raise ValueError(f"Language '{language}' is not supported.")
|
||||
@ -269,6 +271,7 @@ The content of the long document is as follows
|
||||
|
||||
|
||||
class NeedleBenchParallelEvaluator(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, gold):
|
||||
if len(predictions) != len(gold):
|
||||
return {'error': 'predictions and gold have different lengths'}
|
||||
|
Loading…
Reference in New Issue
Block a user