[Feature] Add Ruler datasets (#1310)

* [Feature] Add Ruler datasets

* pre-commit fixed

* Add model specific tokenizer to dataset

* pre-commit modified

* remove unused import

* fix linting

* add trust_remote to tokenizer load

* lint fix

* comments resolved

* fix lint

* Add readme

* Fix lint

* ruler refactorize

* fix lint

* lint fix

* updated

* lint fix

* fix wonderwords import issue

* prompt modified

* update

* readme updated

* update

* ruler dataset added

* Update

---------

Co-authored-by: tonysy <sy.zhangbuaa@gmail.com>
This commit is contained in:
Linchen Xiao 2024-08-20 11:40:11 +08:00 committed by GitHub
parent 99b5122ed5
commit a4b54048ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 2520 additions and 251 deletions

View File

@ -70,9 +70,10 @@ Just like a compass guides us on our journey, OpenCompass will guide you through
## 🚀 What's New <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
- **\[2024.08.16\]** OpenCompass now supports the brand new long-context language model evaluation benchmark — [RULER](https://arxiv.org/pdf/2404.06654). RULER provides an evaluation of long-context including retrieval, multi-hop tracing, aggregation, and question answering through flexible configurations. Check out the [RULER](configs/datasets/ruler/README.md) evaluation config now! 🔥🔥🔥
- **\[2024.08.09\]** We have released the example data and configuration for the CompassBench-202408, welcome to [CompassBench](https://opencompass.readthedocs.io/zh-cn/latest/advanced_guides/compassbench_intro.html) for more details. 🔥🔥🔥
- **\[2024.08.01\]** We supported the [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) models. Welcome to try! 🔥🔥🔥
- **\[2024.07.23\]** We supported the [ModelScope](www.modelscope.cn) datasets, you can load them on demand without downloading all the data to your local disk. Welcome to try! 🔥🔥🔥
- **\[2024.07.17\]** We have released the example data and configuration for the CompassBench-202408, welcome to [CompassBench](https://opencompass.readthedocs.io/zh-cn/latest/advanced_guides/compassbench_intro.html) for more details. 🔥🔥🔥
- **\[2024.07.17\]** We are excited to announce the release of NeedleBench's [technical report](http://arxiv.org/abs/2407.11963). We invite you to visit our [support documentation](https://opencompass.readthedocs.io/en/latest/advanced_guides/needleinahaystack_eval.html) for detailed evaluation guidelines. 🔥🔥🔥
- **\[2024.07.04\]** OpenCompass now supports InternLM2.5, which has **outstanding reasoning capability**, **1M Context window and** and **stronger tool use**, you can try the models in [OpenCompass Config](https://github.com/open-compass/opencompass/tree/main/configs/models/hf_internlm) and [InternLM](https://github.com/InternLM/InternLM) .🔥🔥🔥.
- **\[2024.06.20\]** OpenCompass now supports one-click switching between inference acceleration backends, enhancing the efficiency of the evaluation process. In addition to the default HuggingFace inference backend, it now also supports popular backends [LMDeploy](https://github.com/InternLM/lmdeploy) and [vLLM](https://github.com/vllm-project/vllm). This feature is available via a simple command-line switch and through deployment APIs. For detailed usage, see the [documentation](docs/en/advanced_guides/accelerator_intro.md).🔥🔥🔥.

View File

@ -69,6 +69,7 @@
## 🚀 最新进展 <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
- **\[2024.08.16\]** OpenCompass 现已支持全新的长上下文语言模型评估基准——[RULER](https://arxiv.org/pdf/2404.06654)。RULER 通过灵活的配置,提供了对长上下文包括检索、多跳追踪、聚合和问答等多种任务类型的评测,欢迎访问[RULER](configs/datasets/ruler/README.md)。🔥🔥🔥
- **\[2024.07.23\]** 我们支持了[Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)模型,欢迎试用!🔥🔥🔥
- **\[2024.07.23\]** 我们支持了[ModelScope](www.modelscope.cn)数据集,您可以按需加载,无需事先下载全部数据到本地,欢迎试用!🔥🔥🔥
- **\[2024.07.17\]** 我们发布了CompassBench-202408榜单的示例数据和评测规则敬请访问 [CompassBench](https://opencompass.readthedocs.io/zh-cn/latest/advanced_guides/compassbench_intro.html) 获取更多信息。 🔥🔥🔥

View File

@ -0,0 +1,14 @@
# Ruler
OpenCompass now supports the brand new long-context language model evaluation benchmark — [RULER](https://arxiv.org/pdf/2404.06654). RULER provides an evaluation of long-context including retrieval, multi-hop tracing, aggregation, and question answering through flexible configurations.
OpenCompass have providied two types of evaluation demo for using different tokenizers.
For using the same tokenizer (typicall GPT-4), you can follow the demo (configs/eval_ruler_fix_tokenizer.py) where most of the settings are already defined.
For evaluation using each model's own tokenizer, you have to build the settings when you run the demo (we do not know which model you are trying to evaluate!) you can create a new evaluation script following the example (configs/eval_ruler.py) and change the context window sizes or add models according to your settings.
```bash
python run.py configs/eval_ruler_fix_tokenizer.py # For evaluation with GPT-4 tokenizer
python run.py configs/eval_ruler.py # For evaluation with model's tokenizer
```

View File

@ -0,0 +1,28 @@
from mmengine.config import read_base
with read_base():
from .ruler_niah_gen import niah_datasets # Niah
from .ruler_vt_gen import vt_datasets # VT
from .ruler_fwe_gen import fwe_datasets # FWE
from .ruler_cwe_gen import cwe_datasets # CWE
from .ruler_qa_gen import qa_datasets # QA
import_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# Evaluation config
NUM_SAMPLES = 100 # Change to the number of samples you need
# Change the context lengths to be tested
max_seq_lens = [1024 * 128]
abbr_suffixs = ['128k']
ruler_datasets = []
# Different seq length
for max_seq_len, abbr_suffix in zip(max_seq_lens, abbr_suffixs):
for dataset in import_datasets:
tmp_dataset = dataset.deepcopy()
tmp_dataset['abbr'] = tmp_dataset['abbr'] + '_' + abbr_suffix
tmp_dataset['num_samples'] = NUM_SAMPLES
tmp_dataset['max_seq_length'] = max_seq_len
ruler_datasets.append(tmp_dataset)

View File

@ -0,0 +1,29 @@
from mmengine.config import read_base
with read_base():
from .ruler_niah_gen import niah_datasets # Niah
from .ruler_vt_gen import vt_datasets # VT
from .ruler_fwe_gen import fwe_datasets # FWE
from .ruler_cwe_gen import cwe_datasets # CWE
from .ruler_qa_gen import qa_datasets # QA
import_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# Evaluation config
NUM_SAMPLES = 100 # Change to the number of samples you need
# Change the context lengths to be tested
max_seq_lens = [1024 * 16]
abbr_suffixs = ['16k']
ruler_datasets = []
# Different seq length
for max_seq_len, abbr_suffix in zip(max_seq_lens, abbr_suffixs):
for dataset in import_datasets:
tmp_dataset = dataset.deepcopy()
tmp_dataset['abbr'] = tmp_dataset['abbr'] + '_' + abbr_suffix
tmp_dataset['num_samples'] = NUM_SAMPLES
tmp_dataset['max_seq_length'] = max_seq_len
ruler_datasets.append(tmp_dataset)

View File

@ -0,0 +1,29 @@
from mmengine.config import read_base
with read_base():
from .ruler_niah_gen import niah_datasets # Niah
from .ruler_vt_gen import vt_datasets # VT
from .ruler_fwe_gen import fwe_datasets # FWE
from .ruler_cwe_gen import cwe_datasets # CWE
from .ruler_qa_gen import qa_datasets # QA
import_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# Evaluation config
NUM_SAMPLES = 100 # Change to the number of samples you need
# Change the context lengths to be tested
max_seq_lens = [1024 * 1024]
abbr_suffixs = ['1m']
ruler_datasets = []
# Different seq length
for max_seq_len, abbr_suffix in zip(max_seq_lens, abbr_suffixs):
for dataset in import_datasets:
tmp_dataset = dataset.deepcopy()
tmp_dataset['abbr'] = tmp_dataset['abbr'] + '_' + abbr_suffix
tmp_dataset['num_samples'] = NUM_SAMPLES
tmp_dataset['max_seq_length'] = max_seq_len
ruler_datasets.append(tmp_dataset)

View File

@ -0,0 +1,29 @@
from mmengine.config import read_base
with read_base():
from .ruler_niah_gen import niah_datasets # Niah
from .ruler_vt_gen import vt_datasets # VT
from .ruler_fwe_gen import fwe_datasets # FWE
from .ruler_cwe_gen import cwe_datasets # CWE
from .ruler_qa_gen import qa_datasets # QA
import_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# Evaluation config
NUM_SAMPLES = 100 # Change to the number of samples you need
# Change the context lengths to be tested
max_seq_lens = [1024 * 32]
abbr_suffixs = ['32k']
ruler_datasets = []
# Different seq length
for max_seq_len, abbr_suffix in zip(max_seq_lens, abbr_suffixs):
for dataset in import_datasets:
tmp_dataset = dataset.deepcopy()
tmp_dataset['abbr'] = tmp_dataset['abbr'] + '_' + abbr_suffix
tmp_dataset['num_samples'] = NUM_SAMPLES
tmp_dataset['max_seq_length'] = max_seq_len
ruler_datasets.append(tmp_dataset)

View File

@ -0,0 +1,28 @@
from mmengine.config import read_base
with read_base():
from .ruler_niah_gen import niah_datasets # Niah
from .ruler_vt_gen import vt_datasets # VT
from .ruler_fwe_gen import fwe_datasets # FWE
from .ruler_cwe_gen import cwe_datasets # CWE
from .ruler_qa_gen import qa_datasets # QA
import_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# Evaluation config
NUM_SAMPLES = 100 # Change to the number of samples you need
# Change the context lengths to be tested
max_seq_lens = [1024 * 4]
abbr_suffixs = ['4k']
ruler_datasets = []
# Different seq length
for max_seq_len, abbr_suffix in zip(max_seq_lens, abbr_suffixs):
for dataset in import_datasets:
tmp_dataset = dataset.deepcopy()
tmp_dataset['abbr'] = tmp_dataset['abbr'] + '_' + abbr_suffix
tmp_dataset['num_samples'] = NUM_SAMPLES
tmp_dataset['max_seq_length'] = max_seq_len
ruler_datasets.append(tmp_dataset)

View File

@ -0,0 +1,29 @@
from mmengine.config import read_base
with read_base():
from .ruler_niah_gen import niah_datasets # Niah
from .ruler_vt_gen import vt_datasets # VT
from .ruler_fwe_gen import fwe_datasets # FWE
from .ruler_cwe_gen import cwe_datasets # CWE
from .ruler_qa_gen import qa_datasets # QA
import_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# Evaluation config
NUM_SAMPLES = 100 # Change to the number of samples you need
# Change the context lengths to be tested
max_seq_lens = [1024 * 8]
abbr_suffixs = ['8k']
ruler_datasets = []
# Different seq length
for max_seq_len, abbr_suffix in zip(max_seq_lens, abbr_suffixs):
for dataset in import_datasets:
tmp_dataset = dataset.deepcopy()
tmp_dataset['abbr'] = tmp_dataset['abbr'] + '_' + abbr_suffix
tmp_dataset['num_samples'] = NUM_SAMPLES
tmp_dataset['max_seq_length'] = max_seq_len
ruler_datasets.append(tmp_dataset)

View File

@ -0,0 +1,13 @@
from mmengine.config import read_base
with read_base():
from .ruler_4k_gen import ruler_datasets as ruler_4k_datasets
from .ruler_8k_gen import ruler_datasets as ruler_8k_datasets
from .ruler_16k_gen import ruler_datasets as ruler_16k_datasets
from .ruler_32k_gen import ruler_datasets as ruler_32k_datasets
from .ruler_128k_gen import ruler_datasets as ruler_128k_datasets
from .ruler_1m_gen import ruler_datasets as ruler_1m_datasets
ruler_combined_datasets = sum(
(v for k, v in locals().items() if k.endswith('_datasets')), []
)

View File

@ -0,0 +1,34 @@
from opencompass.datasets.ruler.ruler_cwe import RulerCweDataset
from opencompass.datasets.ruler.ruler_cwe import RulerCweEvaluator
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
# CWE Dataset
cwe_datasets = [
{
'abbr': 'ruler_cwe',
'type': RulerCweDataset,
'freq_cw': 30,
'freq_ucw': 3,
'num_cw': 10,
'tokens_to_generate': 120,
'reader_cfg': dict(input_columns=['prompt'], output_column='answer'),
'infer_cfg': dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{prompt}'),
dict(role='BOT', prompt='{answer}\n'),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
),
'eval_cfg': dict(
evaluator=dict(type=RulerCweEvaluator),
),
}
]

View File

@ -0,0 +1,33 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets.ruler.ruler_fwe import RulerFweDataset
from opencompass.datasets.ruler.ruler_fwe import RulerFweEvaluator
# FWE Dataset
fwe_datasets = [
{
'abbr': 'ruler_fwe',
'type': RulerFweDataset,
'tokens_to_generate': 50,
'alpha': 2.0,
'coded_wordlen': 6,
'reader_cfg': dict(input_columns=['prompt'], output_column='answer'),
'infer_cfg': dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{prompt}'),
dict(role='BOT', prompt='{answer}\n'),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
),
'eval_cfg': dict(
evaluator=dict(type=RulerFweEvaluator),
),
}
]

View File

@ -0,0 +1,123 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets.ruler.ruler_niah import RulerNiahDataset
from opencompass.datasets.ruler.ruler_niah import RulerNiahEvaluator
# Ruler Dataset settings
niah_configurations = [
{
'abbr': 'single_1',
'type_haystack': 'repeat',
'type_needle_k': 'words',
'type_needle_v': 'numbers',
'num_needle_k': 1,
'num_needle_v': 1,
'num_needle_q': 1,
},
{
'abbr': 'single_2',
'type_haystack': 'essay',
'type_needle_k': 'words',
'type_needle_v': 'numbers',
'num_needle_k': 1,
'num_needle_v': 1,
'num_needle_q': 1,
},
{
'abbr': 'single_3',
'type_haystack': 'essay',
'type_needle_k': 'words',
'type_needle_v': 'uuids',
'num_needle_k': 1,
'num_needle_v': 1,
'num_needle_q': 1,
},
{
'abbr': 'multikey_1',
'type_haystack': 'essay',
'type_needle_k': 'words',
'type_needle_v': 'numbers',
'num_needle_k': 4,
'num_needle_v': 1,
'num_needle_q': 1,
},
{
'abbr': 'multikey_2',
'type_haystack': 'needle',
'type_needle_k': 'words',
'type_needle_v': 'numbers',
'num_needle_k': 1,
'num_needle_v': 1,
'num_needle_q': 1,
},
{
'abbr': 'multikey_3',
'type_haystack': 'needle',
'type_needle_k': 'uuids',
'type_needle_v': 'uuids',
'num_needle_k': 1,
'num_needle_v': 1,
'num_needle_q': 1,
},
{
'abbr': 'multivalue',
'type_haystack': 'essay',
'type_needle_k': 'words',
'type_needle_v': 'numbers',
'num_needle_k': 1,
'num_needle_v': 4,
'num_needle_q': 1,
},
{
'abbr': 'multiquery',
'type_haystack': 'essay',
'type_needle_k': 'words',
'type_needle_v': 'numbers',
'num_needle_k': 1,
'num_needle_v': 1,
'num_needle_q': 4,
},
]
niah_datasets = []
# NIAH Dataset
base_path = './data/ruler'
file_path = 'PaulGrahamEssays.jsonl'
for index, config in enumerate(niah_configurations):
dataset_dict = {
'abbr': f'ruler_niah_{config["abbr"]}',
'type': RulerNiahDataset,
'base_path': base_path,
'file_path': file_path,
# 'tokenizer_model': model_path,
'tokens_to_generate': 128,
# 'max_seq_length': max_seq_len,
# 'num_samples': NUM_SAMPLES,
'type_haystack': config['type_haystack'],
'type_needle_k': config['type_needle_k'],
'type_needle_v': config['type_needle_v'],
'num_needle_k': config['num_needle_k'],
'num_needle_v': config['num_needle_v'],
'num_needle_q': config['num_needle_q'],
'reader_cfg': dict(input_columns=['prompt'], output_column='answer'),
'infer_cfg': dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{prompt}'),
dict(role='BOT', prompt='{answer}\n'),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
),
'eval_cfg': dict(
evaluator=dict(type=RulerNiahEvaluator),
),
}
niah_datasets.append(dataset_dict)

View File

@ -0,0 +1,38 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets.ruler.ruler_qa import RulerQaDataset
from opencompass.datasets.ruler.ruler_qa import RulerQaEvaluator
qa_configurations = [
{'dataset': 'squad', 'path': './data/ruler/dev-v2.0.json'},
{'dataset': 'hotpotqa', 'path': './data/ruler/hotpotqa.json'},
]
qa_datasets = []
for index, config in enumerate(qa_configurations):
dataset_dict = {
'abbr': f'ruler_qa_{config["dataset"]}',
'dataset': config['dataset'],
'path': config['path'],
'type': RulerQaDataset,
'tokens_to_generate': 50,
'reader_cfg': dict(input_columns=['prompt'], output_column='answer'),
'infer_cfg': dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{prompt}'),
dict(role='BOT', prompt='{answer}\n'),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
),
'eval_cfg': dict(
evaluator=dict(type=RulerQaEvaluator),
),
}
qa_datasets.append(dataset_dict)

View File

@ -0,0 +1,32 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets.ruler.ruler_vt import RulerVtDataset
from opencompass.datasets.ruler.ruler_vt import RulerVtEvaluator
# VT Dataset
vt_datasets = [
{
'abbr': 'ruler_vt',
'type': RulerVtDataset,
'num_chains': 1,
'num_hops': 4,
'reader_cfg': dict(input_columns=['prompt'], output_column='answer'),
'infer_cfg': dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{prompt}'),
dict(role='BOT', prompt='{answer}\n'),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
),
'eval_cfg': dict(
evaluator=dict(type=RulerVtEvaluator),
),
}
]

100
configs/eval_ruler.py Normal file
View File

@ -0,0 +1,100 @@
from opencompass.partitioners import (
NaivePartitioner,
NumWorkerPartitioner,
)
from mmengine.config import read_base
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask
with read_base():
from ..configs.models.qwen.lmdeploy_qwen2_7b_instruct import (
models as qwen2_7b_instruct_model,
)
from ..configs.models.hf_llama.lmdeploy_llama3_8b_instruct import (
models as llama3_8b_instruct_model,
)
from ..configs.models.hf_internlm.lmdeploy_internlm2_5_7b_chat_1m import (
models as internlm2_5_7b_chat_1m,
)
from .datasets.ruler.ruler_niah_gen import niah_datasets # Niah
from .datasets.ruler.ruler_vt_gen import vt_datasets # VT
from .datasets.ruler.ruler_fwe_gen import fwe_datasets # FWE
from .datasets.ruler.ruler_cwe_gen import cwe_datasets # CWE
from .datasets.ruler.ruler_qa_gen import qa_datasets # QA
from ..configs.summarizers.groups.ruler import ruler_summary_groups
import_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# Evaluation config
NUM_SAMPLES = 500
# Change the context lengths to be tested
max_seq_lens = [1024 * 4, 1024 * 8, 1024 * 16, 1024 * 32]
abbr_suffixs = ['4k', '8k', '16k', '32k']
work_dir = './outputs/ruler'
# Model Settings
qwen2_7b_instruct_model[0]['max_seq_len'] = 33792
qwen2_7b_instruct_model[0]['engine_config']['session_len'] = 33792
qwen2_7b_instruct_model[0]['engine_config']['tp'] = 2
qwen2_7b_instruct_model[0]['run_cfg']['num_gpus'] = 2
llama3_8b_instruct_model[0]['max_seq_len'] = 33792
llama3_8b_instruct_model[0]['engine_config']['session_len'] = 33792
llama3_8b_instruct_model[0]['engine_config']['tp'] = 2
llama3_8b_instruct_model[0]['run_cfg']['num_gpus'] = 2
model_settings = [
[qwen2_7b_instruct_model[0], 'Qwen/Qwen2-7B-Instruct'],
[llama3_8b_instruct_model[0], 'meta-llama/Meta-Llama-3-8B-Instruct'],
[internlm2_5_7b_chat_1m[0], 'internlm/internlm2_5-7b-chat-1m'],
]
# Dataset Model Combination
datasets = []
models = []
model_dataset_combinations = []
# Different seq length
for max_seq_len, abbr_suffix in zip(max_seq_lens, abbr_suffixs):
for model, model_path in model_settings:
_tmp_datasets = []
for dataset in import_datasets:
tmp_dataset = dataset.deepcopy()
tmp_dataset['tokenizer_model'] = model_path
tmp_dataset['abbr'] = tmp_dataset['abbr'] + '_' + abbr_suffix
tmp_dataset['num_samples'] = NUM_SAMPLES
tmp_dataset['max_seq_length'] = max_seq_len
_tmp_datasets.append(tmp_dataset)
model_dataset_combinations.append(dict(models=[model], datasets=_tmp_datasets))
models.append(model)
datasets.extend(_tmp_datasets)
infer = dict(
partitioner=dict(type=NumWorkerPartitioner),
runner=dict(
type=LocalRunner, max_num_workers=16, task=dict(type=OpenICLInferTask), retry=5
),
)
eval = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(type=LocalRunner, max_num_workers=32, task=dict(type=OpenICLEvalTask)),
)
summarizer = dict(
dataset_abbrs=abbr_suffixs,
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# dataset version metric mode qwen2-7b-instruct-turbomind llama-3-8b-instruct-turbomind internlm2_5-7b-chat-1m-turbomind
# --------- --------- ------------- ------ ----------------------------- ------------------------------- ----------------------------------
# 4k - naive_average gen 93.66 93.48 91.20
# 8k - naive_average gen 88.38 89.95 89.07
# 16k - naive_average gen 84.27 0.14 87.61
# 32k - naive_average gen 81.36 0.00 84.59
# $$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$

View File

@ -0,0 +1,38 @@
from opencompass.partitioners import (
NaivePartitioner,
NumWorkerPartitioner,
)
from mmengine.config import read_base
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask
with read_base():
from ..configs.models.hf_internlm.lmdeploy_internlm2_5_7b_chat_1m import (
models as internlm2_5_7b_chat_1m,
)
from .datasets.ruler.ruler_combined_gen import ruler_combined_datasets
from ..configs.summarizers.groups.ruler import ruler_summary_groups
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
models = internlm2_5_7b_chat_1m
work_dir = './outputs/ruler'
infer = dict(
partitioner=dict(type=NumWorkerPartitioner, num_worker=2),
runner=dict(
type=LocalRunner, max_num_workers=16, task=dict(type=OpenICLInferTask), retry=5
),
)
eval = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(type=LocalRunner, max_num_workers=32, task=dict(type=OpenICLEvalTask)),
)
summarizer = dict(
dataset_abbrs=['ruler_4k', 'ruler_8k', 'ruler_16k', 'ruler_32k'],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)

View File

@ -0,0 +1,27 @@
default_ruler_tasks = [
'ruler_niah_single_1',
'ruler_niah_single_2',
'ruler_niah_single_3',
'ruler_niah_multikey_1',
'ruler_niah_multikey_2',
'ruler_niah_multikey_3',
'ruler_niah_multivalue',
'ruler_niah_multiquery',
'ruler_vt',
'ruler_fwe',
'ruler_cwe',
'ruler_qa_squad',
'ruler_qa_hotpotqa',
]
context_window_sizes = ['4k', '8k', '16k', '32k', '128k', '1m']
ruler_summary_groups = []
for context_window_size in context_window_sizes:
ruler_summary_groups.append(
{
'name': f'ruler_{context_window_size}',
'subsets': [
f'{task}_{context_window_size}' for task in default_ruler_tasks
],
}
)

View File

@ -0,0 +1,65 @@
from mmengine.config import read_base
with read_base():
from .groups.ruler import ruler_summary_groups
ruler_4k_summarizer = dict(
dataset_abbrs=['ruler_4k'],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)
ruler_4k_summarizer = dict(
dataset_abbrs=['ruler_4k'],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)
ruler_8k_summarizer = dict(
dataset_abbrs=['ruler_8k'],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)
ruler_16k_summarizer = dict(
dataset_abbrs=['ruler_16k'],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)
ruler_32k_summarizer = dict(
dataset_abbrs=['ruler_32k'],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)
ruler_128k_summarizer = dict(
dataset_abbrs=['ruler_128k'],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)
ruler_1m_summarizer = dict(
dataset_abbrs=['ruler_1m'],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)
ruler_combined_summarizer = dict(
dataset_abbrs=[
'ruler_4k',
'ruler_8k',
'ruler_16k',
'ruler_32k',
'ruler_128k',
'ruler_1m',
],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)

View File

@ -0,0 +1,14 @@
# Ruler
OpenCompass now supports the brand new long-context language model evaluation benchmark — [RULER](https://arxiv.org/pdf/2404.06654). RULER provides an evaluation of long-context including retrieval, multi-hop tracing, aggregation, and question answering through flexible configurations.
OpenCompass have providied two types of evaluation demo for using different tokenizers.
For using the same tokenizer (typicall GPT-4), you can follow the demo (configs/eval_ruler_fix_tokenizer.py) where most of the settings are already defined.
For evaluation using each model's own tokenizer, you have to build the settings when you run the demo (we do not know which model you are trying to evaluate!) you can create a new evaluation script following the example (configs/eval_ruler.py) and change the context window sizes or add models according to your settings.
```bash
python run.py configs/eval_ruler_fix_tokenizer.py # For evaluation with GPT-4 tokenizer
python run.py configs/eval_ruler.py # For evaluation with model's tokenizer
```

View File

@ -0,0 +1,28 @@
from mmengine.config import read_base
with read_base():
from .ruler_niah_gen import niah_datasets # Niah
from .ruler_vt_gen import vt_datasets # VT
from .ruler_fwe_gen import fwe_datasets # FWE
from .ruler_cwe_gen import cwe_datasets # CWE
from .ruler_qa_gen import qa_datasets # QA
import_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# Evaluation config
NUM_SAMPLES = 100 # Change to the number of samples you need
# Change the context lengths to be tested
max_seq_lens = [1024 * 128]
abbr_suffixs = ['128k']
ruler_datasets = []
# Different seq length
for max_seq_len, abbr_suffix in zip(max_seq_lens, abbr_suffixs):
for dataset in import_datasets:
tmp_dataset = dataset.deepcopy()
tmp_dataset['abbr'] = tmp_dataset['abbr'] + '_' + abbr_suffix
tmp_dataset['num_samples'] = NUM_SAMPLES
tmp_dataset['max_seq_length'] = max_seq_len
ruler_datasets.append(tmp_dataset)

View File

@ -0,0 +1,29 @@
from mmengine.config import read_base
with read_base():
from .ruler_niah_gen import niah_datasets # Niah
from .ruler_vt_gen import vt_datasets # VT
from .ruler_fwe_gen import fwe_datasets # FWE
from .ruler_cwe_gen import cwe_datasets # CWE
from .ruler_qa_gen import qa_datasets # QA
import_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# Evaluation config
NUM_SAMPLES = 100 # Change to the number of samples you need
# Change the context lengths to be tested
max_seq_lens = [1024 * 16]
abbr_suffixs = ['16k']
ruler_datasets = []
# Different seq length
for max_seq_len, abbr_suffix in zip(max_seq_lens, abbr_suffixs):
for dataset in import_datasets:
tmp_dataset = dataset.deepcopy()
tmp_dataset['abbr'] = tmp_dataset['abbr'] + '_' + abbr_suffix
tmp_dataset['num_samples'] = NUM_SAMPLES
tmp_dataset['max_seq_length'] = max_seq_len
ruler_datasets.append(tmp_dataset)

View File

@ -0,0 +1,29 @@
from mmengine.config import read_base
with read_base():
from .ruler_niah_gen import niah_datasets # Niah
from .ruler_vt_gen import vt_datasets # VT
from .ruler_fwe_gen import fwe_datasets # FWE
from .ruler_cwe_gen import cwe_datasets # CWE
from .ruler_qa_gen import qa_datasets # QA
import_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# Evaluation config
NUM_SAMPLES = 100 # Change to the number of samples you need
# Change the context lengths to be tested
max_seq_lens = [1024 * 1024]
abbr_suffixs = ['1m']
ruler_datasets = []
# Different seq length
for max_seq_len, abbr_suffix in zip(max_seq_lens, abbr_suffixs):
for dataset in import_datasets:
tmp_dataset = dataset.deepcopy()
tmp_dataset['abbr'] = tmp_dataset['abbr'] + '_' + abbr_suffix
tmp_dataset['num_samples'] = NUM_SAMPLES
tmp_dataset['max_seq_length'] = max_seq_len
ruler_datasets.append(tmp_dataset)

View File

@ -0,0 +1,29 @@
from mmengine.config import read_base
with read_base():
from .ruler_niah_gen import niah_datasets # Niah
from .ruler_vt_gen import vt_datasets # VT
from .ruler_fwe_gen import fwe_datasets # FWE
from .ruler_cwe_gen import cwe_datasets # CWE
from .ruler_qa_gen import qa_datasets # QA
import_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# Evaluation config
NUM_SAMPLES = 100 # Change to the number of samples you need
# Change the context lengths to be tested
max_seq_lens = [1024 * 32]
abbr_suffixs = ['32k']
ruler_datasets = []
# Different seq length
for max_seq_len, abbr_suffix in zip(max_seq_lens, abbr_suffixs):
for dataset in import_datasets:
tmp_dataset = dataset.deepcopy()
tmp_dataset['abbr'] = tmp_dataset['abbr'] + '_' + abbr_suffix
tmp_dataset['num_samples'] = NUM_SAMPLES
tmp_dataset['max_seq_length'] = max_seq_len
ruler_datasets.append(tmp_dataset)

View File

@ -0,0 +1,28 @@
from mmengine.config import read_base
with read_base():
from .ruler_niah_gen import niah_datasets # Niah
from .ruler_vt_gen import vt_datasets # VT
from .ruler_fwe_gen import fwe_datasets # FWE
from .ruler_cwe_gen import cwe_datasets # CWE
from .ruler_qa_gen import qa_datasets # QA
import_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# Evaluation config
NUM_SAMPLES = 100 # Change to the number of samples you need
# Change the context lengths to be tested
max_seq_lens = [1024 * 4]
abbr_suffixs = ['4k']
ruler_datasets = []
# Different seq length
for max_seq_len, abbr_suffix in zip(max_seq_lens, abbr_suffixs):
for dataset in import_datasets:
tmp_dataset = dataset.deepcopy()
tmp_dataset['abbr'] = tmp_dataset['abbr'] + '_' + abbr_suffix
tmp_dataset['num_samples'] = NUM_SAMPLES
tmp_dataset['max_seq_length'] = max_seq_len
ruler_datasets.append(tmp_dataset)

View File

@ -0,0 +1,29 @@
from mmengine.config import read_base
with read_base():
from .ruler_niah_gen import niah_datasets # Niah
from .ruler_vt_gen import vt_datasets # VT
from .ruler_fwe_gen import fwe_datasets # FWE
from .ruler_cwe_gen import cwe_datasets # CWE
from .ruler_qa_gen import qa_datasets # QA
import_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# Evaluation config
NUM_SAMPLES = 100 # Change to the number of samples you need
# Change the context lengths to be tested
max_seq_lens = [1024 * 8]
abbr_suffixs = ['8k']
ruler_datasets = []
# Different seq length
for max_seq_len, abbr_suffix in zip(max_seq_lens, abbr_suffixs):
for dataset in import_datasets:
tmp_dataset = dataset.deepcopy()
tmp_dataset['abbr'] = tmp_dataset['abbr'] + '_' + abbr_suffix
tmp_dataset['num_samples'] = NUM_SAMPLES
tmp_dataset['max_seq_length'] = max_seq_len
ruler_datasets.append(tmp_dataset)

View File

@ -0,0 +1,13 @@
from mmengine.config import read_base
with read_base():
from .ruler_4k_gen import ruler_datasets as ruler_4k_datasets
from .ruler_8k_gen import ruler_datasets as ruler_8k_datasets
from .ruler_16k_gen import ruler_datasets as ruler_16k_datasets
from .ruler_32k_gen import ruler_datasets as ruler_32k_datasets
from .ruler_128k_gen import ruler_datasets as ruler_128k_datasets
from .ruler_1m_gen import ruler_datasets as ruler_1m_datasets
ruler_combined_datasets = sum(
(v for k, v in locals().items() if k.endswith('_datasets')), []
)

View File

@ -0,0 +1,34 @@
from opencompass.datasets.ruler.ruler_cwe import RulerCweDataset
from opencompass.datasets.ruler.ruler_cwe import RulerCweEvaluator
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
# CWE Dataset
cwe_datasets = [
{
'abbr': 'ruler_cwe',
'type': RulerCweDataset,
'freq_cw': 30,
'freq_ucw': 3,
'num_cw': 10,
'tokens_to_generate': 120,
'reader_cfg': dict(input_columns=['prompt'], output_column='answer'),
'infer_cfg': dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{prompt}'),
dict(role='BOT', prompt='{answer}\n'),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
),
'eval_cfg': dict(
evaluator=dict(type=RulerCweEvaluator),
),
}
]

View File

@ -0,0 +1,33 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets.ruler.ruler_fwe import RulerFweDataset
from opencompass.datasets.ruler.ruler_fwe import RulerFweEvaluator
# FWE Dataset
fwe_datasets = [
{
'abbr': 'ruler_fwe',
'type': RulerFweDataset,
'tokens_to_generate': 50,
'alpha': 2.0,
'coded_wordlen': 6,
'reader_cfg': dict(input_columns=['prompt'], output_column='answer'),
'infer_cfg': dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{prompt}'),
dict(role='BOT', prompt='{answer}\n'),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
),
'eval_cfg': dict(
evaluator=dict(type=RulerFweEvaluator),
),
}
]

View File

@ -0,0 +1,123 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets.ruler.ruler_niah import RulerNiahDataset
from opencompass.datasets.ruler.ruler_niah import RulerNiahEvaluator
# Ruler Dataset settings
niah_configurations = [
{
'abbr': 'single_1',
'type_haystack': 'repeat',
'type_needle_k': 'words',
'type_needle_v': 'numbers',
'num_needle_k': 1,
'num_needle_v': 1,
'num_needle_q': 1,
},
{
'abbr': 'single_2',
'type_haystack': 'essay',
'type_needle_k': 'words',
'type_needle_v': 'numbers',
'num_needle_k': 1,
'num_needle_v': 1,
'num_needle_q': 1,
},
{
'abbr': 'single_3',
'type_haystack': 'essay',
'type_needle_k': 'words',
'type_needle_v': 'uuids',
'num_needle_k': 1,
'num_needle_v': 1,
'num_needle_q': 1,
},
{
'abbr': 'multikey_1',
'type_haystack': 'essay',
'type_needle_k': 'words',
'type_needle_v': 'numbers',
'num_needle_k': 4,
'num_needle_v': 1,
'num_needle_q': 1,
},
{
'abbr': 'multikey_2',
'type_haystack': 'needle',
'type_needle_k': 'words',
'type_needle_v': 'numbers',
'num_needle_k': 1,
'num_needle_v': 1,
'num_needle_q': 1,
},
{
'abbr': 'multikey_3',
'type_haystack': 'needle',
'type_needle_k': 'uuids',
'type_needle_v': 'uuids',
'num_needle_k': 1,
'num_needle_v': 1,
'num_needle_q': 1,
},
{
'abbr': 'multivalue',
'type_haystack': 'essay',
'type_needle_k': 'words',
'type_needle_v': 'numbers',
'num_needle_k': 1,
'num_needle_v': 4,
'num_needle_q': 1,
},
{
'abbr': 'multiquery',
'type_haystack': 'essay',
'type_needle_k': 'words',
'type_needle_v': 'numbers',
'num_needle_k': 1,
'num_needle_v': 1,
'num_needle_q': 4,
},
]
niah_datasets = []
# NIAH Dataset
base_path = './data/ruler'
file_path = 'PaulGrahamEssays.jsonl'
for index, config in enumerate(niah_configurations):
dataset_dict = {
'abbr': f'ruler_niah_{config["abbr"]}',
'type': RulerNiahDataset,
'base_path': base_path,
'file_path': file_path,
# 'tokenizer_model': model_path,
'tokens_to_generate': 128,
# 'max_seq_length': max_seq_len,
# 'num_samples': NUM_SAMPLES,
'type_haystack': config['type_haystack'],
'type_needle_k': config['type_needle_k'],
'type_needle_v': config['type_needle_v'],
'num_needle_k': config['num_needle_k'],
'num_needle_v': config['num_needle_v'],
'num_needle_q': config['num_needle_q'],
'reader_cfg': dict(input_columns=['prompt'], output_column='answer'),
'infer_cfg': dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{prompt}'),
dict(role='BOT', prompt='{answer}\n'),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
),
'eval_cfg': dict(
evaluator=dict(type=RulerNiahEvaluator),
),
}
niah_datasets.append(dataset_dict)

View File

@ -0,0 +1,38 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets.ruler.ruler_qa import RulerQaDataset
from opencompass.datasets.ruler.ruler_qa import RulerQaEvaluator
qa_configurations = [
{'dataset': 'squad', 'path': './data/ruler/dev-v2.0.json'},
{'dataset': 'hotpotqa', 'path': './data/ruler/hotpotqa.json'},
]
qa_datasets = []
for index, config in enumerate(qa_configurations):
dataset_dict = {
'abbr': f'ruler_qa_{config["dataset"]}',
'dataset': config['dataset'],
'path': config['path'],
'type': RulerQaDataset,
'tokens_to_generate': 50,
'reader_cfg': dict(input_columns=['prompt'], output_column='answer'),
'infer_cfg': dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{prompt}'),
dict(role='BOT', prompt='{answer}\n'),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
),
'eval_cfg': dict(
evaluator=dict(type=RulerQaEvaluator),
),
}
qa_datasets.append(dataset_dict)

View File

@ -0,0 +1,32 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets.ruler.ruler_vt import RulerVtDataset
from opencompass.datasets.ruler.ruler_vt import RulerVtEvaluator
# VT Dataset
vt_datasets = [
{
'abbr': 'ruler_vt',
'type': RulerVtDataset,
'num_chains': 1,
'num_hops': 4,
'reader_cfg': dict(input_columns=['prompt'], output_column='answer'),
'infer_cfg': dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{prompt}'),
dict(role='BOT', prompt='{answer}\n'),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
),
'eval_cfg': dict(
evaluator=dict(type=RulerVtEvaluator),
),
}
]

View File

@ -0,0 +1,27 @@
default_ruler_tasks = [
'ruler_niah_single_1',
'ruler_niah_single_2',
'ruler_niah_single_3',
'ruler_niah_multikey_1',
'ruler_niah_multikey_2',
'ruler_niah_multikey_3',
'ruler_niah_multivalue',
'ruler_niah_multiquery',
'ruler_vt',
'ruler_fwe',
'ruler_cwe',
'ruler_qa_squad',
'ruler_qa_hotpotqa',
]
context_window_sizes = ['4k', '8k', '16k', '32k', '128k', '1m']
ruler_summary_groups = []
for context_window_size in context_window_sizes:
ruler_summary_groups.append(
{
'name': f'ruler_{context_window_size}',
'subsets': [
f'{task}_{context_window_size}' for task in default_ruler_tasks
],
}
)

View File

@ -0,0 +1,65 @@
from mmengine.config import read_base
with read_base():
from .groups.ruler import ruler_summary_groups
ruler_4k_summarizer = dict(
dataset_abbrs=['ruler_4k'],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)
ruler_4k_summarizer = dict(
dataset_abbrs=['ruler_4k'],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)
ruler_8k_summarizer = dict(
dataset_abbrs=['ruler_8k'],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)
ruler_16k_summarizer = dict(
dataset_abbrs=['ruler_16k'],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)
ruler_32k_summarizer = dict(
dataset_abbrs=['ruler_32k'],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)
ruler_128k_summarizer = dict(
dataset_abbrs=['ruler_128k'],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)
ruler_1m_summarizer = dict(
dataset_abbrs=['ruler_1m'],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)
ruler_combined_summarizer = dict(
dataset_abbrs=[
'ruler_4k',
'ruler_8k',
'ruler_16k',
'ruler_32k',
'ruler_128k',
'ruler_1m',
],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []
),
)

View File

@ -96,6 +96,7 @@ from .race import * # noqa: F401, F403
from .realtoxicprompts import * # noqa: F401, F403
from .reasonbench import ReasonBenchDataset # noqa: F401, F403
from .record import * # noqa: F401, F403
from .ruler import * # noqa: F401, F403
from .safety import * # noqa: F401, F403
from .scibench import ScibenchDataset, scibench_postprocess # noqa: F401, F403
from .siqa import * # noqa: F401, F403

View File

View File

@ -0,0 +1,171 @@
# flake8: noqa: F401, E501
import random
import numpy as np
import tiktoken
from datasets import Dataset
from transformers import AutoTokenizer
from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET
@LOAD_DATASET.register_module()
class RulerCweDataset(BaseDataset):
@staticmethod
def load(
max_seq_length: int = 4096,
tokenizer_model: str = 'gpt-4',
template:
str = 'Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list? Answer: The top 10 words that appear most often in the list are:',
tokens_to_generate: int = 120,
freq_cw: int = 30,
freq_ucw: int = 3,
num_cw: int = 10,
num_samples: int = 500,
random_seed: int = 42,
remove_newline_tab: str = '',
) -> Dataset:
if tokenizer_model == 'gpt-4':
tokenizer = tiktoken.encoding_for_model(tokenizer_model)
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model,
trust_remote_code=True)
random.seed(random_seed)
np.random.seed(random_seed)
try:
import wonderwords
except ImportError:
raise ImportError('''Please install wonderwords by:
pip install wonderwords''')
nouns = wonderwords.random_word._get_words_from_text_file(
'nounlist.txt')
adjs = wonderwords.random_word._get_words_from_text_file(
'adjectivelist.txt')
verbs = wonderwords.random_word._get_words_from_text_file(
'verblist.txt')
words = nouns + adjs + verbs
words = sorted(list(set(words)))
random.Random(random_seed).shuffle(words)
def _get_example(num_words,
common_repeats=30,
uncommon_repeats=3,
common_nums=10):
word_list_full = random.sample(words, num_words)
common, uncommon = (
word_list_full[:common_nums],
word_list_full[common_nums:],
)
word_list = common * int(common_repeats) + uncommon * int(
uncommon_repeats)
random.Random(random_seed).shuffle(word_list)
# Formatting the word list as "1. word1 2. word2 3. word3 ..."
context = ' '.join(
[f'{i + 1}. {word}' for i, word in enumerate(word_list)])
return context, common
def _generate_input_output(num_words):
if max_seq_length < 4096:
context_example, answer_example = _get_example(
20, 3, 1, num_cw)
context, answer = _get_example(num_words, 6, 1, num_cw)
else:
context_example, answer_example = _get_example(
40, 10, 3, num_cw)
context, answer = _get_example(num_words, freq_cw, freq_ucw,
num_cw)
input_example = template.format(
context=context_example,
query='',
) + ' '.join(
[f'{i + 1}. {word}' for i, word in enumerate(answer_example)])
input_text = template.format(
context=context,
query='',
)
return input_example + '\n' + input_text, answer
def _sys_word_pair_random(num_samples: int,
max_seq_length: int,
incremental: int = 10):
data = {'prompt': [], 'answer': []}
# Find the perfect num_words
num_words = incremental
total_tokens = 0
while total_tokens + tokens_to_generate < max_seq_length:
input_text, answer = _generate_input_output(num_words)
# Calculate the number of tokens in the example
total_tokens = len(
tokenizer.encode(input_text + ' ' + ' '.join(
[f'{i + 1}. {word}'
for i, word in enumerate(answer)])))
print(
f'Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Words: {num_words}'
)
if total_tokens + tokens_to_generate > max_seq_length:
num_words -= incremental
break
num_words += incremental
if num_words > len(words):
num_words = len(words)
break
print('num_words:', num_words)
# Generate samples
for index in range(num_samples):
used_words = num_words
while True:
try:
input_text, answer = _generate_input_output(used_words)
length = len(
tokenizer.encode(input_text)) + tokens_to_generate
assert (length <= max_seq_length
), f'{length} exceeds max_seq_length.'
break
except:
if used_words > incremental:
used_words -= incremental
if remove_newline_tab:
input_text = ' '.join(
input_text.replace('\n',
' ').replace('\t',
' ').strip().split())
data['prompt'].append(input_text)
data['answer'].append(answer)
return data
# Generate Data
data = _sys_word_pair_random(num_samples=num_samples,
max_seq_length=max_seq_length)
dataset = Dataset.from_dict(data)
return dataset
class RulerCweEvaluator(BaseEvaluator):
def score(self, predictions, gold):
score = (sum([
sum([1.0 if r.lower() in pred.lower() else 0.0
for r in ref]) / len(ref)
for pred, ref in zip(predictions, gold)
]) / len(predictions) * 100)
result = {'score': round(score, 2)}
return result

View File

@ -0,0 +1,161 @@
# flake8: noqa: F401, E501
import random
import string
import numpy as np
import tiktoken
from datasets import Dataset
from scipy.special import zeta
from transformers import AutoTokenizer
from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET
@LOAD_DATASET.register_module()
class RulerFweDataset(BaseDataset):
@staticmethod
def load(
max_seq_length: int = 4096,
tokenizer_model: str = 'gpt-4',
template:
str = "Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text? Answer: According to the coded text above, the three most frequently appeared words are:",
tokens_to_generate: int = 50,
alpha: float = 2.0,
coded_wordlen: int = 6,
num_samples: int = 500,
random_seed: int = 42,
remove_newline_tab: str = '',
vocab_size: int = -1,
) -> Dataset:
if tokenizer_model == 'gpt-4':
tokenizer = tiktoken.encoding_for_model(tokenizer_model)
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model,
trust_remote_code=True)
random.seed(random_seed)
np.random.seed(random_seed)
def _generate_input_output(
max_len,
num_words=-1,
coded_wordlen=6,
vocab_size=2000,
incremental=10,
alpha=2.0,
):
# generate vocab
vocab = [
''.join(random.choices(string.ascii_lowercase,
k=coded_wordlen))
for _ in range(vocab_size)
]
while len(set(vocab)) < vocab_size:
vocab.append(''.join(
random.choices(string.ascii_lowercase, k=coded_wordlen)))
vocab = sorted(list(set(vocab)))
random.Random(random_seed).shuffle(vocab)
vocab[0] = '...' # treat the top ranked as noise
# sample words
def gen_text(num_words):
k = np.arange(1, len(vocab) + 1)
sampled_cnt = num_words * (k**-alpha) / zeta(alpha)
sampled_words = [
[w] * zi for w, zi in zip(vocab, sampled_cnt.astype(int))
]
sampled_words = [x for wlst in sampled_words for x in wlst]
random.Random(random_seed).shuffle(sampled_words)
return (
template.format(context=' '.join(sampled_words), query=''),
vocab[1:4],
)
if num_words > 0:
num_words = num_words
text, answer = gen_text(num_words)
while len(tokenizer.encode(text)) > max_len:
num_words -= incremental
text, answer = gen_text(num_words)
else:
num_words = max_len // coded_wordlen # init
text, answer = gen_text(num_words)
while len(tokenizer.encode(text)) < max_len:
num_words += incremental
text, answer = gen_text(num_words)
num_words -= incremental
text, answer = gen_text(num_words)
return text, answer, num_words
def _sys_kwext(
num_samples: int,
max_seq_length: int,
vocab_size: int = -1,
incremental: int = 10,
):
data = {'prompt': [], 'answer': []}
vocab_size = max_seq_length // 50 if vocab_size == -1 else vocab_size
# get number of words
input_max_len = max_seq_length
_, _, num_example_words = _generate_input_output(
input_max_len,
coded_wordlen=coded_wordlen,
vocab_size=vocab_size,
incremental=input_max_len // 32,
alpha=alpha,
)
print('num_example_words:', num_example_words)
# Generate samples
for index in range(num_samples):
# construct input
input_max_len = max_seq_length
input_text, answer, _ = _generate_input_output(
input_max_len,
num_words=num_example_words,
coded_wordlen=coded_wordlen,
vocab_size=vocab_size,
incremental=input_max_len // 32,
alpha=alpha,
)
length = len(tokenizer.encode(input_text)) + tokens_to_generate
if remove_newline_tab:
input_text = ' '.join(
input_text.replace('\n',
' ').replace('\t',
' ').strip().split())
data['prompt'].append(input_text)
data['answer'].append(answer)
return data
# Generate Data
data = _sys_kwext(
num_samples=num_samples,
max_seq_length=max_seq_length,
vocab_size=vocab_size,
incremental=10,
)
dataset = Dataset.from_dict(data)
return dataset
class RulerFweEvaluator(BaseEvaluator):
def score(self, predictions, gold):
score = (sum([
sum([1.0 if r.lower() in pred.lower() else 0.0
for r in ref]) / len(ref)
for pred, ref in zip(predictions, gold)
]) / len(predictions) * 100)
result = {'score': round(score, 2)}
return result

View File

@ -0,0 +1,265 @@
# flake8: noqa: F401, E501
import json
import os
import random
import re
import uuid
from pathlib import Path
import numpy as np
import tiktoken
from datasets import Dataset
from transformers import AutoTokenizer
from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET
from opencompass.utils import get_data_path
@LOAD_DATASET.register_module()
class RulerNiahDataset(BaseDataset):
@staticmethod
def load(
base_path: str,
file_path: str,
tokens_to_generate: int = 128,
max_seq_length: int = 4096,
tokenizer_model: str = 'gpt-4',
num_samples: int = 500,
random_seed: int = 42,
template:
str = 'Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text? The special magic {type_needle_v} for {query} mentioned in the provided text are',
num_needle_k: int = 1,
num_needle_v: int = 1,
num_needle_q: int = 1,
type_haystack: str = 'essay',
type_needle_k: str = 'words',
type_needle_v: str = 'numbers',
remove_newline_tab: str = '',
) -> Dataset:
data = {'prompt': [], 'answer': []}
if tokenizer_model == 'gpt-4':
tokenizer = tiktoken.encoding_for_model(tokenizer_model)
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model,
trust_remote_code=True)
random.seed(random_seed)
np.random.seed(random_seed)
num_needle_k = max(num_needle_k, num_needle_q)
# Define Needle/Haystack Format
needle = 'One of the special magic {type_needle_v} for {key} is: {value}.'
if type_haystack == 'essay':
essay = os.path.join(base_path, file_path)
essay = get_data_path(essay, local_mode=True)
# essay = json.load(open(essay))['text']
combined_essay = ''
with open(essay, 'r', encoding='utf-8') as f:
for line in f:
line_text = json.loads(line.strip()).get('text',
'').strip()
combined_essay += line_text + ' '
haystack = re.sub(r'\s+', ' ', combined_essay).split(' ')
elif type_haystack == 'repeat':
haystack = 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.'
elif type_haystack == 'needle':
haystack = needle
else:
raise NotImplementedError(f'{type_haystack} is not implemented.')
try:
import wonderwords
except ImportError:
raise ImportError('''Please install wonderwords by:
pip install wonderwords''')
# Words
nouns = wonderwords.random_word._get_words_from_text_file(
'nounlist.txt')
adjs = wonderwords.random_word._get_words_from_text_file(
'adjectivelist.txt')
# verbs = wonderwords.random_word._get_words_from_text_file("verblist.txt")
words = [f'{adj}-{noun}' for adj in adjs for noun in nouns]
words = sorted(list(set(words)))
# Positions
DEPTHS = list(
np.round(np.linspace(0, 100, num=40, endpoint=True)).astype(int))
def _generate_random_number(num_digits=7):
lower_bound = 10**(num_digits - 1)
upper_bound = 10**num_digits - 1
return str(random.randint(lower_bound, upper_bound))
def _generate_random_word():
word = random.choice(words)
return word
def _generate_random_uuid():
return str(uuid.UUID(int=random.getrandbits(128), version=4))
def _generate_random(type_needle: str):
if type_needle == 'numbers':
return _generate_random_number()
elif type_needle == 'words':
return _generate_random_word()
elif type_needle == 'uuids':
return _generate_random_uuid()
else:
raise NotImplementedError(f'{type_needle} is not implemented.')
def _generate_input_output(num_haystack, type_needle_v, template):
keys, values, needles = [], [], []
for _ in range(num_needle_k):
keys.append(_generate_random(type_needle_k))
value = []
for _ in range(num_needle_v):
value.append(_generate_random(type_needle_v))
needles.append(
needle.format(
type_needle_v=type_needle_v,
key=keys[-1],
value=value[-1],
))
values.append(value)
random.Random(random_seed).shuffle(needles)
# Context
if type_haystack == 'essay':
text = ' '.join(haystack[:num_haystack])
# document_sents = sent_tokenize(text.strip())
document_sents = text.split('. ')
document_sents = [
sentence.strip() for sentence in document_sents if sentence
] # remove possible whitespace
insertion_positions = ([0] + sorted([
int(len(document_sents) * (depth / 100))
for depth in random.sample(DEPTHS, len(needles))
]) + [len(document_sents)])
document_sents_list = []
for i in range(1, len(insertion_positions)):
last_pos = insertion_positions[i - 1]
next_pos = insertion_positions[i]
document_sents_list.append(' '.join(
document_sents[last_pos:next_pos]))
if i - 1 < len(needles):
document_sents_list.append(needles[i - 1])
context = ' '.join(document_sents_list)
else:
if type_haystack == 'repeat':
sentences = [haystack] * num_haystack
elif type_haystack == 'needle':
sentences = [
haystack.format(
type_needle_v=type_needle_v,
key=_generate_random(type_needle_k),
value=_generate_random(type_needle_v),
) for _ in range(num_haystack)
]
indexes = sorted(random.sample(range(num_haystack),
len(needles)),
reverse=True)
for index, element in zip(indexes, needles):
sentences.insert(index, element)
context = '\n'.join(sentences)
## Query and Answer
indices = random.sample(range(num_needle_k), num_needle_q)
queries = [keys[i] for i in indices]
answers = [a for i in indices for a in values[i]]
query = (', '.join(queries[:-1]) + ', and ' +
queries[-1] if len(queries) > 1 else queries[0])
if num_needle_q * num_needle_v == 1:
template = template.replace('Some', 'A')
template = template.replace('are all', 'is')
template = template.replace('are', 'is')
template = template.replace('answers', 'answer')
type_needle_v = type_needle_v[:-1] # remove "s"
input_text = template.format(
type_needle_v=type_needle_v,
context=context,
query=query,
)
return input_text, answers
# Generate Samples
if type_haystack == 'essay':
incremental = 500
elif type_haystack == 'repeat':
incremental = 25
elif type_haystack == 'needle':
incremental = 25
if type_haystack != 'essay' and max_seq_length < 4096:
incremental = 5
num_haystack = incremental
print('Num haystack:', num_haystack)
total_tokens = 0 # Track the total tokens generated for the first example
while total_tokens + tokens_to_generate < max_seq_length:
input_text, answer = _generate_input_output(
num_haystack, type_needle_v, template)
# Calculate the number of tokens in the example
total_tokens = len(tokenizer.encode(input_text + ' '.join(answer)))
print(
f'Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Haystack: {num_haystack}'
)
if total_tokens + tokens_to_generate > max_seq_length:
num_haystack -= incremental
break
if type_haystack == 'essay' and num_haystack > len(haystack):
num_haystack = len(haystack)
break
num_haystack += incremental
# Generate samples
for index in range(num_samples):
used_haystack = num_haystack
while True:
try:
input_text, answer = _generate_input_output(
used_haystack, type_needle_v, template)
length = len(
tokenizer.encode(input_text)) + tokens_to_generate
assert length <= max_seq_length, f'{length} exceeds max_seq_length.'
break
except:
if used_haystack > incremental:
used_haystack -= incremental
if remove_newline_tab:
input_text = ' '.join(
input_text.replace('\n', ' ').replace('\t',
' ').strip().split())
data['prompt'].append(input_text)
data['answer'].append(answer)
dataset = Dataset.from_dict({
'prompt': data['prompt'],
'answer': data['answer'],
})
return dataset
class RulerNiahEvaluator(BaseEvaluator):
def score(self, predictions, gold):
score = (sum([
sum([1.0 if r.lower() in pred.lower() else 0.0
for r in ref]) / len(ref)
for pred, ref in zip(predictions, gold)
]) / len(predictions) * 100)
result = {'score': round(score, 2)}
return result

View File

@ -0,0 +1,231 @@
# flake8: noqa: F401, E501
import json
import os
import random
import numpy as np
import requests
import tiktoken
from datasets import Dataset
from transformers import AutoTokenizer
from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET
from opencompass.utils import get_data_path
@LOAD_DATASET.register_module()
class RulerQaDataset(BaseDataset):
@staticmethod
def load(
path: str,
dataset: str = 'squad',
max_seq_length: int = 4096,
tokenizer_model: str = 'gpt-4',
template:
str = 'Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query} Answer:',
tokens_to_generate: int = 32,
num_samples: int = 500,
pre_samples: int = 0,
random_seed: int = 42,
remove_newline_tab: str = '',
) -> Dataset:
if tokenizer_model == 'gpt-4':
tokenizer = tiktoken.encoding_for_model(tokenizer_model)
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model,
trust_remote_code=True)
random.seed(random_seed)
np.random.seed(random_seed)
# Read SQuAD QA dataset
def _read_squad(file):
file = get_data_path(file, local_mode=True)
with open(file) as f:
data = json.load(f)
total_docs = [
p['context'] for d in data['data'] for p in d['paragraphs']
]
total_docs = sorted(list(set(total_docs)))
total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}
total_qas = []
for d in data['data']:
more_docs = [
total_docs_dict[p['context']] for p in d['paragraphs']
]
for p in d['paragraphs']:
for qas in p['qas']:
if not qas['is_impossible']:
total_qas.append({
'query':
qas['question'],
'outputs': [a['text'] for a in qas['answers']],
'context': [total_docs_dict[p['context']]],
'more_context': [
idx for idx in more_docs
if idx != total_docs_dict[p['context']]
],
})
return total_qas, total_docs
# Read Hotpot QA dataset
def _read_hotpotqa(file_path):
# url = (
# 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json'
# )
# if not os.path.exists(os.path.dirname(file_path)):
# os.makedirs(os.path.dirname(file_path))
# if not os.path.exists(file_path):
# response = requests.get(url)
# if response.status_code == 200:
# with open(file_path, 'wb') as file:
# file.write(response.content)
# else:
# print(
# f'Failed to download file. Status code: {response.status_code}'
# )
# else:
# print('File already exists.')
file_path = get_data_path(file_path, local_mode=True)
with open(file_path) as f:
data = json.load(f)
total_docs = [
f"{t}\n{''.join(p)}" for d in data for t, p in d['context']
]
total_docs = sorted(list(set(total_docs)))
total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}
total_qas = []
for d in data:
total_qas.append({
'query':
d['question'],
'outputs': [d['answer']],
'context': [
total_docs_dict[f"{t}\n{''.join(p)}"]
for t, p in d['context']
],
})
return total_qas, total_docs
DOCUMENT_PROMPT = 'Document {i}:\n{document}'
if dataset == 'squad':
QAS, DOCS = _read_squad(path)
elif dataset == 'hotpotqa':
QAS, DOCS = _read_hotpotqa(path)
else:
raise NotImplementedError(f'{dataset} is not implemented.')
def generate_input_output(index, num_docs):
curr_q = QAS[index]['query']
curr_a = QAS[index]['outputs']
curr_docs = QAS[index]['context']
curr_more = QAS[index].get('more_context', [])
if num_docs < len(DOCS):
if (num_docs - len(curr_docs)) > len(curr_more):
addition_docs = [
i for i, d in enumerate(DOCS)
if i not in curr_docs + curr_more
]
all_docs = (curr_docs + curr_more + random.sample(
addition_docs,
max(0, num_docs - len(curr_docs) - len(curr_more)),
))
else:
all_docs = curr_docs + random.sample(
curr_more, num_docs - len(curr_docs))
all_docs = [DOCS[idx] for idx in all_docs]
else:
all_docs = DOCS
random.Random(random_seed).shuffle(all_docs)
context = '\n\n'.join([
DOCUMENT_PROMPT.format(i=i + 1, document=d)
for i, d in enumerate(all_docs)
])
input_text = template.format(context=context, query=curr_q)
return input_text, curr_a
def _generate_samples(num_samples: int,
max_seq_length: int,
incremental: int = 10):
data = {'prompt': [], 'answer': []}
# Find the perfect num_docs
num_docs = incremental
total_tokens = 0 # Track the total tokens generated for this example
while total_tokens + tokens_to_generate < max_seq_length:
input_text, answer = generate_input_output(0, num_docs)
# Calculate the number of tokens in the example
total_tokens = len(tokenizer.encode(input_text + f' {answer}'))
print(
f'Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Docs: {num_docs}'
)
if total_tokens + tokens_to_generate > max_seq_length:
num_docs -= incremental
break
num_docs += incremental
if num_docs > len(DOCS):
num_docs = len(DOCS)
break
print('Number of documents:', num_docs)
# Generate samples
for index in range(num_samples):
used_docs = num_docs
while True:
try:
input_text, answer = generate_input_output(
index + pre_samples, used_docs)
length = len(
tokenizer.encode(input_text)) + tokens_to_generate
assert (length <= max_seq_length
), f'{length} exceeds max_seq_length.'
break
except:
if used_docs > incremental:
used_docs -= incremental
if remove_newline_tab:
input_text = ' '.join(
input_text.replace('\n',
' ').replace('\t',
' ').strip().split())
data['prompt'].append(input_text)
data['answer'].append(answer)
return data
# Generate Data
data = _generate_samples(num_samples=num_samples,
max_seq_length=max_seq_length)
dataset = Dataset.from_dict(data)
return dataset
class RulerQaEvaluator(BaseEvaluator):
def score(self, predictions, gold):
score = (sum([
max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref])
for pred, ref in zip(predictions, gold)
]) / len(predictions) * 100)
result = {'score': round(score, 2)}
return result

View File

@ -0,0 +1,193 @@
# flake8: noqa: F401, E501
import random
import string
import numpy as np
import tiktoken
from datasets import Dataset
from transformers import AutoTokenizer
from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET
@LOAD_DATASET.register_module()
class RulerVtDataset(BaseDataset):
@staticmethod
def load(
max_seq_length: int = 4096,
tokenizer_model: str = 'gpt-4',
template:
str = 'Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above. Answer: According to the chain(s) of variable assignment in the text above, {num_v} variables are assigned the value {query}, they are: ',
tokens_to_generate: int = 30,
num_chains: int = 1,
num_hops: int = 4,
num_samples: int = 500,
random_seed: int = 42,
remove_newline_tab: str = '',
) -> Dataset:
if tokenizer_model == 'gpt-4':
tokenizer = tiktoken.encoding_for_model(tokenizer_model)
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model,
trust_remote_code=True)
random.seed(random_seed)
np.random.seed(random_seed)
def _generate_chains(num_chains, num_hops, is_icl=False):
vars_all = []
k = 5 if not is_icl else 3
num_hops = num_hops if not is_icl else min(10, num_hops)
vars_all = [
''.join(random.choices(string.ascii_uppercase, k=k)).upper()
for _ in range((num_hops + 1) * num_chains)
]
while len(set(vars_all)) < num_chains * (num_hops + 1):
vars_all.append(''.join(
random.choices(string.ascii_uppercase, k=k)).upper())
vars_ret = []
chains_ret = []
for i in range(0, len(vars_all), num_hops + 1):
this_vars = vars_all[i:i + num_hops + 1]
vars_ret.append(this_vars)
this_chain = [
f'VAR {this_vars[0]} = {np.random.randint(10000, 99999)}'
]
for j in range(num_hops):
this_chain.append(
f'VAR {this_vars[j+1]} = VAR {this_vars[j]} ')
chains_ret.append(this_chain)
return vars_ret, chains_ret
def _generate_input_output(num_noises,
num_chains,
num_hops,
is_icl=False):
vars, chains = _generate_chains(num_chains,
num_hops,
is_icl=is_icl)
noise = 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n'
# Create a list of the repeated noise
sentences = [noise] * num_noises
if len(sentences) <= len(chains[0]):
sentences = [
n + '.' if len(n.strip()) > 0 else n for n in
[x for noise in sentences for x in noise.split('.')]
]
try:
assert len(sentences) > len(
chains[0]), 'Noises too short, unable to generate data'
except:
print('reduces chain length for not enough noises')
chains = [chain[:len(sentences) - 1] for chain in chains]
# sample random positions to insert variable assignment
for chain_i in chains:
# sample random positions (sorted) to insert variable assignment
positions = list(
sorted(random.sample(range(len(sentences)), len(chain_i))))
for insert_pi, j in zip(positions, range(len(chain_i))):
sentences.insert(insert_pi + j, chain_i[j])
# Insert the passkey sentence at the random position
context = ' '.join(sentences)
context = context.replace('. \n', '.\n')
# if is_icl:
# # remove model template
# cutoff = template.index(template[:20])
# cutoff_ans = template.index(answer_prefix[:10])
# template = ' '.join(template[cutoff:cutoff_ans].split()[:-1]) + template[cutoff_ans:]
value = chains[0][0].split('=')[-1].strip()
input_text = template.format(context=context,
query=value,
num_v=num_hops + 1)
return input_text, vars[0]
def _sys_vartrack_w_noise_random(
num_samples: int,
max_seq_length: int,
incremental: int = 10,
num_chains: int = 1,
num_hops: int = 4,
):
data = {'prompt': [], 'answer': []}
# Find the perfect num_noises
num_noises = incremental
total_tokens = 0 # Track the total tokens generated for this example
example_tokens = 0
while total_tokens + tokens_to_generate + example_tokens < max_seq_length:
input_text, answer = _generate_input_output(
num_noises, num_chains, num_hops)
# Calculate the number of tokens in the example
total_tokens = len(tokenizer.encode(input_text + f' {answer}'))
print(
f'Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate + example_tokens} | Noises: {num_noises}'
)
if total_tokens + tokens_to_generate + example_tokens > max_seq_length:
num_noises -= incremental
break
num_noises += incremental
print('Num noises:', num_noises)
# Generate samples
for index in range(num_samples):
used_noises = num_noises
while True:
try:
input_text, answer = _generate_input_output(
num_noises, num_chains, num_hops)
length = (len(tokenizer.encode(input_text)) +
tokens_to_generate + example_tokens)
assert (length <= max_seq_length
), f'{length} exceeds max_seq_length.'
break
except:
if used_noises > incremental:
used_noises -= incremental
if remove_newline_tab:
input_text = ' '.join(
input_text.replace('\n',
' ').replace('\t',
' ').strip().split())
data['prompt'].append(input_text)
data['answer'].append(answer)
return data
# Generate Data
data = _sys_vartrack_w_noise_random(
num_samples=num_samples,
max_seq_length=max_seq_length,
num_chains=num_chains,
num_hops=num_hops,
)
dataset = Dataset.from_dict(data)
return dataset
class RulerVtEvaluator(BaseEvaluator):
def score(self, predictions, gold):
score = (sum([
sum([1.0 if r.lower() in pred.lower() else 0.0
for r in ref]) / len(ref)
for pred, ref in zip(predictions, gold)
]) / len(predictions) * 100)
result = {'score': round(score, 2)}
return result

View File

@ -55,17 +55,19 @@ def get_data_path(dataset_id: str, local_mode: bool = False):
def download_dataset(data_path, cache_dir, remove_finished=True):
get_logger().info(f'{data_path} does not exist!'
'Start Download data automatically!'
'If you have downloaded the data before,'
'You can specific `COMPASS_DATA_CACHE` '
'to avoid downloading~')
# Try to load from default cache folder
try_default_path = os.path.join(DEFAULT_DATA_FOLDER, data_path)
if os.path.exists(try_default_path):
get_logger().info(f"Try to load the data from {try_default_path}")
return try_default_path
get_logger().info(f'{data_path} does not exist!'
'Start Download data automatically!'
'If you have downloaded the data before,'
'You can specific `COMPASS_DATA_CACHE` '
'to avoid downloading~')
# Cannot find data from default cache folder, download data.
# Update DATASET_URL for internal dataset
try:
@ -76,7 +78,7 @@ def download_dataset(data_path, cache_dir, remove_finished=True):
with open(file_path, 'r') as f:
internal_datasets_info = json.load(f)
DATASETS_URL.update(internal_datasets_info)
get_logger().info("Load internal dataset from: {file_path}")
get_logger().info(f"Load internal dataset from: {file_path}")
except Exception as e: # noqa
pass

View File

@ -1,357 +1,359 @@
DATASETS_MAPPING = {
# ADVGLUE Datasets
'opencompass/advglue-dev': {
'ms_id': None,
'hf_id': None,
'local': './data/adv_glue/dev_ann.json',
"opencompass/advglue-dev": {
"ms_id": None,
"hf_id": None,
"local": "./data/adv_glue/dev_ann.json",
},
# AGIEval Datasets
'opencompass/agieval': {
'ms_id': 'opencompass/agieval',
'hf_id': 'opencompass/agieval',
'local': './data/AGIEval/data/v1/',
"opencompass/agieval": {
"ms_id": "opencompass/agieval",
"hf_id": "opencompass/agieval",
"local": "./data/AGIEval/data/v1/",
},
# ARC Datasets(Test)
'opencompass/ai2_arc-test': {
'ms_id': 'opencompass/ai2_arc',
'hf_id': 'opencompass/ai2_arc',
'local': './data/ARC/ARC-c/ARC-Challenge-Test.jsonl',
"opencompass/ai2_arc-test": {
"ms_id": "opencompass/ai2_arc",
"hf_id": "opencompass/ai2_arc",
"local": "./data/ARC/ARC-c/ARC-Challenge-Test.jsonl",
},
'opencompass/ai2_arc-dev': {
'ms_id': 'opencompass/ai2_arc',
'hf_id': 'opencompass/ai2_arc',
'local': './data/ARC/ARC-c/ARC-Challenge-Dev.jsonl',
"opencompass/ai2_arc-dev": {
"ms_id": "opencompass/ai2_arc",
"hf_id": "opencompass/ai2_arc",
"local": "./data/ARC/ARC-c/ARC-Challenge-Dev.jsonl",
},
'opencompass/ai2_arc-easy-dev': {
'ms_id': 'opencompass/ai2_arc',
'hf_id': 'opencompass/ai2_arc',
'local': './data/ARC/ARC-e/ARC-Easy-Dev.jsonl',
"opencompass/ai2_arc-easy-dev": {
"ms_id": "opencompass/ai2_arc",
"hf_id": "opencompass/ai2_arc",
"local": "./data/ARC/ARC-e/ARC-Easy-Dev.jsonl",
},
# BBH
'opencompass/bbh': {
'ms_id': 'opencompass/bbh',
'hf_id': 'opencompass/bbh',
'local': './data/BBH/data',
"opencompass/bbh": {
"ms_id": "opencompass/bbh",
"hf_id": "opencompass/bbh",
"local": "./data/BBH/data",
},
# C-Eval
'opencompass/ceval-exam': {
'ms_id': 'opencompass/ceval-exam',
'hf_id': 'opencompass/ceval-exam',
'local': './data/ceval/formal_ceval',
"opencompass/ceval-exam": {
"ms_id": "opencompass/ceval-exam",
"hf_id": "opencompass/ceval-exam",
"local": "./data/ceval/formal_ceval",
},
# AFQMC
'opencompass/afqmc-dev': {
'ms_id': 'opencompass/afqmc',
'hf_id': 'opencompass/afqmc',
'local': './data/CLUE/AFQMC/dev.json',
"opencompass/afqmc-dev": {
"ms_id": "opencompass/afqmc",
"hf_id": "opencompass/afqmc",
"local": "./data/CLUE/AFQMC/dev.json",
},
# CMNLI
'opencompass/cmnli-dev': {
'ms_id': 'opencompass/cmnli',
'hf_id': 'opencompass/cmnli',
'local': './data/CLUE/cmnli/cmnli_public/dev.json',
"opencompass/cmnli-dev": {
"ms_id": "opencompass/cmnli",
"hf_id": "opencompass/cmnli",
"local": "./data/CLUE/cmnli/cmnli_public/dev.json",
},
# OCNLI
'opencompass/OCNLI-dev': {
'ms_id': 'opencompass/OCNLI',
'hf_id': 'opencompass/OCNLI',
'local': './data/CLUE/OCNLI/dev.json',
"opencompass/OCNLI-dev": {
"ms_id": "opencompass/OCNLI",
"hf_id": "opencompass/OCNLI",
"local": "./data/CLUE/OCNLI/dev.json",
},
# ChemBench
'opencompass/ChemBench': {
'ms_id': 'opencompass/ChemBench',
'hf_id': 'opencompass/ChemBench',
'local': './data/ChemBench/',
"opencompass/ChemBench": {
"ms_id": "opencompass/ChemBench",
"hf_id": "opencompass/ChemBench",
"local": "./data/ChemBench/",
},
# CMMLU
'opencompass/cmmlu': {
'ms_id': 'opencompass/cmmlu',
'hf_id': 'opencompass/cmmlu',
'local': './data/cmmlu/',
"opencompass/cmmlu": {
"ms_id": "opencompass/cmmlu",
"hf_id": "opencompass/cmmlu",
"local": "./data/cmmlu/",
},
# CommonsenseQA
'opencompass/commonsense_qa': {
'ms_id': 'opencompass/commonsense_qa',
'hf_id': 'opencompass/commonsense_qa',
'local': './data/commonsenseqa',
"opencompass/commonsense_qa": {
"ms_id": "opencompass/commonsense_qa",
"hf_id": "opencompass/commonsense_qa",
"local": "./data/commonsenseqa",
},
# CMRC
'opencompass/cmrc_dev': {
'ms_id': 'opencompass/cmrc_dev',
'hf_id': 'opencompass/cmrc_dev',
'local': './data/CLUE/CMRC/dev.json'
"opencompass/cmrc_dev": {
"ms_id": "opencompass/cmrc_dev",
"hf_id": "opencompass/cmrc_dev",
"local": "./data/CLUE/CMRC/dev.json",
},
# DRCD_dev
'opencompass/drcd_dev': {
'ms_id': 'opencompass/drcd_dev',
'hf_id': 'opencompass/drcd_dev',
'local': './data/CLUE/DRCD/dev.json'
"opencompass/drcd_dev": {
"ms_id": "opencompass/drcd_dev",
"hf_id": "opencompass/drcd_dev",
"local": "./data/CLUE/DRCD/dev.json",
},
# clozeTest_maxmin
'opencompass/clozeTest_maxmin': {
'ms_id': None,
'hf_id': None,
'local': './data/clozeTest-maxmin/python/clozeTest.json',
"opencompass/clozeTest_maxmin": {
"ms_id": None,
"hf_id": None,
"local": "./data/clozeTest-maxmin/python/clozeTest.json",
},
# clozeTest_maxmin
'opencompass/clozeTest_maxmin_answers': {
'ms_id': None,
'hf_id': None,
'local': './data/clozeTest-maxmin/python/answers.txt',
"opencompass/clozeTest_maxmin_answers": {
"ms_id": None,
"hf_id": None,
"local": "./data/clozeTest-maxmin/python/answers.txt",
},
# Flores
'opencompass/flores': {
'ms_id': 'opencompass/flores',
'hf_id': 'opencompass/flores',
'local': './data/flores_first100',
"opencompass/flores": {
"ms_id": "opencompass/flores",
"hf_id": "opencompass/flores",
"local": "./data/flores_first100",
},
# MBPP
'opencompass/mbpp': {
'ms_id': 'opencompass/mbpp',
'hf_id': 'opencompass/mbpp',
'local': './data/mbpp/mbpp.jsonl',
"opencompass/mbpp": {
"ms_id": "opencompass/mbpp",
"hf_id": "opencompass/mbpp",
"local": "./data/mbpp/mbpp.jsonl",
},
# 'opencompass/mbpp': {
# 'ms_id': 'opencompass/mbpp',
# 'hf_id': 'opencompass/mbpp',
# 'local': './data/mbpp/mbpp.jsonl',
# },
'opencompass/sanitized_mbpp': {
'ms_id': 'opencompass/mbpp',
'hf_id': 'opencompass/mbpp',
'local': './data/mbpp/sanitized-mbpp.jsonl',
"opencompass/sanitized_mbpp": {
"ms_id": "opencompass/mbpp",
"hf_id": "opencompass/mbpp",
"local": "./data/mbpp/sanitized-mbpp.jsonl",
},
# GSM
'opencompass/gsm8k': {
'ms_id': 'opencompass/gsm8k',
'hf_id': 'opencompass/gsm8k',
'local': './data/gsm8k/',
"opencompass/gsm8k": {
"ms_id": "opencompass/gsm8k",
"hf_id": "opencompass/gsm8k",
"local": "./data/gsm8k/",
},
# HellaSwag
'opencompass/hellaswag': {
'ms_id': 'opencompass/hellaswag',
'hf_id': 'opencompass/hellaswag',
'local': './data/hellaswag/hellaswag.jsonl',
"opencompass/hellaswag": {
"ms_id": "opencompass/hellaswag",
"hf_id": "opencompass/hellaswag",
"local": "./data/hellaswag/hellaswag.jsonl",
},
# HellaSwagICE
'opencompass/hellaswag_ice': {
'ms_id': 'opencompass/hellaswag',
'hf_id': 'opencompass/hellaswag',
'local': './data/hellaswag/',
"opencompass/hellaswag_ice": {
"ms_id": "opencompass/hellaswag",
"hf_id": "opencompass/hellaswag",
"local": "./data/hellaswag/",
},
# HumanEval
'opencompass/humaneval': {
'ms_id': 'opencompass/humaneval',
'hf_id': 'opencompass/humaneval',
'local': './data/humaneval/human-eval-v2-20210705.jsonl',
"opencompass/humaneval": {
"ms_id": "opencompass/humaneval",
"hf_id": "opencompass/humaneval",
"local": "./data/humaneval/human-eval-v2-20210705.jsonl",
},
# HumanEvalCN
'opencompass/humaneval_cn': {
'ms_id': 'opencompass/humaneval',
'hf_id': 'opencompass/humaneval',
'local': './data/humaneval_cn/human-eval-cn-v2-20210705.jsonl',
"opencompass/humaneval_cn": {
"ms_id": "opencompass/humaneval",
"hf_id": "opencompass/humaneval",
"local": "./data/humaneval_cn/human-eval-cn-v2-20210705.jsonl",
},
# Lambada
'opencompass/lambada': {
'ms_id': 'opencompass/lambada',
'hf_id': 'opencompass/lambada',
'local': './data/lambada/test.jsonl',
"opencompass/lambada": {
"ms_id": "opencompass/lambada",
"hf_id": "opencompass/lambada",
"local": "./data/lambada/test.jsonl",
},
# LCSTS
'opencompass/LCSTS': {
'ms_id': 'opencompass/LCSTS',
'hf_id': 'opencompass/LCSTS',
'local': './data/LCSTS',
"opencompass/LCSTS": {
"ms_id": "opencompass/LCSTS",
"hf_id": "opencompass/LCSTS",
"local": "./data/LCSTS",
},
# MATH
'opencompass/math': {
'ms_id': 'opencompass/math',
'hf_id': 'opencompass/math',
'local': './data/math/math.json',
"opencompass/math": {
"ms_id": "opencompass/math",
"hf_id": "opencompass/math",
"local": "./data/math/math.json",
},
# MMLU
'opencompass/mmlu': {
'ms_id': 'opencompass/mmlu',
'hf_id': 'opencompass/mmlu',
'local': './data/mmlu/',
"opencompass/mmlu": {
"ms_id": "opencompass/mmlu",
"hf_id": "opencompass/mmlu",
"local": "./data/mmlu/",
},
# NQ
'opencompass/natural_question': {
'ms_id': 'opencompass/natural_question',
'hf_id': 'opencompass/natural_question',
'local': './data/nq/',
"opencompass/natural_question": {
"ms_id": "opencompass/natural_question",
"hf_id": "opencompass/natural_question",
"local": "./data/nq/",
},
# OpenBook QA-test
'opencompass/openbookqa_test': {
'ms_id': 'opencompass/openbookqa',
'hf_id': 'opencompass/openbookqa',
'local': './data/openbookqa/Main/test.jsonl',
"opencompass/openbookqa_test": {
"ms_id": "opencompass/openbookqa",
"hf_id": "opencompass/openbookqa",
"local": "./data/openbookqa/Main/test.jsonl",
},
# OpenBook QA-fact
'opencompass/openbookqa_fact': {
'ms_id': 'opencompass/openbookqa',
'hf_id': 'opencompass/openbookqa',
'local': './data/openbookqa/Additional/test_complete.jsonl',
"opencompass/openbookqa_fact": {
"ms_id": "opencompass/openbookqa",
"hf_id": "opencompass/openbookqa",
"local": "./data/openbookqa/Additional/test_complete.jsonl",
},
# PIQA
'opencompass/piqa': {
'ms_id': 'opencompass/piqa',
'hf_id': 'opencompass/piqa',
'local': './data/piqa',
"opencompass/piqa": {
"ms_id": "opencompass/piqa",
"hf_id": "opencompass/piqa",
"local": "./data/piqa",
},
# RACE
'opencompass/race': {
'ms_id': 'opencompass/race',
'hf_id': 'opencompass/race',
'local': './data/race',
"opencompass/race": {
"ms_id": "opencompass/race",
"hf_id": "opencompass/race",
"local": "./data/race",
},
# SIQA
'opencompass/siqa': {
'ms_id': 'opencompass/siqa',
'hf_id': 'opencompass/siqa',
'local': './data/siqa',
"opencompass/siqa": {
"ms_id": "opencompass/siqa",
"hf_id": "opencompass/siqa",
"local": "./data/siqa",
},
# XStoryCloze
'opencompass/xstory_cloze': {
'ms_id': 'opencompass/xstory_cloze',
'hf_id': 'opencompass/xstory_cloze',
'local': './data/xstory_cloze',
"opencompass/xstory_cloze": {
"ms_id": "opencompass/xstory_cloze",
"hf_id": "opencompass/xstory_cloze",
"local": "./data/xstory_cloze",
},
# StrategyQA
'opencompass/strategy_qa': {
'ms_id': 'opencompass/strategy_qa',
'hf_id': 'opencompass/strategy_qa',
'local': './data/strategyqa/strategyQA_train.json',
"opencompass/strategy_qa": {
"ms_id": "opencompass/strategy_qa",
"hf_id": "opencompass/strategy_qa",
"local": "./data/strategyqa/strategyQA_train.json",
},
# SummEdits
'opencompass/summedits': {
'ms_id': 'opencompass/summedits',
'hf_id': 'opencompass/summedits',
'local': './data/summedits/summedits.jsonl',
"opencompass/summedits": {
"ms_id": "opencompass/summedits",
"hf_id": "opencompass/summedits",
"local": "./data/summedits/summedits.jsonl",
},
# TriviaQA
'opencompass/trivia_qa': {
'ms_id': 'opencompass/trivia_qa',
'hf_id': 'opencompass/trivia_qa',
'local': './data/triviaqa/',
"opencompass/trivia_qa": {
"ms_id": "opencompass/trivia_qa",
"hf_id": "opencompass/trivia_qa",
"local": "./data/triviaqa/",
},
# TydiQA
'opencompass/tydiqa': {
'ms_id': 'opencompass/tydiqa',
'hf_id': 'opencompass/tydiqa',
'local': './data/tydiqa/',
"opencompass/tydiqa": {
"ms_id": "opencompass/tydiqa",
"hf_id": "opencompass/tydiqa",
"local": "./data/tydiqa/",
},
# Winogrande
'opencompass/winogrande': {
'ms_id': 'opencompass/winogrande',
'hf_id': 'opencompass/winogrande',
'local': './data/winogrande/',
"opencompass/winogrande": {
"ms_id": "opencompass/winogrande",
"hf_id": "opencompass/winogrande",
"local": "./data/winogrande/",
},
# XSum
'opencompass/xsum': {
'ms_id': 'opencompass/xsum',
'hf_id': 'opencompass/xsum',
'local': './data/Xsum/dev.jsonl',
}
"opencompass/xsum": {
"ms_id": "opencompass/xsum",
"hf_id": "opencompass/xsum",
"local": "./data/Xsum/dev.jsonl",
},
}
DATASETS_URL = {
'/mmlu/': {
'url':
'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mmlu.zip',
'md5': '761310671509a239e41c4b717f7fab9c',
"/mmlu/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mmlu.zip",
"md5": "761310671509a239e41c4b717f7fab9c",
},
'/gpqa/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/gpqa.zip',
'md5': '2e9657959030a765916f1f2aca29140d'
"/gpqa/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/gpqa.zip",
"md5": "2e9657959030a765916f1f2aca29140d",
},
'/CHARM/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/CHARM.zip',
'md5': 'fdf51e955d1b8e0bb35bc1997eaf37cb'
"/CHARM/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/CHARM.zip",
"md5": "fdf51e955d1b8e0bb35bc1997eaf37cb",
},
'/ifeval/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/ifeval.zip',
'md5': '64d98b6f36b42e7390c9cef76cace75f'
"/ifeval/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/ifeval.zip",
"md5": "64d98b6f36b42e7390c9cef76cace75f",
},
'/mbpp/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mbpp.zip',
'md5': '777739c90f04bce44096a5bc96c8f9e5'
"/mbpp/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mbpp.zip",
"md5": "777739c90f04bce44096a5bc96c8f9e5",
},
'/cmmlu/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/cmmlu.zip',
'md5': 'a59f4003d6918509a719ce3bc2a5d5bc'
"/cmmlu/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/cmmlu.zip",
"md5": "a59f4003d6918509a719ce3bc2a5d5bc",
},
'/math/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/math.zip',
'md5': '8b1b897259684672055e6fd4fc07c808'
"/math/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/math.zip",
"md5": "8b1b897259684672055e6fd4fc07c808",
},
'/hellaswag/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/hellaswag.zip',
'md5': '2b700a02ffb58571c7df8d8d0619256f'
"/hellaswag/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/hellaswag.zip",
"md5": "2b700a02ffb58571c7df8d8d0619256f",
},
'/BBH/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/BBH.zip',
'md5': '60c49f9bef5148aa7e1941328e96a554'
"/BBH/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/BBH.zip",
"md5": "60c49f9bef5148aa7e1941328e96a554",
},
'/mmlu/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mmlu.zip',
'md5': '761310671509a239e41c4b717f7fab9c'
"/mmlu/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mmlu.zip",
"md5": "761310671509a239e41c4b717f7fab9c",
},
'/compass_arena/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/compass_arena.zip',
'md5': 'cd59b54a179d16f2a858b359b60588f6'
"/compass_arena/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/compass_arena.zip",
"md5": "cd59b54a179d16f2a858b359b60588f6",
},
'/TheoremQA/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/TheoremQA.zip',
'md5': 'f2793b07bc26510d507aa710d9bd8622'
"/TheoremQA/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/TheoremQA.zip",
"md5": "f2793b07bc26510d507aa710d9bd8622",
},
'/mathbench_v1/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mathbench_v1.zip',
'md5': '50257a910ca43d1f61a610a79fdb16b5'
"/mathbench_v1/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mathbench_v1.zip",
"md5": "50257a910ca43d1f61a610a79fdb16b5",
},
'/gsm8k/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/gsm8k.zip',
'md5': '901e5dc93a2889789a469da9850cdca8'
"/gsm8k/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/gsm8k.zip",
"md5": "901e5dc93a2889789a469da9850cdca8",
},
'/LCBench2023/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/LCBench2023.zip',
'md5': 'e1a38c94a42ad1809e9e0650476a9306'
"/LCBench2023/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/LCBench2023.zip",
"md5": "e1a38c94a42ad1809e9e0650476a9306",
},
'/humaneval/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/humaneval.zip',
'md5':'88b1b89dc47b7121c81da6bcd85a69c3'
"/humaneval/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/humaneval.zip",
"md5": "88b1b89dc47b7121c81da6bcd85a69c3",
},
'/drop_simple_eval/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/drop_simple_eval.zip',
'md5': 'c912afe5b4a63509851cf16e6b91830e'
"/drop_simple_eval/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/drop_simple_eval.zip",
"md5": "c912afe5b4a63509851cf16e6b91830e",
},
'subjective/alignment_bench/': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/alignment_bench.zip',
'md5': 'd8ae9a0398526479dbbcdb80fafabceb'
"subjective/alignment_bench/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/alignment_bench.zip",
"md5": "d8ae9a0398526479dbbcdb80fafabceb",
},
'subjective/alpaca_eval': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/alpaca_eval.zip',
'md5': 'd7399d63cb46c82f089447160ef49b6a'
"subjective/alpaca_eval": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/alpaca_eval.zip",
"md5": "d7399d63cb46c82f089447160ef49b6a",
},
'subjective/arena_hard': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/arena_hard.zip',
'md5': '02cd09a482cb0f0cd9d2c2afe7a1697f'
"subjective/arena_hard": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/arena_hard.zip",
"md5": "02cd09a482cb0f0cd9d2c2afe7a1697f",
},
'subjective/mtbench': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mtbench.zip',
'md5': 'd1afc0787aeac7f1f24872742e161069'
"subjective/mtbench": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mtbench.zip",
"md5": "d1afc0787aeac7f1f24872742e161069",
},
'subjective/fofo': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/fofo.zip',
'md5': '8a302712e425e27e4292a9369df5b9d3'
"subjective/fofo": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/fofo.zip",
"md5": "8a302712e425e27e4292a9369df5b9d3",
},
'subjective/mtbench101': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mtbench101.zip',
'md5': '5d80257bc9929ebe5cfbf6d11184b04c',
"subjective/mtbench101": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mtbench101.zip",
"md5": "5d80257bc9929ebe5cfbf6d11184b04c",
},
"subjective/WildBench": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/wildbench.zip",
"md5": "b06252857f1f8f44a17b1bfca4888ff4",
},
"/ruler/": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/ruler.zip",
"md5": "c60bdfff3d02358067104cc1dea7c0f7",
},
'subjective/WildBench': {
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/wildbench.zip',
'md5': 'b06252857f1f8f44a17b1bfca4888ff4',
}
}

View File

@ -2,3 +2,4 @@ alpaca-eval==0.6
faiss_gpu==1.7.2
latex2sympy2
scikit-learn==1.5
wonderwords