mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Add MedXpertQA
This commit is contained in:
parent
a84d8d693e
commit
34e11d136f
@ -89,11 +89,11 @@ repos:
|
||||
- mdformat_frontmatter
|
||||
- linkify-it-py
|
||||
exclude: configs/
|
||||
- repo: https://gitee.com/openmmlab/mirrors-docformatter
|
||||
rev: v1.3.1
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
||||
# - repo: https://gitee.com/openmmlab/mirrors-docformatter
|
||||
# rev: v1.3.1
|
||||
# hooks:
|
||||
# - id: docformatter
|
||||
# args: ["--in-place", "--wrap-descriptions", "79"]
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: update-dataset-suffix
|
||||
|
@ -89,11 +89,11 @@ repos:
|
||||
- mdformat_frontmatter
|
||||
- linkify-it-py
|
||||
exclude: configs/
|
||||
- repo: https://github.com/myint/docformatter
|
||||
rev: v1.3.1
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
||||
# - repo: https://github.com/myint/docformatter
|
||||
# rev: v1.3.1
|
||||
# hooks:
|
||||
# - id: docformatter
|
||||
# args: ["--in-place", "--wrap-descriptions", "79"]
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: update-dataset-suffix
|
||||
|
@ -1,13 +1,10 @@
|
||||
import csv
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
|
||||
from opencompass.openicl import BaseEvaluator
|
||||
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
||||
from opencompass.utils import get_data_path
|
||||
from opencompass.utils import get_logger
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
@ -30,7 +27,7 @@ class MedXpertQADataset(BaseDataset):
|
||||
if prompt_mode == 'zero-shot':
|
||||
dataset = dataset.map(lambda item: _parse(item, prompt_mode))
|
||||
elif prompt_mode == 'few-shot':
|
||||
pass # TODO: Implement few-shot prompt
|
||||
pass # TODO: Implement few-shot prompt
|
||||
|
||||
return dataset
|
||||
|
||||
@ -46,7 +43,8 @@ class MedXpertQAEvaluator(BaseEvaluator):
|
||||
count = 0
|
||||
details = []
|
||||
for idx, (i, j) in enumerate(zip(predictions, references)):
|
||||
i = answer_cleansing(method, i, test_set['options'][idx], test_set['label'][idx])
|
||||
i = answer_cleansing(method, i, test_set['options'][idx],
|
||||
test_set['label'][idx])
|
||||
detail = {'pred': i, 'answer': j, 'correct': False}
|
||||
count += 1
|
||||
if i == j:
|
||||
@ -67,35 +65,39 @@ def answer_cleansing(
|
||||
|
||||
# Clean up unwanted phrases in the prediction
|
||||
for unwanted_phrase in [
|
||||
"I understand",
|
||||
"A through J",
|
||||
"A through E",
|
||||
"A through D",
|
||||
'I understand',
|
||||
'A through J',
|
||||
'A through E',
|
||||
'A through D',
|
||||
]:
|
||||
prediction = prediction.replace(unwanted_phrase, "")
|
||||
prediction = prediction.replace(unwanted_phrase, '')
|
||||
|
||||
options_num = len(options)
|
||||
options = [chr(65 + i) for i in range(options_num)]
|
||||
options_str = r"\b(" + "|".join(options) + r")\b"
|
||||
options_str = r'\b(' + '|'.join(options) + r')\b'
|
||||
prediction = re.findall(options_str, prediction)
|
||||
|
||||
if len(prediction) == 0:
|
||||
prediction = []
|
||||
else:
|
||||
# If there is a "label" and its length is 1, process prediction accordingly
|
||||
# If there is a "label" and its length is 1,
|
||||
# process prediction accordingly
|
||||
if len(label) == 1:
|
||||
if method == "few-shot":
|
||||
if method == 'few-shot':
|
||||
answer_flag = True if len(prediction) > 1 else False
|
||||
# choose the first or last element based on the answer_flag
|
||||
prediction = [prediction[0]] if answer_flag else [prediction[-1]]
|
||||
elif method == "zero-shot":
|
||||
if answer_flag:
|
||||
prediction = [prediction[0]]
|
||||
else:
|
||||
prediction = [prediction[-1]]
|
||||
elif method == 'zero-shot':
|
||||
# choose the first element in list
|
||||
prediction = [prediction[0]]
|
||||
else:
|
||||
raise ValueError("Method is not properly defined ...")
|
||||
raise ValueError('Method is not properly defined ...')
|
||||
|
||||
# Remove trailing period if it exists
|
||||
if prediction[0] and prediction[0].endswith("."):
|
||||
if prediction[0] and prediction[0].endswith('.'):
|
||||
prediction[0] = prediction[0][:-1]
|
||||
|
||||
return prediction[0]
|
||||
@ -108,7 +110,11 @@ def _generic_llmjudge_postprocess(judgement: str):
|
||||
return grade_letter
|
||||
|
||||
|
||||
def MedXpertQA_llmjudge_postprocess(output: dict, output_path: str, dataset: Dataset) -> dict:
|
||||
def MedXpertQA_llmjudge_postprocess(
|
||||
output: dict,
|
||||
output_path: str,
|
||||
dataset: Dataset,
|
||||
) -> dict:
|
||||
# Get the original dataset
|
||||
original_dataset = dataset.reader.dataset['test']
|
||||
|
||||
|
@ -92,7 +92,7 @@ from .math_intern import * # noqa: F401, F403
|
||||
from .mathbench import * # noqa: F401, F403
|
||||
from .mbpp import * # noqa: F401, F403
|
||||
from .medbench import * # noqa: F401, F403
|
||||
from .MedXpertQA import * # noqa: F401, F403
|
||||
from .MedXpertQA import * # noqa: F401, F403
|
||||
from .mgsm import * # noqa: F401, F403
|
||||
from .mmlu import * # noqa: F401, F403
|
||||
from .mmlu_cf import * # noqa: F401, F403
|
||||
|
Loading…
Reference in New Issue
Block a user