mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Add MedXpertQA
This commit is contained in:
parent
a84d8d693e
commit
34e11d136f
@ -89,11 +89,11 @@ repos:
|
|||||||
- mdformat_frontmatter
|
- mdformat_frontmatter
|
||||||
- linkify-it-py
|
- linkify-it-py
|
||||||
exclude: configs/
|
exclude: configs/
|
||||||
- repo: https://gitee.com/openmmlab/mirrors-docformatter
|
# - repo: https://gitee.com/openmmlab/mirrors-docformatter
|
||||||
rev: v1.3.1
|
# rev: v1.3.1
|
||||||
hooks:
|
# hooks:
|
||||||
- id: docformatter
|
# - id: docformatter
|
||||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
# args: ["--in-place", "--wrap-descriptions", "79"]
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
- id: update-dataset-suffix
|
- id: update-dataset-suffix
|
||||||
|
@ -89,11 +89,11 @@ repos:
|
|||||||
- mdformat_frontmatter
|
- mdformat_frontmatter
|
||||||
- linkify-it-py
|
- linkify-it-py
|
||||||
exclude: configs/
|
exclude: configs/
|
||||||
- repo: https://github.com/myint/docformatter
|
# - repo: https://github.com/myint/docformatter
|
||||||
rev: v1.3.1
|
# rev: v1.3.1
|
||||||
hooks:
|
# hooks:
|
||||||
- id: docformatter
|
# - id: docformatter
|
||||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
# args: ["--in-place", "--wrap-descriptions", "79"]
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
- id: update-dataset-suffix
|
- id: update-dataset-suffix
|
||||||
|
@ -1,13 +1,10 @@
|
|||||||
import csv
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
|
|
||||||
from opencompass.openicl import BaseEvaluator
|
from opencompass.openicl import BaseEvaluator
|
||||||
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
||||||
from opencompass.utils import get_data_path
|
from opencompass.utils import get_logger
|
||||||
|
|
||||||
from .base import BaseDataset
|
from .base import BaseDataset
|
||||||
|
|
||||||
@ -30,7 +27,7 @@ class MedXpertQADataset(BaseDataset):
|
|||||||
if prompt_mode == 'zero-shot':
|
if prompt_mode == 'zero-shot':
|
||||||
dataset = dataset.map(lambda item: _parse(item, prompt_mode))
|
dataset = dataset.map(lambda item: _parse(item, prompt_mode))
|
||||||
elif prompt_mode == 'few-shot':
|
elif prompt_mode == 'few-shot':
|
||||||
pass # TODO: Implement few-shot prompt
|
pass # TODO: Implement few-shot prompt
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@ -46,7 +43,8 @@ class MedXpertQAEvaluator(BaseEvaluator):
|
|||||||
count = 0
|
count = 0
|
||||||
details = []
|
details = []
|
||||||
for idx, (i, j) in enumerate(zip(predictions, references)):
|
for idx, (i, j) in enumerate(zip(predictions, references)):
|
||||||
i = answer_cleansing(method, i, test_set['options'][idx], test_set['label'][idx])
|
i = answer_cleansing(method, i, test_set['options'][idx],
|
||||||
|
test_set['label'][idx])
|
||||||
detail = {'pred': i, 'answer': j, 'correct': False}
|
detail = {'pred': i, 'answer': j, 'correct': False}
|
||||||
count += 1
|
count += 1
|
||||||
if i == j:
|
if i == j:
|
||||||
@ -67,35 +65,39 @@ def answer_cleansing(
|
|||||||
|
|
||||||
# Clean up unwanted phrases in the prediction
|
# Clean up unwanted phrases in the prediction
|
||||||
for unwanted_phrase in [
|
for unwanted_phrase in [
|
||||||
"I understand",
|
'I understand',
|
||||||
"A through J",
|
'A through J',
|
||||||
"A through E",
|
'A through E',
|
||||||
"A through D",
|
'A through D',
|
||||||
]:
|
]:
|
||||||
prediction = prediction.replace(unwanted_phrase, "")
|
prediction = prediction.replace(unwanted_phrase, '')
|
||||||
|
|
||||||
options_num = len(options)
|
options_num = len(options)
|
||||||
options = [chr(65 + i) for i in range(options_num)]
|
options = [chr(65 + i) for i in range(options_num)]
|
||||||
options_str = r"\b(" + "|".join(options) + r")\b"
|
options_str = r'\b(' + '|'.join(options) + r')\b'
|
||||||
prediction = re.findall(options_str, prediction)
|
prediction = re.findall(options_str, prediction)
|
||||||
|
|
||||||
if len(prediction) == 0:
|
if len(prediction) == 0:
|
||||||
prediction = []
|
prediction = []
|
||||||
else:
|
else:
|
||||||
# If there is a "label" and its length is 1, process prediction accordingly
|
# If there is a "label" and its length is 1,
|
||||||
|
# process prediction accordingly
|
||||||
if len(label) == 1:
|
if len(label) == 1:
|
||||||
if method == "few-shot":
|
if method == 'few-shot':
|
||||||
answer_flag = True if len(prediction) > 1 else False
|
answer_flag = True if len(prediction) > 1 else False
|
||||||
# choose the first or last element based on the answer_flag
|
# choose the first or last element based on the answer_flag
|
||||||
prediction = [prediction[0]] if answer_flag else [prediction[-1]]
|
if answer_flag:
|
||||||
elif method == "zero-shot":
|
prediction = [prediction[0]]
|
||||||
|
else:
|
||||||
|
prediction = [prediction[-1]]
|
||||||
|
elif method == 'zero-shot':
|
||||||
# choose the first element in list
|
# choose the first element in list
|
||||||
prediction = [prediction[0]]
|
prediction = [prediction[0]]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Method is not properly defined ...")
|
raise ValueError('Method is not properly defined ...')
|
||||||
|
|
||||||
# Remove trailing period if it exists
|
# Remove trailing period if it exists
|
||||||
if prediction[0] and prediction[0].endswith("."):
|
if prediction[0] and prediction[0].endswith('.'):
|
||||||
prediction[0] = prediction[0][:-1]
|
prediction[0] = prediction[0][:-1]
|
||||||
|
|
||||||
return prediction[0]
|
return prediction[0]
|
||||||
@ -108,7 +110,11 @@ def _generic_llmjudge_postprocess(judgement: str):
|
|||||||
return grade_letter
|
return grade_letter
|
||||||
|
|
||||||
|
|
||||||
def MedXpertQA_llmjudge_postprocess(output: dict, output_path: str, dataset: Dataset) -> dict:
|
def MedXpertQA_llmjudge_postprocess(
|
||||||
|
output: dict,
|
||||||
|
output_path: str,
|
||||||
|
dataset: Dataset,
|
||||||
|
) -> dict:
|
||||||
# Get the original dataset
|
# Get the original dataset
|
||||||
original_dataset = dataset.reader.dataset['test']
|
original_dataset = dataset.reader.dataset['test']
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ from .math_intern import * # noqa: F401, F403
|
|||||||
from .mathbench import * # noqa: F401, F403
|
from .mathbench import * # noqa: F401, F403
|
||||||
from .mbpp import * # noqa: F401, F403
|
from .mbpp import * # noqa: F401, F403
|
||||||
from .medbench import * # noqa: F401, F403
|
from .medbench import * # noqa: F401, F403
|
||||||
from .MedXpertQA import * # noqa: F401, F403
|
from .MedXpertQA import * # noqa: F401, F403
|
||||||
from .mgsm import * # noqa: F401, F403
|
from .mgsm import * # noqa: F401, F403
|
||||||
from .mmlu import * # noqa: F401, F403
|
from .mmlu import * # noqa: F401, F403
|
||||||
from .mmlu_cf import * # noqa: F401, F403
|
from .mmlu_cf import * # noqa: F401, F403
|
||||||
|
Loading…
Reference in New Issue
Block a user