diff --git a/.pre-commit-config-zh-cn.yaml b/.pre-commit-config-zh-cn.yaml index 2e21c85d..c7ca70be 100644 --- a/.pre-commit-config-zh-cn.yaml +++ b/.pre-commit-config-zh-cn.yaml @@ -89,11 +89,11 @@ repos: - mdformat_frontmatter - linkify-it-py exclude: configs/ - - repo: https://gitee.com/openmmlab/mirrors-docformatter - rev: v1.3.1 - hooks: - - id: docformatter - args: ["--in-place", "--wrap-descriptions", "79"] + # - repo: https://gitee.com/openmmlab/mirrors-docformatter + # rev: v1.3.1 + # hooks: + # - id: docformatter + # args: ["--in-place", "--wrap-descriptions", "79"] - repo: local hooks: - id: update-dataset-suffix diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9f72ae42..a1b3dd14 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -89,11 +89,11 @@ repos: - mdformat_frontmatter - linkify-it-py exclude: configs/ - - repo: https://github.com/myint/docformatter - rev: v1.3.1 - hooks: - - id: docformatter - args: ["--in-place", "--wrap-descriptions", "79"] + # - repo: https://github.com/myint/docformatter + # rev: v1.3.1 + # hooks: + # - id: docformatter + # args: ["--in-place", "--wrap-descriptions", "79"] - repo: local hooks: - id: update-dataset-suffix diff --git a/opencompass/datasets/MedXpertQA.py b/opencompass/datasets/MedXpertQA.py index c5cbe004..f016297a 100644 --- a/opencompass/datasets/MedXpertQA.py +++ b/opencompass/datasets/MedXpertQA.py @@ -1,13 +1,10 @@ -import csv -import os -import random import re from datasets import Dataset, load_dataset from opencompass.openicl import BaseEvaluator from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS -from opencompass.utils import get_data_path +from opencompass.utils import get_logger from .base import BaseDataset @@ -30,7 +27,7 @@ class MedXpertQADataset(BaseDataset): if prompt_mode == 'zero-shot': dataset = dataset.map(lambda item: _parse(item, prompt_mode)) elif prompt_mode == 'few-shot': - pass # TODO: Implement few-shot prompt + pass # TODO: Implement few-shot prompt return dataset @@ -46,7 +43,8 @@ class MedXpertQAEvaluator(BaseEvaluator): count = 0 details = [] for idx, (i, j) in enumerate(zip(predictions, references)): - i = answer_cleansing(method, i, test_set['options'][idx], test_set['label'][idx]) + i = answer_cleansing(method, i, test_set['options'][idx], + test_set['label'][idx]) detail = {'pred': i, 'answer': j, 'correct': False} count += 1 if i == j: @@ -67,35 +65,39 @@ def answer_cleansing( # Clean up unwanted phrases in the prediction for unwanted_phrase in [ - "I understand", - "A through J", - "A through E", - "A through D", + 'I understand', + 'A through J', + 'A through E', + 'A through D', ]: - prediction = prediction.replace(unwanted_phrase, "") + prediction = prediction.replace(unwanted_phrase, '') options_num = len(options) options = [chr(65 + i) for i in range(options_num)] - options_str = r"\b(" + "|".join(options) + r")\b" + options_str = r'\b(' + '|'.join(options) + r')\b' prediction = re.findall(options_str, prediction) if len(prediction) == 0: prediction = [] else: - # If there is a "label" and its length is 1, process prediction accordingly + # If there is a "label" and its length is 1, + # process prediction accordingly if len(label) == 1: - if method == "few-shot": + if method == 'few-shot': answer_flag = True if len(prediction) > 1 else False # choose the first or last element based on the answer_flag - prediction = [prediction[0]] if answer_flag else [prediction[-1]] - elif method == "zero-shot": + if answer_flag: + prediction = [prediction[0]] + else: + prediction = [prediction[-1]] + elif method == 'zero-shot': # choose the first element in list prediction = [prediction[0]] else: - raise ValueError("Method is not properly defined ...") + raise ValueError('Method is not properly defined ...') # Remove trailing period if it exists - if prediction[0] and prediction[0].endswith("."): + if prediction[0] and prediction[0].endswith('.'): prediction[0] = prediction[0][:-1] return prediction[0] @@ -108,7 +110,11 @@ def _generic_llmjudge_postprocess(judgement: str): return grade_letter -def MedXpertQA_llmjudge_postprocess(output: dict, output_path: str, dataset: Dataset) -> dict: +def MedXpertQA_llmjudge_postprocess( + output: dict, + output_path: str, + dataset: Dataset, +) -> dict: # Get the original dataset original_dataset = dataset.reader.dataset['test'] diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index a11b5183..3e2d0eef 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -92,7 +92,7 @@ from .math_intern import * # noqa: F401, F403 from .mathbench import * # noqa: F401, F403 from .mbpp import * # noqa: F401, F403 from .medbench import * # noqa: F401, F403 -from .MedXpertQA import * # noqa: F401, F403 +from .MedXpertQA import * # noqa: F401, F403 from .mgsm import * # noqa: F401, F403 from .mmlu import * # noqa: F401, F403 from .mmlu_cf import * # noqa: F401, F403