OpenCompass/opencompass/datasets/needlebench/atc_choice.py

170 lines
6.4 KiB
Python
Raw Normal View History

[Feature] Add ATC Choice Version (#1019) * Squashed commit of the following: commit c48ad194c3976dc63d1b60d8c8ab2d5ff9e1cbfe Author: DseidLi <2568818204@qq.com> Date: Tue Apr 2 16:57:43 2024 +0800 add atc_choice commit 3ac6efea29619573e6fac8fa3cce464853dcead0 Merge: 2d4e559 8e3a9c3 Author: DseidLi <2568818204@qq.com> Date: Tue Apr 2 16:41:38 2024 +0800 Merge branch 'atc_choice' into atc_add_choice commit 8e3a9c396a3e5546d3faf584183f6fd60b974d5e Merge: 150a036 0a6a03f Author: DseidLi <2568818204@qq.com> Date: Tue Mar 26 04:47:07 2024 +0800 Merge branch 'main' into atc_choice Conflicts: configs/summarizers/needlebench.py opencompass/datasets/needlebench/multi.py opencompass/datasets/needlebench/origin.py opencompass/datasets/needlebench/parallel.py commit 150a036d6d990f26a57c974d1af83d88c31a0f9d Merge: 8d6ac9a 940dd18 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 20 03:49:08 2024 +0800 Merge branch 'needlebench_fix' into atc_choice commit 8d6ac9a1a43b1c9d0f0ea27e7d58968a203ea898 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 20 03:41:49 2024 +0800 optimize needlebench code commit 940dd18a4270f24bc69edd2a780182c68918e1a9 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 20 03:39:46 2024 +0800 fix vllm commit d8be6877bc41051f3edcc0421c462c834c0f1c9a Merge: ecad78a 2527fda Author: DseidLi <2568818204@qq.com> Date: Tue Mar 19 21:07:08 2024 +0800 Merge remote-tracking branch 'origin/add_1M_dataset' into atc_choice commit 2527fda8a546595bcaea1e5261367bc1097faec8 Author: DseidLi <2568818204@qq.com> Date: Tue Mar 19 16:03:40 2024 +0800 add model configs commit 75425acdf80d6d25ee24bb0aa60ac48539262e76 Author: DseidLi <2568818204@qq.com> Date: Tue Mar 19 16:02:15 2024 +0800 add prompt postion args commit 367ba1ba612a8cec5df1f80d5e5ae4e285baf38b Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 21:40:00 2024 +0800 add Needlebench-1000K configs commit ecad78af14c4bb00fe325779114b384c57ab30bf Author: DseidLi <2568818204@qq.com> Date: Thu Mar 14 22:08:32 2024 +0800 fix atc commit 08772c0787b18872abadc9ffec3223941a5ee0c2 Merge: 9f3f8cf caf1cf8 Author: DseidLi <2568818204@qq.com> Date: Thu Mar 14 22:07:28 2024 +0800 Merge branch 'main' into atc_choice Conflicts: configs/datasets/needlebench/readme.md configs/datasets/needlebench/readme_zh-CN.md configs/summarizers/needlebench.py opencompass/datasets/needlebench/atc.py opencompass/summarizers/needlebench.py commit 9f3f8cfb4452722734d334114ac1d14110e57406 Author: DseidLi <2568818204@qq.com> Date: Thu Mar 14 21:35:53 2024 +0800 add atc-choice test commit 52be7c1202376b4e09821188b826f1a805328129 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 6 02:54:15 2024 +0800 update needlebench randomseed and add vllm qwen14b commit fc1effce596ae2e5ece4933e8cd34aef8e64a6f9 Merge: 4e747ed caf1cf8 Author: DseidLi <2568818204@qq.com> Date: Wed Mar 6 02:51:14 2024 +0800 Merge branch 'main' into add_model_configs commit 31834f9b23af3354ac3581ec86d693d0f05cdd1c Merge: 7dabc82 120bf8b Author: DseidLi <2568818204@qq.com> Date: Sun Mar 3 23:29:42 2024 +0800 Merge branch 'main' of https://github.com/open-compass/opencompass into atc_choice commit 4e747ed1988ddbcfcc7fff334601259ade72d363 Author: DseidLi <2568818204@qq.com> Date: Sun Mar 3 22:15:25 2024 +0800 add internlm2-lmdeploy model and gemma configs commit 7dabc828123d711c8cf834d6aab4137bb55e85ed Author: DseidLi <2568818204@qq.com> Date: Sat Mar 2 17:26:15 2024 +0800 add atc choice version -ZH commit 996f8ae43d3f946a052f736717ead139d153e2dd Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:58:56 2024 +0800 update readme for needlebench commit f7266e873cb34ccf18a7f20b2c5821af8416a14f Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:44:53 2024 +0800 move readme.md commit 1c7375681dea13996802e45b878dc4929ea8fa65 Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:38:31 2024 +0800 fix linting error commit b6524f3ebfb8a3a12a5ad3e3fa7a8a0921fcb6c1 Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:33:51 2024 +0800 lint summarizer commit c0d1190e39d3b6724f677346df2572df9af59f25 Author: DseidLi <2568818204@qq.com> Date: Wed Feb 28 16:29:03 2024 +0800 add needlebench intro, fix summarizer commit 0965baf78588e29d813b61d73f0ebd868a0ce3d0 Author: DseidLi <2568818204@qq.com> Date: Mon Feb 26 13:31:26 2024 +0800 fix bug in needlebench summarizer commit 5d32b31eb85382026935f356190ad92b103afd98 Author: DseidLi <2568818204@qq.com> Date: Sat Feb 24 03:19:08 2024 +0800 update act prompt commit af82a7f085e394d83aa84043e2881dd50115942c Merge: 32bf9fe 53fe788 Author: DseidLi <2568818204@qq.com> Date: Fri Feb 23 17:50:32 2024 +0800 Merge remote-tracking branch 'upstream/main' into needlebench commit 32bf9fe802eaf8e8e5b33ff17b2a897058f8b66b Author: DseidLi <2568818204@qq.com> Date: Fri Feb 23 17:31:32 2024 +0800 simplify needlebench 32k, 128k, 200k for eval commit a7cb025e05a48449de9839005fada02bd5bff15a Author: DseidLi <2568818204@qq.com> Date: Fri Feb 23 14:48:58 2024 +0800 add needlebench * fix summarizer * remove repeated code * remove chinese comments
2024-04-07 15:46:20 +08:00
# flake8: noqa
import copy
import json
import random
from datasets import Dataset
from opencompass.registry import LOAD_DATASET
from ..base import BaseDataset
def get_number(options):
result_string = ''
for i, option in enumerate(options, start=ord('A')):
result_string += f'{chr(i)}. {option}\n'
return result_string
def get_circular_example(entry, id):
"""For given example, generate four circular examples."""
# Only 4 options is supported for current circular eval.
circular_patterns = ['ABCD', 'BCDA', 'CDAB', 'DABC']
data = []
for c in circular_patterns:
line = copy.deepcopy(entry)
options = []
for i in range(4):
options.append(line['options'][ord(c[i]) - ord('A')])
line['options'] = options
line['answer'] = {
c[0]: 'A',
c[1]: 'B',
c[2]: 'C',
c[3]: 'D'
}[line['answer']]
line['answer'] = str(id) + '--' + line['answer'] + '--' + c
line['question'] = line['question'].strip() + '\n' + get_number(
line['options'])
data.append(line)
return data
@LOAD_DATASET.register_module()
class NeedleBenchATCDataset(BaseDataset):
@staticmethod
def load(path: str,
num_needles: int,
language: str,
repeats: int,
with_circular: bool = True):
"""NeedleBenthATC Dataset.
Args:
path (str): Path of the needlebench dataset.
name (str): Name of the target subset.
with_circular (bool): Whether to create circular dataset for
single choice question. Defaults to True.
"""
data = []
entry = {}
with open(path, 'r', encoding='utf-8') as file:
names_data = json.load(file)
all_names = names_data[language].split(',')
for id in range(repeats):
random.seed(id)
names = random.sample(all_names, num_needles)
if language == 'Chinese':
relationship_terms = [
'父亲', '母亲', '爸爸', '妈妈', '爷爷', '奶奶', '姥姥', '姥爷', '外公', '外婆'
]
relationship_templates = [
'{A}{B}{relationship}',
'{B}{relationship}{A}',
'{A}作为{B}{relationship},对{B}的成长有重要影响。',
'{A}不仅是{B}{relationship},还是{B}的榜样。',
'{B}{A}所生的孩子。',
'{A}{B}来说,不只是一个{relationship},还是一个朋友。',
'{A}{B}的生命中扮演着{relationship}的角色。',
'{B}{A}视为其{relationship}',
]
elif language == 'English':
relationship_terms = [
'father', 'mother', 'dad', 'mom', 'grandfather',
'grandmother', 'maternal grandmother',
'maternal grandfather', 'paternal grandfather',
'paternal grandmother'
]
relationship_templates = [
"{A} is {B}'s {relationship}.",
"{B}'s {relationship} is {A}.",
("{A}, as {B}'s {relationship}, "
"has a significant impact on {B}'s upbringing."),
("{A} is not only {B}'s {relationship} "
"but also {B}'s role model."),
'{B} is the child of {A}.',
('For {B}, {A} is not just a {relationship}, '
'but also a friend.'),
("{A} plays the role of {B}'s {relationship} "
"in {B}'s life."),
'{B} considers {A} as their {relationship}.',
]
def generate_chain_family_story(names, templates,
relationship_terms):
story = ''
for i in range(len(names) - 1):
template = random.choice(templates)
relation_term = random.choice(relationship_terms)
relation = template.format(A=names[i],
B=names[i + 1],
relationship=relation_term)
story += f'{relation}*'
return story
chain_story = generate_chain_family_story(names,
relationship_templates,
relationship_terms)
# Splitting the chain_story into a list of fragments
family_story_fragments = chain_story.split('*')
# Shuffling the list of fragments
random.shuffle(family_story_fragments)
# Joining the shuffled fragments back into a string
shuffled_story = ''.join(family_story_fragments)
last_person = names[-1]
# Generating the prompt based on the language
if language == 'Chinese':
prompt = (f"""
在上面提供的打乱的家族关系文本中'{last_person}'的能够向上追溯到的最年长的亲人是谁""")
elif language == 'English':
prompt = (f"""
Given the scrambled family relationships described above, who is the eldest relative that '{last_person}' can trace back to in the context?"""
)
else:
prompt = 'Language not supported.'
raise Exception('Unsupported language specified. '
"Please choose either 'Chinese' or 'English'.")
# Combine story and prompt
shuffled_story_with_prompt = shuffled_story + ' ' + prompt
entry['question'] = shuffled_story_with_prompt
if len(names) < 4:
additional_names_needed = max(4 - len(names), 0)
additional_names = random.sample(
[name for name in all_names if name not in names],
additional_names_needed)
names.extend(additional_names)
entry['options'] = names[0:4]
entry['answer'] = 'A'
# print(entry)
data.extend(get_circular_example(entry, id))
dataset = Dataset.from_list(data)
return dataset