mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
add HuStandardFIB under new paradigm (#3)
Co-authored-by: weixingjian <weixingjian@pjlab.org.cn>
This commit is contained in:
parent
6527fdf70a
commit
5f72e96d5b
@ -0,0 +1,50 @@
|
|||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
|
from opencompass.datasets.OpenHuEval.HuStandardFIB import HuStandardFIBDataset, HuStandardFIBEvaluator
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
from .HuStandardFIB_setting import INSTRUCTIONS, DATASET_PATH
|
||||||
|
|
||||||
|
ALL_LANGUAGES = ['hu']
|
||||||
|
PROMPT_VERSION = INSTRUCTIONS['version']
|
||||||
|
|
||||||
|
FIB2_reader_cfg = dict(input_columns=['question', 'subject'],
|
||||||
|
output_column='reference')
|
||||||
|
|
||||||
|
FIB2_datasets = []
|
||||||
|
for lan in ALL_LANGUAGES:
|
||||||
|
instruction = INSTRUCTIONS[lan]
|
||||||
|
FIB2_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(
|
||||||
|
begin='</E>',
|
||||||
|
round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt=instruction
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ice_token='</E>',
|
||||||
|
),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=GenInferencer),
|
||||||
|
)
|
||||||
|
|
||||||
|
FIB2_eval_cfg = dict(evaluator=dict(type=HuStandardFIBEvaluator))
|
||||||
|
|
||||||
|
FIB2_datasets.append(
|
||||||
|
dict(
|
||||||
|
abbr=f'HuStandardFIB-{lan}-1shot-{PROMPT_VERSION}',
|
||||||
|
type=HuStandardFIBDataset,
|
||||||
|
path=DATASET_PATH,
|
||||||
|
lan=lan,
|
||||||
|
reader_cfg=FIB2_reader_cfg,
|
||||||
|
infer_cfg=FIB2_infer_cfg,
|
||||||
|
eval_cfg=FIB2_eval_cfg,
|
||||||
|
)
|
||||||
|
)
|
@ -0,0 +1,15 @@
|
|||||||
|
INSTRUCTIONS = {
|
||||||
|
'hu': """The following question is in hungarian language on {subject}, please read the question, and try to fill in the blank in the sub question list. Please organize the answer in a list. An example:
|
||||||
|
{
|
||||||
|
"q_main": "Írd be a megfelelő meghatározás mellé a fogalmat!",
|
||||||
|
"q_sub": ["A.A szerzetesi közösségek szabályzatának elnevezése latinul: #0#", "B.Az első ún. kolduló rend: #1#", "C.A szerzetesek által kézzel másolt mű: #2#", "D.Papi nőtlenség: #3#", "E.A pápát megválasztó egyházi méltóságok: #4#", "F.A bencés rend megújítása ebben a kolostorban kezdődött a 10. században: #5#"],
|
||||||
|
"formatted_std_ans": ["#0#regula", "#1#ferencesrend;ferences", "#2#kódex", "#3#cölibátus", "#4#bíborosok;bíboros", "#5#Cluny"]
|
||||||
|
}
|
||||||
|
Now try to answer the following question, your response should be in a JSON format. Contain the std_ans like the case given above.
|
||||||
|
The question is: {question}.
|
||||||
|
""",
|
||||||
|
'version':'V1',
|
||||||
|
'description': 'Initial version, using 1shot, incontext, #0# as place holder, output in JSON format',
|
||||||
|
}
|
||||||
|
|
||||||
|
DATASET_PATH = "/mnt/hwfile/opendatalab/weixingjian/test/test2/"
|
@ -0,0 +1 @@
|
|||||||
|
from .HuStandardFIB import * # noqa: F401, F403
|
140
opencompass/datasets/OpenHuEval/HuStandardFIB.py
Normal file
140
opencompass/datasets/OpenHuEval/HuStandardFIB.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
from datasets import Dataset, DatasetDict
|
||||||
|
from fuzzywuzzy import fuzz
|
||||||
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||||
|
from ..base import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
class HuStandardFIBDataset(BaseDataset):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(**kwargs):
|
||||||
|
path = kwargs.get('path', None)
|
||||||
|
# lan = kwargs.get('lan', None)
|
||||||
|
dataset = DatasetDict()
|
||||||
|
file_list = [os.path.join(path, file) for file in os.listdir(path)
|
||||||
|
] # TODO only work for a single split.
|
||||||
|
f_path = file_list[0]
|
||||||
|
f = open(f_path, 'r', encoding='utf-8')
|
||||||
|
lines = f.readlines()
|
||||||
|
objs = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
obj = json.loads(line)
|
||||||
|
objs.append(obj)
|
||||||
|
|
||||||
|
out_dict_list = []
|
||||||
|
|
||||||
|
for obj in objs:
|
||||||
|
question = dict(q_main=obj['q_main'],
|
||||||
|
q_sub=obj['formatted_q_sub']) # TODO
|
||||||
|
subject = obj['major']
|
||||||
|
tmp = obj
|
||||||
|
new_obj = dict(question=question, subject=subject, reference=tmp)
|
||||||
|
out_dict_list.append(new_obj)
|
||||||
|
dataset = Dataset.from_list(out_dict_list)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
class HuStandardFIBEvaluator(BaseEvaluator):
|
||||||
|
"""
|
||||||
|
ref: opencompass.openicl.icl_evaluator.AccwithDetailsEvaluator
|
||||||
|
"""
|
||||||
|
|
||||||
|
def score(self, predictions, references, origin_prompt) -> dict:
|
||||||
|
if len(predictions) != len(references):
|
||||||
|
return {'error': 'preds and refers have different length.'}
|
||||||
|
details = {}
|
||||||
|
blank_correct, blank_total = 0, 0
|
||||||
|
question_correct, question_total = 0, 0
|
||||||
|
|
||||||
|
for idx, (pred, refer, prompt) in enumerate(
|
||||||
|
zip(predictions, references, origin_prompt)):
|
||||||
|
std_ans = [
|
||||||
|
re.sub(r'#\d+#', '', ans).split(';')
|
||||||
|
for ans in refer['formatted_std_ans']
|
||||||
|
] # Remove "#0#" and "#1#", then split
|
||||||
|
# refer['formatted_std_ans']
|
||||||
|
model_ans = []
|
||||||
|
pred = pred.strip()
|
||||||
|
match = re.search(r'\{.*?\}', pred, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
json_str = match.group(0)
|
||||||
|
else:
|
||||||
|
blank_total += len(std_ans)
|
||||||
|
question_total += 1
|
||||||
|
details[idx] = {
|
||||||
|
'detail': refer,
|
||||||
|
'model_ans': model_ans,
|
||||||
|
'gt': std_ans,
|
||||||
|
'prompt': prompt,
|
||||||
|
'raw_pred': pred,
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
json_str = json_str.strip()
|
||||||
|
json_str = json_str.replace('\\xa0', '')
|
||||||
|
formatted_json_str = json_str
|
||||||
|
|
||||||
|
to_end_flag = False
|
||||||
|
if isinstance(formatted_json_str, str):
|
||||||
|
try:
|
||||||
|
data = json.loads(formatted_json_str)
|
||||||
|
to_end_flag = True
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f'Invalid JSON format. {idx}')
|
||||||
|
blank_total += len(std_ans)
|
||||||
|
question_total += 1
|
||||||
|
|
||||||
|
elif isinstance(formatted_json_str, dict):
|
||||||
|
data = formatted_json_str
|
||||||
|
to_end_flag = True
|
||||||
|
else:
|
||||||
|
blank_total += len(std_ans)
|
||||||
|
question_total += 1
|
||||||
|
|
||||||
|
model_ans = []
|
||||||
|
if to_end_flag:
|
||||||
|
model_ans = [
|
||||||
|
re.sub(r'#\d+#', '', ans).split(';')
|
||||||
|
for ans in data.get('formatted_std_ans', [])
|
||||||
|
] # Preprocess model_ans in the same way as std_ans
|
||||||
|
is_question_correct = True
|
||||||
|
for idx, ans_list in enumerate(std_ans):
|
||||||
|
if idx >= len(model_ans):
|
||||||
|
is_question_correct = False
|
||||||
|
break
|
||||||
|
|
||||||
|
model_list = model_ans[idx]
|
||||||
|
for ans in ans_list:
|
||||||
|
best_match = max(
|
||||||
|
model_list,
|
||||||
|
key=lambda model: fuzz.ratio(ans, model))
|
||||||
|
if fuzz.ratio(ans, best_match) > 70: # check threshold
|
||||||
|
blank_correct += 1
|
||||||
|
else:
|
||||||
|
is_question_correct = False
|
||||||
|
|
||||||
|
blank_total += len(std_ans)
|
||||||
|
question_total += 1
|
||||||
|
question_correct += 1 if is_question_correct else 0
|
||||||
|
|
||||||
|
details[idx] = {
|
||||||
|
'detail': refer,
|
||||||
|
'model_ans': model_ans,
|
||||||
|
'gt': std_ans,
|
||||||
|
'prompt': prompt,
|
||||||
|
'raw_pred': pred,
|
||||||
|
}
|
||||||
|
results = {
|
||||||
|
'blank_level_correctness':
|
||||||
|
round(blank_correct / blank_total * 100, 2),
|
||||||
|
'question_level_correctness':
|
||||||
|
round(question_correct / question_total * 100, 2),
|
||||||
|
'details':
|
||||||
|
details
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
@ -1 +1,2 @@
|
|||||||
from .HuMatchingFIB import * # noqa: F401, F403
|
from .HuMatchingFIB import * # noqa: F401, F403
|
||||||
|
from .HuStandardFIB import * # noqa: F401, F403
|
Loading…
Reference in New Issue
Block a user