From f36c0496f37d8dc159fa0c8d1c21897dff3cfaf6 Mon Sep 17 00:00:00 2001 From: liushz Date: Tue, 18 Jul 2023 14:54:35 +0800 Subject: [PATCH] [Feature] Add tydiqa-goldp (#75) Co-authored-by: liuhongwei --- configs/datasets/tydiqa/tydiqa_gen.py | 4 ++ configs/datasets/tydiqa/tydiqa_gen_978d2a.py | 51 ++++++++++++++ opencompass/datasets/__init__.py | 1 + opencompass/datasets/tydiqa.py | 71 ++++++++++++++++++++ 4 files changed, 127 insertions(+) create mode 100644 configs/datasets/tydiqa/tydiqa_gen.py create mode 100644 configs/datasets/tydiqa/tydiqa_gen_978d2a.py create mode 100644 opencompass/datasets/tydiqa.py diff --git a/configs/datasets/tydiqa/tydiqa_gen.py b/configs/datasets/tydiqa/tydiqa_gen.py new file mode 100644 index 00000000..269c6334 --- /dev/null +++ b/configs/datasets/tydiqa/tydiqa_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .tydiqa_gen_978d2a import tydiqa_datasets # noqa: F401, F403 diff --git a/configs/datasets/tydiqa/tydiqa_gen_978d2a.py b/configs/datasets/tydiqa/tydiqa_gen_978d2a.py new file mode 100644 index 00000000..07ff7fa3 --- /dev/null +++ b/configs/datasets/tydiqa/tydiqa_gen_978d2a.py @@ -0,0 +1,51 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import TydiQADataset, TydiQAEvaluator + +# All configs are for TydiQA Goldp task +tydiqa_reader_cfg = dict( + input_columns=["passage_text", "question_text"], + output_column="answer", + test_split='validation', + train_split='validation',) + +langs = ['arabic', 'bengali', 'english', 'finnish', 'indonesian', 'japanese', 'korean', 'russian', 'swahili', 'telugu', 'thai'] + +prefixs_prompt = { + "english": ("Answer the following question based on the information in the given passage.", "Passage:", "Question:", "Answer:"), + "arabic": ("أجب على السؤال التالي بناءً على المعلومات في المقطع المعطى.", "المقطع:", "السؤال:", "الإجابة:"), + "bengali": ("প্রদত্ত অধ্যায়ের তথ্যের উপর ভিত্তি করে নিম্নলিখিত প্রশ্নের উত্তর দিন।", "অধ্যায়:", "প্রশ্ন:", "উত্তর:"), + "finnish": ("Vastaa seuraavaan kysymykseen annetun kappaleen tiedon perusteella.", "Kappale:", "Kysymys:", "Vastaus:"), + "indonesian": ("Jawab pertanyaan berikut berdasarkan informasi di bagian yang diberikan.", "Bagian:", "Pertanyaan:", "Jawaban:"), + "korean": ("주어진 문단의 정보에 기반하여 다음 질문에 답하십시오.", "문단:", "질문:", "답변:"), + "japanese":("文脈に基づいて質問に答えてください。","ぶんしょう:","しつもん:", "かいとう:"), + "russian": ("Ответьте на следующий вопрос на основе информации в данном отрывке.", "Отрывок:", "Вопрос:", "Ответ:"), + "swahili": ("Jibu swali lifuatalo kulingana na habari kwenye kifungu kilichotolewa.", "Kifungu:", "Swali:", "Jibu:"), + "telugu": ("ఇచ్చిన పేరాలోని సమాచారం ఆధారంగా కింది ప్రశ్నకు సమాధానం ఇవ్వండి.", "పేరా:", "ప్రశ్న:", "సమాధానం:"), + "thai":("ตอบคำถามต่อไปนี้โดยอิงตามข้อมูลในตอนข้อความที่กำหนด:", "ตอนข้อความ:", "คำถาม:", "คำตอบ:") +} + +tydiqa_datasets = [] +for _lang in langs: + _hint = prefixs_prompt[_lang] + tydiqa_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=f"{_hint[0]}\n\n{_hint[1]}{{passage_text}}\n{_hint[2]} {{question_text}}\n{_hint[3]} {{answer}}" , + ice_token=''), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), max_out_len=50) + + tydiqa_eval_cfg = dict(evaluator=dict(type=TydiQAEvaluator), + ds_split='validation', + ds_column='answer', + ) + tydiqa_datasets.append( + dict(abbr=f'tyidqa-goldp_{_lang}', + type=TydiQADataset, + path='khalidalt/tydiqa-goldp', + name=_lang, + reader_cfg=tydiqa_reader_cfg, + infer_cfg=tydiqa_infer_cfg, + eval_cfg=tydiqa_eval_cfg)) \ No newline at end of file diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index 7e559f42..b8f773e4 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -55,6 +55,7 @@ from .tnews import * # noqa: F401, F403 from .triviaqa import * # noqa: F401, F403 from .triviaqarc import * # noqa: F401, F403 from .truthfulqa import * # noqa: F401, F403 +from .tydiqa import * # noqa: F401, F403 from .wic import * # noqa: F401, F4 from .winograd import * # noqa: F401, F403 from .winogrande import * # noqa: F401, F403 diff --git a/opencompass/datasets/tydiqa.py b/opencompass/datasets/tydiqa.py new file mode 100644 index 00000000..7b048594 --- /dev/null +++ b/opencompass/datasets/tydiqa.py @@ -0,0 +1,71 @@ +import re +from collections import Counter + +from datasets import load_dataset + +from opencompass.openicl.icl_evaluator import BaseEvaluator +from opencompass.utils.text_postprocessors import general_postprocess + +from .base import BaseDataset + + +class TydiQADataset(BaseDataset): + + @staticmethod + def load(**kwargs): + dataset = load_dataset(**kwargs) + + def pre_process(example): + example['answer'] = example['answers']['text'] + return example + + dataset = dataset.map(pre_process).remove_columns(['id', 'answers']) + return dataset + + +class TydiQAEvaluator(BaseEvaluator): + # This evaluation class is edited from: + # https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py + def f1_score(self, prediction, ground_truth): + prediction_tokens = general_postprocess(prediction).split() + ground_truth_tokens = general_postprocess(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + def exact_match_score(self, prediction, ground_truth): + return (general_postprocess(prediction) == general_postprocess( + ground_truth)) + + def metric_max_over_ground_truths(self, metric_fn, prediction, + ground_truths): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + def score(self, predictions, references): + f1 = exact_match = total = 0 + if len(predictions) != len(references): + return { + 'error': 'predictions and references have different ' + 'length' + } + for prediction, reference in zip(predictions, references): + prediction = re.split(r'[\n]', prediction, 1)[0].lower() + exact_match += self.metric_max_over_ground_truths( + self.exact_match_score, prediction, reference) + f1 += self.metric_max_over_ground_truths(self.f1_score, prediction, + reference) + total += 1 + + exact_match = 100.0 * exact_match / total + f1 = 100.0 * f1 / total + + return {'exact_match': exact_match, 'f1': f1}