OpenCompass/opencompass/datasets/ruler/ruler_vt.py
Linchen Xiao a4b54048ae
[Feature] Add Ruler datasets (#1310)
* [Feature] Add Ruler datasets

* pre-commit fixed

* Add model specific tokenizer to dataset

* pre-commit modified

* remove unused import

* fix linting

* add trust_remote to tokenizer load

* lint fix

* comments resolved

* fix lint

* Add readme

* Fix lint

* ruler refactorize

* fix lint

* lint fix

* updated

* lint fix

* fix wonderwords import issue

* prompt modified

* update

* readme updated

* update

* ruler dataset added

* Update

---------

Co-authored-by: tonysy <sy.zhangbuaa@gmail.com>
2024-08-20 11:40:11 +08:00

194 lines
7.7 KiB
Python

# flake8: noqa: F401, E501
import random
import string
import numpy as np
import tiktoken
from datasets import Dataset
from transformers import AutoTokenizer
from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET
@LOAD_DATASET.register_module()
class RulerVtDataset(BaseDataset):
@staticmethod
def load(
max_seq_length: int = 4096,
tokenizer_model: str = 'gpt-4',
template:
str = 'Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above. Answer: According to the chain(s) of variable assignment in the text above, {num_v} variables are assigned the value {query}, they are: ',
tokens_to_generate: int = 30,
num_chains: int = 1,
num_hops: int = 4,
num_samples: int = 500,
random_seed: int = 42,
remove_newline_tab: str = '',
) -> Dataset:
if tokenizer_model == 'gpt-4':
tokenizer = tiktoken.encoding_for_model(tokenizer_model)
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model,
trust_remote_code=True)
random.seed(random_seed)
np.random.seed(random_seed)
def _generate_chains(num_chains, num_hops, is_icl=False):
vars_all = []
k = 5 if not is_icl else 3
num_hops = num_hops if not is_icl else min(10, num_hops)
vars_all = [
''.join(random.choices(string.ascii_uppercase, k=k)).upper()
for _ in range((num_hops + 1) * num_chains)
]
while len(set(vars_all)) < num_chains * (num_hops + 1):
vars_all.append(''.join(
random.choices(string.ascii_uppercase, k=k)).upper())
vars_ret = []
chains_ret = []
for i in range(0, len(vars_all), num_hops + 1):
this_vars = vars_all[i:i + num_hops + 1]
vars_ret.append(this_vars)
this_chain = [
f'VAR {this_vars[0]} = {np.random.randint(10000, 99999)}'
]
for j in range(num_hops):
this_chain.append(
f'VAR {this_vars[j+1]} = VAR {this_vars[j]} ')
chains_ret.append(this_chain)
return vars_ret, chains_ret
def _generate_input_output(num_noises,
num_chains,
num_hops,
is_icl=False):
vars, chains = _generate_chains(num_chains,
num_hops,
is_icl=is_icl)
noise = 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n'
# Create a list of the repeated noise
sentences = [noise] * num_noises
if len(sentences) <= len(chains[0]):
sentences = [
n + '.' if len(n.strip()) > 0 else n for n in
[x for noise in sentences for x in noise.split('.')]
]
try:
assert len(sentences) > len(
chains[0]), 'Noises too short, unable to generate data'
except:
print('reduces chain length for not enough noises')
chains = [chain[:len(sentences) - 1] for chain in chains]
# sample random positions to insert variable assignment
for chain_i in chains:
# sample random positions (sorted) to insert variable assignment
positions = list(
sorted(random.sample(range(len(sentences)), len(chain_i))))
for insert_pi, j in zip(positions, range(len(chain_i))):
sentences.insert(insert_pi + j, chain_i[j])
# Insert the passkey sentence at the random position
context = ' '.join(sentences)
context = context.replace('. \n', '.\n')
# if is_icl:
# # remove model template
# cutoff = template.index(template[:20])
# cutoff_ans = template.index(answer_prefix[:10])
# template = ' '.join(template[cutoff:cutoff_ans].split()[:-1]) + template[cutoff_ans:]
value = chains[0][0].split('=')[-1].strip()
input_text = template.format(context=context,
query=value,
num_v=num_hops + 1)
return input_text, vars[0]
def _sys_vartrack_w_noise_random(
num_samples: int,
max_seq_length: int,
incremental: int = 10,
num_chains: int = 1,
num_hops: int = 4,
):
data = {'prompt': [], 'answer': []}
# Find the perfect num_noises
num_noises = incremental
total_tokens = 0 # Track the total tokens generated for this example
example_tokens = 0
while total_tokens + tokens_to_generate + example_tokens < max_seq_length:
input_text, answer = _generate_input_output(
num_noises, num_chains, num_hops)
# Calculate the number of tokens in the example
total_tokens = len(tokenizer.encode(input_text + f' {answer}'))
print(
f'Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate + example_tokens} | Noises: {num_noises}'
)
if total_tokens + tokens_to_generate + example_tokens > max_seq_length:
num_noises -= incremental
break
num_noises += incremental
print('Num noises:', num_noises)
# Generate samples
for index in range(num_samples):
used_noises = num_noises
while True:
try:
input_text, answer = _generate_input_output(
num_noises, num_chains, num_hops)
length = (len(tokenizer.encode(input_text)) +
tokens_to_generate + example_tokens)
assert (length <= max_seq_length
), f'{length} exceeds max_seq_length.'
break
except:
if used_noises > incremental:
used_noises -= incremental
if remove_newline_tab:
input_text = ' '.join(
input_text.replace('\n',
' ').replace('\t',
' ').strip().split())
data['prompt'].append(input_text)
data['answer'].append(answer)
return data
# Generate Data
data = _sys_vartrack_w_noise_random(
num_samples=num_samples,
max_seq_length=max_seq_length,
num_chains=num_chains,
num_hops=num_hops,
)
dataset = Dataset.from_dict(data)
return dataset
class RulerVtEvaluator(BaseEvaluator):
def score(self, predictions, gold):
score = (sum([
sum([1.0 if r.lower() in pred.lower() else 0.0
for r in ref]) / len(ref)
for pred, ref in zip(predictions, gold)
]) / len(predictions) * 100)
result = {'score': round(score, 2)}
return result