diff --git a/.codespellrc b/.codespellrc new file mode 100644 index 00000000..60e189e2 --- /dev/null +++ b/.codespellrc @@ -0,0 +1,5 @@ +[codespell] +skip = *.ipynb +count = +quiet-level = 3 +ignore-words-list = nd, ans, ques diff --git a/configs/datasets/ARC_e/ARC_e_ppl_e6b2c5.py b/configs/datasets/ARC_e/ARC_e_ppl_e6b2c5.py new file mode 100644 index 00000000..87418cd3 --- /dev/null +++ b/configs/datasets/ARC_e/ARC_e_ppl_e6b2c5.py @@ -0,0 +1,33 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import PPLInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import ARCDataset + +ARC_e_reader_cfg = dict( + input_columns=['question', 'textA', 'textB', 'textC', 'textD'], + output_column='answerKey') + +ARC_e_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template={ + "A": "Question: {question}\nAnswer: {textA}", + "B": "Question: {question}\nAnswer: {textB}", + "C": "Question: {question}\nAnswer: {textC}", + "D": "Question: {question}\nAnswer: {textD}" + }), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=PPLInferencer)) + +ARC_e_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) + +ARC_e_datasets = [ + dict( + type=ARCDataset, + abbr='ARC-e', + path='./data/ARC/ARC-e/ARC-Easy-Dev.jsonl', + reader_cfg=ARC_e_reader_cfg, + infer_cfg=ARC_e_infer_cfg, + eval_cfg=ARC_e_eval_cfg) +] diff --git a/configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl.py b/configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl.py new file mode 100644 index 00000000..bfee7b3c --- /dev/null +++ b/configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .CLUE_afqmc_ppl_c83c36 import afqmc_datasets # noqa: F401, F403 diff --git a/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl.py b/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl.py new file mode 100644 index 00000000..22da8d3b --- /dev/null +++ b/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .FewCLUE_bustm_ppl_47f2ab import bustm_datasets # noqa: F401, F403 diff --git a/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_gen.py b/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_gen.py new file mode 100644 index 00000000..ec5c483d --- /dev/null +++ b/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .FewCLUE_ocnli_fc_gen_bef37f import ocnli_fc_datasets # noqa: F401, F403 diff --git a/configs/datasets/SuperGLUE_RTE/SuperGLUE_RTE_ppl_f28ad6.py b/configs/datasets/SuperGLUE_RTE/SuperGLUE_RTE_ppl_f28ad6.py new file mode 100644 index 00000000..0ceb8371 --- /dev/null +++ b/configs/datasets/SuperGLUE_RTE/SuperGLUE_RTE_ppl_f28ad6.py @@ -0,0 +1,34 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import PPLInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import HFDataset + +RTE_reader_cfg = dict( + input_columns=['hypothesis', 'premise'], + output_column='label', + test_split='train') + +RTE_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template={ + 'entailment': '{premise}?entailment, {hypothesis}', + 'not_entailment': '{premise}?not_entailment, {hypothesis}' + }), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=PPLInferencer)) + +RTE_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) + +RTE_datasets = [ + dict( + type=HFDataset, + abbr='RTE', + path='json', + data_files='./data/SuperGLUE/RTE/val.jsonl', + split='train', + reader_cfg=RTE_reader_cfg, + infer_cfg=RTE_infer_cfg, + eval_cfg=RTE_eval_cfg) +] diff --git a/configs/datasets/SuperGLUE_WiC/SuperGLUE_WiC_ppl.py b/configs/datasets/SuperGLUE_WiC/SuperGLUE_WiC_ppl.py new file mode 100644 index 00000000..95bb8c36 --- /dev/null +++ b/configs/datasets/SuperGLUE_WiC/SuperGLUE_WiC_ppl.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .SuperGLUE_WiC_ppl_4118db import WiC_datasets # noqa: F401, F403 diff --git a/configs/datasets/civilcomments/civilcomments_ppl.py b/configs/datasets/civilcomments/civilcomments_ppl.py new file mode 100644 index 00000000..2b4c9b4c --- /dev/null +++ b/configs/datasets/civilcomments/civilcomments_ppl.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .civilcomments_ppl_e01497 import civilcomments_datasets # noqa: F401, F403 diff --git a/configs/datasets/commonsenseqa/commonsenseqa_gen.py b/configs/datasets/commonsenseqa/commonsenseqa_gen.py new file mode 100644 index 00000000..86964ccd --- /dev/null +++ b/configs/datasets/commonsenseqa/commonsenseqa_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .commonsenseqa_gen_a58dbd import commonsenseqa_datasets # noqa: F401, F403 diff --git a/configs/datasets/glm/chid.py b/configs/datasets/glm/chid.py new file mode 100644 index 00000000..63bbaff1 --- /dev/null +++ b/configs/datasets/glm/chid.py @@ -0,0 +1,30 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import PPLInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import CHIDDataset + +chid_reader_cfg = dict( + input_columns=[f'content{i}' for i in range(7)], output_column='answer') + +chid_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template={answer: f"{{content{answer}}}" + for answer in range(7)}), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=PPLInferencer)) + +chid_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) + +chid_datasets = [ + dict( + type=CHIDDataset, + path='json', + abbr='chid', + data_files='./data/FewCLUE/chid/test_public.json', + split='train', + reader_cfg=chid_reader_cfg, + infer_cfg=chid_infer_cfg, + eval_cfg=chid_eval_cfg) +] diff --git a/configs/datasets/hellaswag/hellaswag_gen.py b/configs/datasets/hellaswag/hellaswag_gen.py new file mode 100644 index 00000000..a0e437ee --- /dev/null +++ b/configs/datasets/hellaswag/hellaswag_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .hellaswag_gen_cae9cb import hellaswag_datasets # noqa: F401, F403 diff --git a/configs/datasets/humaneval/humaneval_gen_d428f1.py b/configs/datasets/humaneval/humaneval_gen_d428f1.py new file mode 100644 index 00000000..0b71e8e8 --- /dev/null +++ b/configs/datasets/humaneval/humaneval_gen_d428f1.py @@ -0,0 +1,35 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import HFDataset, HumanEvaluator + +humaneval_reader_cfg = dict( + input_columns=['prompt'], output_column='task_id', train_split='test') + +# TODO: allow empty output-column +humaneval_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict( + role='HUMAN', + prompt='Complete the following python code:\n{prompt}'), + ])), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=512)) + +humaneval_eval_cfg = dict( + evaluator=dict(type=HumanEvaluator), + pred_role='BOT', + k=[1, 10, 100], # the parameter only for humaneval + pred_postprocessor=dict(type='humaneval'), +) + +humaneval_datasets = [ + dict( + type=HFDataset, + path='openai_humaneval', + reader_cfg=humaneval_reader_cfg, + infer_cfg=humaneval_infer_cfg, + eval_cfg=humaneval_eval_cfg) +] diff --git a/configs/datasets/iwslt2017/iwslt2017_gen_95def3.py b/configs/datasets/iwslt2017/iwslt2017_gen_95def3.py new file mode 100644 index 00000000..3b51fba9 --- /dev/null +++ b/configs/datasets/iwslt2017/iwslt2017_gen_95def3.py @@ -0,0 +1,31 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import BM25Retriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_evaluator import BleuEvaluator +from opencompass.datasets import IWSLT2017Dataset + +iwslt2017_reader_cfg = dict( + input_columns='en', output_column='de', train_split='validation') + +iwslt2017_infer_cfg = dict( + ice_template=dict(type='PromptTemplate', + template='{en} = {de}', + ice_token=''), + retriever=dict(type=BM25Retriever, ice_num=1), + inferencer=dict(type=GenInferencer)) + +iwslt2017_eval_cfg = dict( + evaluator=dict(type=BleuEvaluator), + pred_role='BOT', + pred_postprocessor=dict(type='general_cn'), + dataset_postprocessor=dict(type='general_cn')) + +iwslt2017_datasets = [ + dict( + type=IWSLT2017Dataset, + path='iwslt2017', + name='iwslt2017-en-de', + reader_cfg=iwslt2017_reader_cfg, + infer_cfg=iwslt2017_infer_cfg, + eval_cfg=iwslt2017_eval_cfg) +] \ No newline at end of file diff --git a/configs/datasets/mbpp/mbpp_gen.py b/configs/datasets/mbpp/mbpp_gen.py new file mode 100644 index 00000000..9398c835 --- /dev/null +++ b/configs/datasets/mbpp/mbpp_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .mbpp_gen_4104e4 import mbpp_datasets # noqa: F401, F403 diff --git a/configs/datasets/narrativeqa/narrativeqa_gen_5786a7.py b/configs/datasets/narrativeqa/narrativeqa_gen_5786a7.py new file mode 100644 index 00000000..41c1b17d --- /dev/null +++ b/configs/datasets/narrativeqa/narrativeqa_gen_5786a7.py @@ -0,0 +1,37 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import NarrativeQADataset, TriviaQAEvaluator + +narrativeqa_reader_cfg = dict( + input_columns=['question', 'evidence'], + output_column='answer', + train_split='valid', + test_split='valid') + +narrativeqa_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict( + role='HUMAN', + prompt='{evidence}\nAnswer these questions:\nQ: {question}?A:'), + dict(role='BOT', prompt=''), + ], )), + retriever=dict(type=ZeroRetriever), + inferencer=dict( + type=GenInferencer, max_out_len=50, max_seq_len=8192, batch_size=4)) + +narrativeqa_eval_cfg = dict( + evaluator=dict(type=TriviaQAEvaluator), pred_role='BOT') + +narrativeqa_datasets = [ + dict( + type=NarrativeQADataset, + abbr='NarrativeQA', + path='./data/narrativeqa/', + reader_cfg=narrativeqa_reader_cfg, + infer_cfg=narrativeqa_infer_cfg, + eval_cfg=narrativeqa_eval_cfg) +] diff --git a/configs/datasets/qabench/qabench_gen_0d5967.py b/configs/datasets/qabench/qabench_gen_0d5967.py new file mode 100644 index 00000000..d335e5d9 --- /dev/null +++ b/configs/datasets/qabench/qabench_gen_0d5967.py @@ -0,0 +1,29 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import HFDataset + +qabench_reader_cfg = dict( + input_columns=['prompt'], + output_column='reference', +) + +# TODO: allow empty output-column +qabench_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[dict(role="HUMAN", prompt="{prompt}")])), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer)) + +qabench_datasets = [ + dict( + type=HFDataset, + path='csv', + data_files='./data/qabench/qabench-test.qa.csv', + abbr="qabench", + split='train', + reader_cfg=qabench_reader_cfg, + infer_cfg=qabench_infer_cfg, + eval_cfg=dict(ds_column="reference")) +] diff --git a/configs/datasets/qaspercut/qaspercut_gen.py b/configs/datasets/qaspercut/qaspercut_gen.py new file mode 100644 index 00000000..cb16e9fd --- /dev/null +++ b/configs/datasets/qaspercut/qaspercut_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .qaspercut_gen_943606 import qaspercut_datasets # noqa: F401, F403 diff --git a/configs/datasets/strategyqa/strategyqa_gen.py b/configs/datasets/strategyqa/strategyqa_gen.py new file mode 100644 index 00000000..f23e1741 --- /dev/null +++ b/configs/datasets/strategyqa/strategyqa_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .strategyqa_gen_be3f8d import strategyqa_datasets # noqa: F401, F403 diff --git a/configs/datasets/triviaqarc/triviaqarc_gen.py b/configs/datasets/triviaqarc/triviaqarc_gen.py new file mode 100644 index 00000000..66c346d5 --- /dev/null +++ b/configs/datasets/triviaqarc/triviaqarc_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .triviaqarc_gen_6c1726 import triviaqarc_datasets # noqa: F401, F403 diff --git a/configs/datasets/winogrande/winogrande_ppl_00f8ad.py b/configs/datasets/winogrande/winogrande_ppl_00f8ad.py new file mode 100644 index 00000000..500166c4 --- /dev/null +++ b/configs/datasets/winogrande/winogrande_ppl_00f8ad.py @@ -0,0 +1,36 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import PPLInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import winograndeDataset + +winogrande_reader_cfg = dict( + input_columns=['opt1', 'opt2'], + output_column='answer', + train_split='validation', + test_split='validation') + +winogrande_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template={ + i: dict(round=[ + dict(role="HUMAN", prompt=f"Good sentence: {{opt{i+1}}}"), + ]) + for i in range(2) + }), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=PPLInferencer)) + +winogrande_eval_cfg = dict(evaluator=dict(type=AccEvaluator), ) + +winogrande_datasets = [ + dict( + abbr='winogrande', + type=winograndeDataset, + path='winogrande', + name='winogrande_xs', + reader_cfg=winogrande_reader_cfg, + infer_cfg=winogrande_infer_cfg, + eval_cfg=winogrande_eval_cfg) +] diff --git a/configs/models/classic/hf_llama.py b/configs/models/classic/hf_llama.py new file mode 100644 index 00000000..64bcec55 --- /dev/null +++ b/configs/models/classic/hf_llama.py @@ -0,0 +1,22 @@ +from opencompass.models import HuggingFaceCausalLM + + +models = [ + # LLaMA 7B + dict( + type=HuggingFaceCausalLM, + path="decapoda-research/llama-7b-hf", + tokenizer_path='decapoda-research/llama-7b-hf', + tokenizer_kwargs=dict(padding_side='left', + truncation_side='left', + use_fast=False, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + model_kwargs=dict(device_map='auto'), + batch_padding=False, # if false, inference with for-loop without batch padding + run_cfg=dict(num_gpus=2, num_procs=1), + ) + +] diff --git a/configs/summarizers/groups/flores.py b/configs/summarizers/groups/flores.py new file mode 100644 index 00000000..42514afb --- /dev/null +++ b/configs/summarizers/groups/flores.py @@ -0,0 +1,25 @@ +flores_summary_groups = [] + +_flores_lang_map = { + 'Indo-European-Germanic': ['afr', 'dan', 'deu', 'isl', 'ltz', 'nld', 'nob', 'swe'], + 'Indo-European-Romance': ['ast', 'cat', 'fra', 'glg', 'oci', 'por', 'ron', 'spa'], + 'Indo-European-Slavic': ['bel', 'bos', 'bul', 'ces', 'hrv', 'mkd', 'pol', 'rus', 'slk', 'slv', 'srp', 'ukr'], + 'Indo-European-Indo-Aryan': ['asm', 'ben', 'guj', 'hin', 'mar', 'npi', 'ory', 'pan', 'snd', 'urd'], + 'Indo-European-Other': ['ckb', 'cym', 'ell', 'fas', 'gle', 'hye', 'ita', 'lav', 'lit', 'pus', 'tgk'], + 'Austronesian': ['ceb', 'ind', 'jav', 'mri', 'msa', 'tgl'], + 'Atlantic-Congo': ['ibo', 'kam', 'kea', 'lin', 'lug', 'nso', 'nya', 'sna', 'swh', 'umb', 'wol', 'xho', 'yor', 'zul'], + 'Afro-Asiatic': ['amh', 'ara', 'ful', 'mlt', 'orm', 'som'], + 'Turkic': ['azj', 'kaz', 'kir', 'tur', 'uzb'], + 'Dravidian': ['kan', 'mal', 'tam', 'tel'], + 'Sino-Tibetan': ['mya', 'zho_simpl', 'zho_trad'], + 'Other': ['est', 'fin', 'hau', 'heb', 'hun', 'jpn', 'kat', 'khm', 'kor', 'lao', 'luo', 'mon', 'tha', 'vie'], +} +for _lang_serie in _flores_lang_map: + flores_summary_groups.append({ + 'name': f'flores_100_{_lang_serie}_English', + 'subsets': [f'flores_100_{lang_name}-eng' for lang_name in _flores_lang_map[_lang_serie]] + }) + flores_summary_groups.append({ + 'name': f'flores_100_English_{_lang_serie}', + 'subsets': [f'flores_100_eng-{lang_name}' for lang_name in _flores_lang_map[_lang_serie]] + }) diff --git a/docs/en/prompt/meta_template.md b/docs/en/prompt/meta_template.md new file mode 100644 index 00000000..ae0a50d2 --- /dev/null +++ b/docs/en/prompt/meta_template.md @@ -0,0 +1 @@ +# Meta-Prompt \ No newline at end of file diff --git a/docs/en/user_guides/config.md b/docs/en/user_guides/config.md new file mode 100644 index 00000000..0ffd292c --- /dev/null +++ b/docs/en/user_guides/config.md @@ -0,0 +1,2 @@ +# Learn About Config + diff --git a/docs/zh_cn/_templates/404.html b/docs/zh_cn/_templates/404.html new file mode 100644 index 00000000..64910175 --- /dev/null +++ b/docs/zh_cn/_templates/404.html @@ -0,0 +1,18 @@ +{% extends "layout.html" %} + +{% block body %} + +

Page Not Found

+

+ The page you are looking for cannot be found. +

+

+ If you just switched documentation versions, it is likely that the page you were on is moved. You can look for it in + the content table left, or go to the homepage. +

+ + +{% endblock %} diff --git a/docs/zh_cn/prompt/prompt_template.md b/docs/zh_cn/prompt/prompt_template.md new file mode 100644 index 00000000..bd8af7dc --- /dev/null +++ b/docs/zh_cn/prompt/prompt_template.md @@ -0,0 +1 @@ +# Prompt 模板 \ No newline at end of file diff --git a/docs/zh_cn/user_guides/models.md b/docs/zh_cn/user_guides/models.md new file mode 100644 index 00000000..d24ab7bc --- /dev/null +++ b/docs/zh_cn/user_guides/models.md @@ -0,0 +1 @@ +# 准备模型 diff --git a/opencompass/datasets/cmnli.py b/opencompass/datasets/cmnli.py new file mode 100644 index 00000000..9cd9243c --- /dev/null +++ b/opencompass/datasets/cmnli.py @@ -0,0 +1,27 @@ +import json + +from datasets import Dataset + +from opencompass.registry import LOAD_DATASET + +from .base import BaseDataset + + +@LOAD_DATASET.register_module() +class cmnliDataset_V2(BaseDataset): + + @staticmethod + def load(path): + data = [] + with open(path, 'r') as f: + for line in f: + line = json.loads(line) + if line['label'] == '-': + continue + line['label'] = { + 'entailment': 'A', + 'contradiction': 'B', + 'neutral': 'C', + }[line['label']] + data.append(line) + return Dataset.from_list(data) diff --git a/opencompass/datasets/xlsum.py b/opencompass/datasets/xlsum.py new file mode 100644 index 00000000..6830f22b --- /dev/null +++ b/opencompass/datasets/xlsum.py @@ -0,0 +1,33 @@ +from datasets import concatenate_datasets, load_dataset + +from opencompass.registry import LOAD_DATASET + +from .base import BaseDataset + + +@LOAD_DATASET.register_module() +class XLSUMDataset(BaseDataset): + + @staticmethod + def load(**kwargs): + path = kwargs.get('path', None) + lans = [ + 'oromo', 'french', 'amharic', 'arabic', 'azerbaijani', 'bengali', + 'burmese', 'chinese_simplified', 'chinese_traditional', 'welsh', + 'english', 'kirundi', 'gujarati', 'hausa', 'hindi', 'igbo', + 'indonesian', 'japanese', 'korean', 'kyrgyz', 'marathi', 'spanish', + 'scottish_gaelic', 'nepali', 'pashto', 'persian', 'pidgin', + 'portuguese', 'punjabi', 'russian', 'serbian_cyrillic', + 'serbian_latin', 'sinhala', 'somali', 'swahili', 'tamil', 'telugu', + 'thai', 'tigrinya', 'turkish', 'ukrainian', 'urdu', 'uzbek', + 'vietnamese', 'yoruba' + ] + + datasets = [] + for lan in lans: + dataset = load_dataset(path, lan)['validation'] + datasets.append(dataset) + + combined_dataset = concatenate_datasets(datasets) + + return combined_dataset diff --git a/opencompass/datasets/xsum.py b/opencompass/datasets/xsum.py new file mode 100644 index 00000000..4ece9132 --- /dev/null +++ b/opencompass/datasets/xsum.py @@ -0,0 +1,36 @@ +import json + +from datasets import Dataset + +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS + +from .base import BaseDataset + + +@LOAD_DATASET.register_module() +class XsumDataset(BaseDataset): + + @staticmethod + def load(path: str): + with open(path, 'r', errors='ignore') as in_f: + rows = [] + for i, line in enumerate(in_f): + if i == 1000: + break + sample = json.loads(line.strip()) + dialogue = sample['dialogue'] + summary = sample['summary'] + if isinstance(dialogue, float) or isinstance(summary, float): + continue + rows.append({'dialogue': dialogue, 'summary': summary}) + dataset = Dataset.from_dict({ + 'dialogue': [row['dialogue'] for row in rows], + 'summary': [row['summary'] for row in rows] + }) + return dataset + + +@TEXT_POSTPROCESSORS.register_module('Xsum') +def Xsum_postprocess(text: str) -> str: + text = text.strip().split('\n')[0].strip() + return text diff --git a/opencompass/openicl/icl_evaluator/icl_base_evaluator.py b/opencompass/openicl/icl_evaluator/icl_base_evaluator.py new file mode 100644 index 00000000..14fa8a20 --- /dev/null +++ b/opencompass/openicl/icl_evaluator/icl_base_evaluator.py @@ -0,0 +1,10 @@ +"""Base Evaluator.""" +from typing import List + + +class BaseEvaluator: + def __init__(self) -> None: + pass + + def score(self): + raise NotImplementedError("Method hasn't been implemented yet")