mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* Update with PMMEval * Update * Update __init__.py * Fix Bugs * Delete .pre-commit-config.yaml * Pull merge --------- Co-authored-by: liushz <qq1791167085@163.com>
153 lines
4.3 KiB
Python
Executable File
153 lines
4.3 KiB
Python
Executable File
import json
|
|
import os
|
|
import re
|
|
from typing import Tuple
|
|
|
|
from datasets import Dataset
|
|
|
|
from opencompass.datasets.base import BaseDataset
|
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
|
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
|
from opencompass.utils import get_data_path
|
|
|
|
langs_dict = {
|
|
'fr': ['La réponse est', 'la réponse est'],
|
|
'en': ['the answer is', 'The answer is'],
|
|
'vi': ['Câu trả lời là', 'câu trả lời là'],
|
|
'ar': ['الجواب هو'],
|
|
'th': ['คำตอบคือ'],
|
|
'zh': ['答案是'],
|
|
'ko': ['답변은'],
|
|
'pt': ['A resposta é'],
|
|
'ja': ['答えは'],
|
|
'es': ['La respuesta es']
|
|
}
|
|
|
|
|
|
def extract_choice(gen, lang):
|
|
r"""
|
|
{
|
|
"answer": "A|B|C|D"
|
|
}
|
|
"""
|
|
patterns = [
|
|
r"\{\s*?\"answer\"\s*?\:\s*?\"?(A|B|C|D).*?\"?\s*?\}",
|
|
r"\{\s*?[\'\"]answer[\'\"]\s*?\:\s*?[\'\"](A|B|C|D).*?[\'\"]\s*?\}",
|
|
r"\"answer\"\s*:\s*\"?(A|B|C|D)\"?",
|
|
r"[\'\"]answer[\'\"]\s*:\s*[\'\"](A|B|C|D)[\'\"]"
|
|
]
|
|
for pattern in patterns:
|
|
res = re.findall(pattern, gen, flags=re.DOTALL)
|
|
if len(res) >= 1:
|
|
return res[-1]
|
|
|
|
else:
|
|
res = None
|
|
pattern = langs_dict[lang]
|
|
for p in pattern:
|
|
if p in gen and p != gen:
|
|
res = gen.split(p)
|
|
if len(res) > 1 and len(res[-1].strip()) > 0:
|
|
res = res[-1].strip()[0]
|
|
else:
|
|
res = None
|
|
|
|
break
|
|
|
|
temp = ['A', 'B', 'C', 'D', 'a', 'b', 'c', 'd']
|
|
if res in temp:
|
|
return res
|
|
else:
|
|
return None
|
|
|
|
|
|
def extract_choice_fuzzy(gen):
|
|
options = ['A', 'B', 'C', 'D']
|
|
for option in options:
|
|
if option in gen:
|
|
return option
|
|
return None
|
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('pmmeval_mlogiqa')
|
|
def pmmeval_mlogiqa_postprocess(text: str, lang_code: str) -> Tuple[str]:
|
|
return text, lang_code
|
|
|
|
|
|
@LOAD_DATASET.register_module()
|
|
class PMMEvalMLogiQADataset(BaseDataset):
|
|
|
|
@staticmethod
|
|
def load(path: str, lang: str):
|
|
data_path = get_data_path(path)
|
|
|
|
if os.environ.get('DATASET_SOURCE') == 'ModelScope':
|
|
from modelscope import MsDataset
|
|
dataset = MsDataset.load(dataset_name=data_path,
|
|
subset_name='mlogiqa',
|
|
split=f'test/{lang}')
|
|
else:
|
|
dataset = list()
|
|
filename = os.path.join(data_path, f'mlogiqa/test/{lang}.jsonl')
|
|
with open(filename, mode='r', encoding='utf-8') as f:
|
|
for line in f:
|
|
line = json.loads(line.strip())
|
|
dataset.append(line)
|
|
dataset = Dataset.from_list(dataset)
|
|
|
|
return dataset
|
|
|
|
|
|
class PMMEvalMLogiQAEvaluator(BaseEvaluator):
|
|
|
|
def score(self, predictions, references):
|
|
assert len(predictions) == len(references)
|
|
|
|
all_results = list()
|
|
|
|
for (pred, lang), ref in zip(predictions, references):
|
|
answer = chr(int(ref) + 65)
|
|
pred = extract_choice(pred, lang)
|
|
acc = 0
|
|
failed_strict = 0
|
|
failed = 1
|
|
if pred is not None:
|
|
failed = 0
|
|
if answer.lower() == pred.lower():
|
|
acc = 1
|
|
else:
|
|
acc = 0
|
|
else:
|
|
pred_fuzzy = extract_choice_fuzzy(pred)
|
|
if pred_fuzzy is None:
|
|
acc = 0
|
|
failed_strict = 1
|
|
else:
|
|
failed_strict = 0
|
|
if answer.lower() == pred_fuzzy.lower():
|
|
acc = 1
|
|
else:
|
|
acc = 0
|
|
|
|
all_results.append({
|
|
'acc':
|
|
float(acc),
|
|
'failed':
|
|
float(failed),
|
|
'failed_strict':
|
|
float(failed_strict),
|
|
'extracted_answer':
|
|
pred if pred else 'no answer',
|
|
})
|
|
|
|
final_result = {
|
|
'accuracy':
|
|
round(
|
|
sum(x['acc'] for x in all_results) / len(all_results) * 100,
|
|
2),
|
|
'details':
|
|
all_results
|
|
}
|
|
|
|
return final_result
|