mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* Add model postprocess function * Add model postprocess function * Add model postprocess function * Add model postprocess function * Add model postprocess function * Add model postprocess function * Add model postprocess function * Add model postprocess function --------- Co-authored-by: liushz <liuhongwei@pjlab.rog.cn>
78 lines
2.9 KiB
Python
78 lines
2.9 KiB
Python
from functools import partial
|
|
from multiprocessing import Pool
|
|
from typing import Union
|
|
|
|
from tqdm import tqdm
|
|
|
|
from opencompass.registry import TEXT_POSTPROCESSORS
|
|
|
|
from .postprocessors.xfinder.extractor import Extractor
|
|
from .postprocessors.xfinder.xfinder_utils import (DataProcessor,
|
|
convert_to_xfinder_format)
|
|
|
|
|
|
def gen_output(ori_data, extractor):
|
|
ext_cor_pairs = []
|
|
extracted_data = []
|
|
extracted_answers = []
|
|
for item in tqdm(ori_data):
|
|
user_input = extractor.prepare_input(item)
|
|
extracted_answer = extractor.gen_output(user_input)
|
|
ext_cor_pairs.append([
|
|
item['key_answer_type'], item['standard_answer_range'],
|
|
extracted_answer, item['correct_answer']
|
|
])
|
|
item['xfinder_extracted_answer'] = extracted_answer
|
|
extracted_answers.append(extracted_answer)
|
|
extracted_data.append(item)
|
|
|
|
return extracted_answers, ext_cor_pairs, extracted_data
|
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('xfinder')
|
|
def xfinder_postprocess(preds: list, question_type: str,
|
|
xfinder_model_name: str,
|
|
xfiner_api_url: Union[str, list], **kwargs) -> list:
|
|
"""Postprocess the text extracted by xFinder model.
|
|
Args:
|
|
preds (list): The question, reference answer and model prediction.
|
|
question_type (str): The type of the question.
|
|
url (Union[str, list]): The api url of the xFinder model.
|
|
|
|
|
|
Returns:
|
|
list: The postprocessed texts.
|
|
"""
|
|
|
|
def _eval_pred(texts, data_processor, extractor, num_processes=8):
|
|
ori_data = data_processor.read_data(texts)
|
|
extracted_correct_pairs = []
|
|
extracted_data = []
|
|
extracted_answers = []
|
|
batched_ori_data = []
|
|
# Split data into batches
|
|
num_processes = min(num_processes, len(ori_data))
|
|
batch_size = len(ori_data) // num_processes
|
|
for i in range(0, len(ori_data), batch_size):
|
|
batched_ori_data.append(ori_data[i:i + batch_size])
|
|
with Pool(num_processes) as p:
|
|
results = p.map(partial(gen_output, extractor=extractor),
|
|
batched_ori_data)
|
|
for result in results:
|
|
extracted_answers += result[0]
|
|
extracted_correct_pairs += result[1]
|
|
extracted_data += result[2]
|
|
return extracted_answers
|
|
|
|
format_data = convert_to_xfinder_format(question_type, preds)
|
|
assert xfiner_api_url is not None, 'Please provide the api url.'
|
|
data_processor = DataProcessor()
|
|
extractor = Extractor(model_name=xfinder_model_name,
|
|
url=xfiner_api_url.split(',')
|
|
if ',' in xfiner_api_url else xfiner_api_url)
|
|
calc_acc_func = partial(_eval_pred,
|
|
data_processor=data_processor,
|
|
extractor=extractor)
|
|
extracted_answers = calc_acc_func(format_data)
|
|
return extracted_answers
|