diff --git a/configs/datasets/wmt19/wmt19_gen.py b/configs/datasets/wmt19/wmt19_gen.py new file mode 100644 index 00000000..0a138776 --- /dev/null +++ b/configs/datasets/wmt19/wmt19_gen.py @@ -0,0 +1,122 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever, BM25Retriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_evaluator import BleuEvaluator +from opencompass.datasets.wmt19 import WMT19TranslationDataset + +LANG_CODE_TO_NAME = { + 'cs': 'Czech', + 'de': 'German', + 'en': 'English', + 'fi': 'Finnish', + 'fr': 'French', + 'gu': 'Gujarati', + 'kk': 'Kazakh', + 'lt': 'Lithuanian', + 'ru': 'Russian', + 'zh': 'Chinese' +} + +wmt19_reader_cfg = dict( + input_columns=['input'], + output_column='target', + train_split='validation', + test_split='validation') + +wmt19_infer_cfg_0shot = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt='Translate the following {src_lang_name} text to {tgt_lang_name}:\n{{input}}\n'), + dict(role='BOT', prompt='Translation:\n') + ] + ) + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer) +) + +wmt19_infer_cfg_5shot = dict( + ice_template=dict( + type=PromptTemplate, + template='Example:\n{src_lang_name}: {{input}}\n{tgt_lang_name}: {{target}}' + ), + prompt_template=dict( + type=PromptTemplate, + template='\nTranslate the following {src_lang_name} text to {tgt_lang_name}:\n{{input}}\nTranslation:\n', + ice_token='', + ), + retriever=dict(type=BM25Retriever, ice_num=5), + inferencer=dict(type=GenInferencer), +) + +wmt19_eval_cfg = dict( + evaluator=dict( + type=BleuEvaluator + ), + pred_role='BOT', +) + +language_pairs = [ + ('cs', 'en'), ('de', 'en'), ('fi', 'en'), ('fr', 'de'), + ('gu', 'en'), ('kk', 'en'), ('lt', 'en'), ('ru', 'en'), ('zh', 'en') +] + +wmt19_datasets = [] + +for src_lang, tgt_lang in language_pairs: + src_lang_name = LANG_CODE_TO_NAME[src_lang] + tgt_lang_name = LANG_CODE_TO_NAME[tgt_lang] + + wmt19_datasets.extend([ + dict( + abbr=f'wmt19_{src_lang}-{tgt_lang}_0shot', + type=WMT19TranslationDataset, + path='/path/to/wmt19', + src_lang=src_lang, + tgt_lang=tgt_lang, + reader_cfg=wmt19_reader_cfg, + infer_cfg={ + **wmt19_infer_cfg_0shot, + 'prompt_template': { + **wmt19_infer_cfg_0shot['prompt_template'], + 'template': { + **wmt19_infer_cfg_0shot['prompt_template']['template'], + 'round': [ + { + **wmt19_infer_cfg_0shot['prompt_template']['template']['round'][0], + 'prompt': wmt19_infer_cfg_0shot['prompt_template']['template']['round'][0]['prompt'].format( + src_lang_name=src_lang_name, tgt_lang_name=tgt_lang_name + ) + }, + wmt19_infer_cfg_0shot['prompt_template']['template']['round'][1] + ] + } + } + }, + eval_cfg=wmt19_eval_cfg), + dict( + abbr=f'wmt19_{src_lang}-{tgt_lang}_5shot', + type=WMT19TranslationDataset, + path='/path/to/wmt19', + src_lang=src_lang, + tgt_lang=tgt_lang, + reader_cfg=wmt19_reader_cfg, + infer_cfg={ + **wmt19_infer_cfg_5shot, + 'ice_template': { + **wmt19_infer_cfg_5shot['ice_template'], + 'template': wmt19_infer_cfg_5shot['ice_template']['template'].format( + src_lang_name=src_lang_name, tgt_lang_name=tgt_lang_name + ) + }, + 'prompt_template': { + **wmt19_infer_cfg_5shot['prompt_template'], + 'template': wmt19_infer_cfg_5shot['prompt_template']['template'].format( + src_lang_name=src_lang_name, tgt_lang_name=tgt_lang_name + ) + } + }, + eval_cfg=wmt19_eval_cfg), + ]) \ No newline at end of file diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index b1753221..ef14745d 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -164,6 +164,7 @@ from .wic import * # noqa: F401, F403 from .wikibench import * # noqa: F401, F403 from .winograd import * # noqa: F401, F403 from .winogrande import * # noqa: F401, F403 +from .wmt19 import * # noqa: F401, F403 from .wnli import wnliDataset # noqa: F401, F403 from .wsc import * # noqa: F401, F403 from .xcopa import * # noqa: F401, F403 diff --git a/opencompass/datasets/wmt19.py b/opencompass/datasets/wmt19.py new file mode 100644 index 00000000..45070b83 --- /dev/null +++ b/opencompass/datasets/wmt19.py @@ -0,0 +1,38 @@ +import os +import pandas as pd +from datasets import Dataset, DatasetDict +from opencompass.registry import LOAD_DATASET +from opencompass.datasets.base import BaseDataset + +@LOAD_DATASET.register_module() +class WMT19TranslationDataset(BaseDataset): + @staticmethod + def load(path: str, src_lang: str, tgt_lang: str): + print(f"Attempting to load data from path: {path}") + print(f"Source language: {src_lang}, Target language: {tgt_lang}") + + lang_pair_dir = os.path.join(path, f"{src_lang}-{tgt_lang}") + if not os.path.exists(lang_pair_dir): + lang_pair_dir = os.path.join(path, f"{tgt_lang}-{src_lang}") + if not os.path.exists(lang_pair_dir): + raise ValueError(f"Cannot find directory for language pair {src_lang}-{tgt_lang} or {tgt_lang}-{src_lang}") + + print(f"Loading data from directory: {lang_pair_dir}") + + val_file = os.path.join(lang_pair_dir, "validation-00000-of-00001.parquet") + val_df = pd.read_parquet(val_file) + + def process_split(df): + return Dataset.from_dict({ + 'input': df['translation'].apply(lambda x: x[src_lang]).tolist(), + 'target': df['translation'].apply(lambda x: x[tgt_lang]).tolist() + }) + + return DatasetDict({ + 'validation': process_split(val_df) + }) + + @classmethod + def get_dataset(cls, path, src_lang, tgt_lang): + return cls.load(path, src_lang, tgt_lang) +