Add MedXpertQA

This commit is contained in:
Yejin0111 2025-04-07 10:36:32 +00:00
parent a84d8d693e
commit 34e11d136f
4 changed files with 36 additions and 30 deletions

View File

@ -89,11 +89,11 @@ repos:
- mdformat_frontmatter
- linkify-it-py
exclude: configs/
- repo: https://gitee.com/openmmlab/mirrors-docformatter
rev: v1.3.1
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
# - repo: https://gitee.com/openmmlab/mirrors-docformatter
# rev: v1.3.1
# hooks:
# - id: docformatter
# args: ["--in-place", "--wrap-descriptions", "79"]
- repo: local
hooks:
- id: update-dataset-suffix

View File

@ -89,11 +89,11 @@ repos:
- mdformat_frontmatter
- linkify-it-py
exclude: configs/
- repo: https://github.com/myint/docformatter
rev: v1.3.1
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
# - repo: https://github.com/myint/docformatter
# rev: v1.3.1
# hooks:
# - id: docformatter
# args: ["--in-place", "--wrap-descriptions", "79"]
- repo: local
hooks:
- id: update-dataset-suffix

View File

@ -1,13 +1,10 @@
import csv
import os
import random
import re
from datasets import Dataset, load_dataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.utils import get_data_path
from opencompass.utils import get_logger
from .base import BaseDataset
@ -46,7 +43,8 @@ class MedXpertQAEvaluator(BaseEvaluator):
count = 0
details = []
for idx, (i, j) in enumerate(zip(predictions, references)):
i = answer_cleansing(method, i, test_set['options'][idx], test_set['label'][idx])
i = answer_cleansing(method, i, test_set['options'][idx],
test_set['label'][idx])
detail = {'pred': i, 'answer': j, 'correct': False}
count += 1
if i == j:
@ -67,35 +65,39 @@ def answer_cleansing(
# Clean up unwanted phrases in the prediction
for unwanted_phrase in [
"I understand",
"A through J",
"A through E",
"A through D",
'I understand',
'A through J',
'A through E',
'A through D',
]:
prediction = prediction.replace(unwanted_phrase, "")
prediction = prediction.replace(unwanted_phrase, '')
options_num = len(options)
options = [chr(65 + i) for i in range(options_num)]
options_str = r"\b(" + "|".join(options) + r")\b"
options_str = r'\b(' + '|'.join(options) + r')\b'
prediction = re.findall(options_str, prediction)
if len(prediction) == 0:
prediction = []
else:
# If there is a "label" and its length is 1, process prediction accordingly
# If there is a "label" and its length is 1,
# process prediction accordingly
if len(label) == 1:
if method == "few-shot":
if method == 'few-shot':
answer_flag = True if len(prediction) > 1 else False
# choose the first or last element based on the answer_flag
prediction = [prediction[0]] if answer_flag else [prediction[-1]]
elif method == "zero-shot":
if answer_flag:
prediction = [prediction[0]]
else:
prediction = [prediction[-1]]
elif method == 'zero-shot':
# choose the first element in list
prediction = [prediction[0]]
else:
raise ValueError("Method is not properly defined ...")
raise ValueError('Method is not properly defined ...')
# Remove trailing period if it exists
if prediction[0] and prediction[0].endswith("."):
if prediction[0] and prediction[0].endswith('.'):
prediction[0] = prediction[0][:-1]
return prediction[0]
@ -108,7 +110,11 @@ def _generic_llmjudge_postprocess(judgement: str):
return grade_letter
def MedXpertQA_llmjudge_postprocess(output: dict, output_path: str, dataset: Dataset) -> dict:
def MedXpertQA_llmjudge_postprocess(
output: dict,
output_path: str,
dataset: Dataset,
) -> dict:
# Get the original dataset
original_dataset = dataset.reader.dataset['test']