mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Refactor] Refactorize openicl eval task (#1990)
* [Refactor] Refactorize openicl eval task * update
This commit is contained in:
parent
6ac9b06bc2
commit
12213207b6
@ -91,7 +91,8 @@ class BaseEvaluator:
|
|||||||
):
|
):
|
||||||
# Check if predictions and references have the
|
# Check if predictions and references have the
|
||||||
# same length if both are provided
|
# same length if both are provided
|
||||||
if 'predictions' in score_kwargs and 'references' in score_kwargs:
|
if ('predictions' in score_kwargs and 'references' in score_kwargs
|
||||||
|
and score_kwargs['references'] is not None):
|
||||||
if len(score_kwargs['predictions']) != len(
|
if len(score_kwargs['predictions']) != len(
|
||||||
score_kwargs['references']):
|
score_kwargs['references']):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -7,7 +7,6 @@ import random
|
|||||||
import statistics
|
import statistics
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from collections import Counter
|
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@ -19,7 +18,7 @@ from opencompass.registry import (ICL_EVALUATORS, MODELS, TASKS,
|
|||||||
TEXT_POSTPROCESSORS)
|
TEXT_POSTPROCESSORS)
|
||||||
from opencompass.tasks.base import BaseTask, extract_role_pred
|
from opencompass.tasks.base import BaseTask, extract_role_pred
|
||||||
from opencompass.utils import (build_dataset_from_cfg, get_infer_output_path,
|
from opencompass.utils import (build_dataset_from_cfg, get_infer_output_path,
|
||||||
get_logger, task_abbr_from_cfg)
|
get_logger)
|
||||||
|
|
||||||
|
|
||||||
@TASKS.register_module()
|
@TASKS.register_module()
|
||||||
@ -86,6 +85,26 @@ class OpenICLEvalTask(BaseTask):
|
|||||||
self._score()
|
self._score()
|
||||||
|
|
||||||
def _score(self):
|
def _score(self):
|
||||||
|
# Load and preprocess test data
|
||||||
|
test_set = self._load_and_preprocess_test_data()
|
||||||
|
# Load predictions
|
||||||
|
pred_dicts, pred_strs = self._load_predictions()
|
||||||
|
|
||||||
|
# Process predictions
|
||||||
|
pred_strs = self._process_predictions(pred_strs)
|
||||||
|
|
||||||
|
# Evaluate predictions
|
||||||
|
result = self._evaluate_predictions(
|
||||||
|
pred_strs,
|
||||||
|
test_set,
|
||||||
|
pred_dicts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save results
|
||||||
|
self._save_results(result)
|
||||||
|
|
||||||
|
def _load_and_preprocess_test_data(self):
|
||||||
|
"""Load test dataset and apply postprocessing if needed."""
|
||||||
test_set = build_dataset_from_cfg(self.dataset_cfg).test
|
test_set = build_dataset_from_cfg(self.dataset_cfg).test
|
||||||
# Postprocess dataset if necessary
|
# Postprocess dataset if necessary
|
||||||
if 'dataset_postprocessor' in self.eval_cfg:
|
if 'dataset_postprocessor' in self.eval_cfg:
|
||||||
@ -100,7 +119,10 @@ class OpenICLEvalTask(BaseTask):
|
|||||||
|
|
||||||
test_set = test_set.map(postprocess)
|
test_set = test_set.map(postprocess)
|
||||||
|
|
||||||
# Load predictions
|
return test_set
|
||||||
|
|
||||||
|
def _load_predictions(self):
|
||||||
|
"""Load model predictions from files."""
|
||||||
filename = get_infer_output_path(
|
filename = get_infer_output_path(
|
||||||
self.model_cfg,
|
self.model_cfg,
|
||||||
self.dataset_cfg,
|
self.dataset_cfg,
|
||||||
@ -110,217 +132,188 @@ class OpenICLEvalTask(BaseTask):
|
|||||||
root, ext = osp.splitext(filename)
|
root, ext = osp.splitext(filename)
|
||||||
partial_filename = root + '_0' + ext
|
partial_filename = root + '_0' + ext
|
||||||
|
|
||||||
# Get sc_size if use Self-Consistency
|
|
||||||
sc_size = self.eval_cfg.get('sc_size')
|
|
||||||
|
|
||||||
if not osp.exists(osp.realpath(filename)) and not osp.exists(
|
if not osp.exists(osp.realpath(filename)) and not osp.exists(
|
||||||
osp.realpath(partial_filename)):
|
osp.realpath(partial_filename)):
|
||||||
result = {'error': 'No predictions found.'}
|
raise FileNotFoundError(
|
||||||
|
f'Prediction files not found: neither {filename} '
|
||||||
|
f'nor {partial_filename} exists')
|
||||||
|
|
||||||
|
if osp.exists(osp.realpath(filename)):
|
||||||
|
preds = mmengine.load(filename)
|
||||||
|
preds = [preds[str(i)] for i in range(len(preds))]
|
||||||
else:
|
else:
|
||||||
if osp.exists(osp.realpath(filename)):
|
filename = partial_filename
|
||||||
preds = mmengine.load(filename)
|
preds = []
|
||||||
preds = [preds[str(i)] for i in range(len(preds))]
|
i = 1
|
||||||
|
while osp.exists(osp.realpath(filename)):
|
||||||
|
sub_preds = mmengine.load(filename)
|
||||||
|
preds.extend(
|
||||||
|
[sub_preds[str(i)] for i in range(len(sub_preds))])
|
||||||
|
filename = root + f'_{i}' + ext
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
pred_dicts = copy.deepcopy(preds)
|
||||||
|
preds = {k: [pred.get(k) for pred in preds] for k in preds[0]}
|
||||||
|
|
||||||
|
pred_strs = preds.pop('prediction', None)
|
||||||
|
|
||||||
|
return pred_dicts, pred_strs
|
||||||
|
|
||||||
|
def _process_predictions(self, pred_strs):
|
||||||
|
"""Apply various processing steps to predictions."""
|
||||||
|
# Check if we're dealing with a list of lists (pred_list_flag)
|
||||||
|
pred_list_flag = pred_strs is not None and isinstance(
|
||||||
|
pred_strs[0], list)
|
||||||
|
|
||||||
|
# Extract role predictions if needed
|
||||||
|
if ('pred_role' in self.eval_cfg and 'meta_template' in self.model_cfg
|
||||||
|
and not MODELS.get(self.model_cfg['type']).is_api):
|
||||||
|
# Create a prompt template for role config parsing
|
||||||
|
from opencompass.models.base import LMTemplateParser
|
||||||
|
|
||||||
|
parser = LMTemplateParser(self.model_cfg['meta_template'])
|
||||||
|
role = parser.roles[self.eval_cfg['pred_role']]
|
||||||
|
if pred_list_flag:
|
||||||
|
pred_strs = [[
|
||||||
|
extract_role_pred(
|
||||||
|
_pred,
|
||||||
|
role.get('begin', None),
|
||||||
|
role.get('end', None),
|
||||||
|
) for _pred in pred
|
||||||
|
] for pred in pred_strs]
|
||||||
else:
|
else:
|
||||||
filename = partial_filename
|
|
||||||
preds = []
|
|
||||||
i = 1
|
|
||||||
while osp.exists(osp.realpath(filename)):
|
|
||||||
sub_preds = mmengine.load(filename)
|
|
||||||
preds.extend(
|
|
||||||
[sub_preds[str(i)] for i in range(len(sub_preds))])
|
|
||||||
filename = root + f'_{i}' + ext
|
|
||||||
i += 1
|
|
||||||
pred_dicts = copy.deepcopy(preds)
|
|
||||||
preds = {k: [pred.get(k) for pred in preds] for k in preds[0]}
|
|
||||||
|
|
||||||
pred_strs = preds.pop('prediction', None)
|
|
||||||
pred_list_flag = pred_strs is not None and isinstance(
|
|
||||||
pred_strs[0], list)
|
|
||||||
if ('pred_role' in self.eval_cfg
|
|
||||||
and 'meta_template' in self.model_cfg
|
|
||||||
and not MODELS.get(self.model_cfg['type']).is_api):
|
|
||||||
# Create a prompt template for role config parsing
|
|
||||||
from opencompass.models.base import LMTemplateParser
|
|
||||||
|
|
||||||
parser = LMTemplateParser(self.model_cfg['meta_template'])
|
|
||||||
role = parser.roles[self.eval_cfg['pred_role']]
|
|
||||||
if sc_size is not None:
|
|
||||||
assert pred_list_flag, (
|
|
||||||
'The prediction for Self-Consistency'
|
|
||||||
'must be list.')
|
|
||||||
if pred_list_flag:
|
|
||||||
pred_strs = [[
|
|
||||||
extract_role_pred(
|
|
||||||
_pred,
|
|
||||||
role.get('begin', None),
|
|
||||||
role.get('end', None),
|
|
||||||
) for _pred in pred
|
|
||||||
] for pred in pred_strs]
|
|
||||||
else:
|
|
||||||
pred_strs = [
|
|
||||||
extract_role_pred(
|
|
||||||
pred,
|
|
||||||
role.get('begin', None),
|
|
||||||
role.get('end', None),
|
|
||||||
) for pred in pred_strs
|
|
||||||
]
|
|
||||||
|
|
||||||
# Postprocess predictions if necessary
|
|
||||||
# Model Specified Postprocessor
|
|
||||||
if 'pred_postprocessor' in self.model_cfg:
|
|
||||||
kwargs = copy.deepcopy(self.model_cfg['pred_postprocessor'])
|
|
||||||
proc = kwargs.pop('type')
|
|
||||||
if isinstance(proc, str):
|
|
||||||
proc = TEXT_POSTPROCESSORS.get(proc)
|
|
||||||
if pred_list_flag:
|
|
||||||
pred_strs = [[proc(s, **kwargs) for s in preds]
|
|
||||||
for preds in pred_strs]
|
|
||||||
else:
|
|
||||||
pred_strs = [proc(s, **kwargs) for s in pred_strs]
|
|
||||||
# Dataset Specified Postprocessor
|
|
||||||
if 'pred_postprocessor' in self.eval_cfg:
|
|
||||||
kwargs = copy.deepcopy(self.eval_cfg['pred_postprocessor'])
|
|
||||||
proc = kwargs.pop('type')
|
|
||||||
if isinstance(proc, str):
|
|
||||||
proc = TEXT_POSTPROCESSORS.get(proc)
|
|
||||||
if pred_list_flag:
|
|
||||||
pred_strs = [[proc(s, **kwargs) for s in preds]
|
|
||||||
for preds in pred_strs]
|
|
||||||
else:
|
|
||||||
pred_strs = [proc(s, **kwargs) for s in pred_strs]
|
|
||||||
|
|
||||||
model_pred_strs = []
|
|
||||||
if 'model_postprocessor' in self.eval_cfg:
|
|
||||||
references = (test_set[self.output_column]
|
|
||||||
if self.output_column else None)
|
|
||||||
model_pred_dicts = copy.deepcopy(pred_dicts)
|
|
||||||
for i, pred_dict in enumerate(model_pred_dicts):
|
|
||||||
pred_dict['reference'] = [references[i]]
|
|
||||||
self.logger.info('Postprocessing model predictions...')
|
|
||||||
kwargs = self.eval_cfg['model_postprocessor']
|
|
||||||
proc = kwargs.pop('type')
|
|
||||||
if isinstance(proc, str):
|
|
||||||
proc = TEXT_POSTPROCESSORS.get(proc)
|
|
||||||
if pred_list_flag:
|
|
||||||
model_pred_strs = [[
|
|
||||||
proc(model_pred_dict, **kwargs)
|
|
||||||
for model_pred_dict in model_pred_dicts
|
|
||||||
]]
|
|
||||||
else:
|
|
||||||
model_pred_strs = proc(model_pred_dicts, **kwargs)
|
|
||||||
|
|
||||||
# Get majority voting predictions if use self-consistency
|
|
||||||
if sc_size is not None:
|
|
||||||
pred_strs = [
|
pred_strs = [
|
||||||
Counter(s).most_common(1)[0][0] for s in pred_strs
|
extract_role_pred(
|
||||||
|
pred,
|
||||||
|
role.get('begin', None),
|
||||||
|
role.get('end', None),
|
||||||
|
) for pred in pred_strs
|
||||||
]
|
]
|
||||||
|
|
||||||
icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator'])
|
# Apply postprocessors if configured
|
||||||
# need results dir to save other files
|
# Postprocess predictions if necessary
|
||||||
out_path = get_infer_output_path(
|
# Model Specified Postprocessor
|
||||||
self.model_cfg,
|
if 'pred_postprocessor' in self.model_cfg:
|
||||||
self.dataset_cfg,
|
kwargs = copy.deepcopy(self.model_cfg['pred_postprocessor'])
|
||||||
osp.join(self.work_dir, 'results'),
|
proc = kwargs.pop('type')
|
||||||
)
|
if isinstance(proc, str):
|
||||||
icl_evaluator._out_dir = osp.splitext(out_path)[
|
proc = TEXT_POSTPROCESSORS.get(proc)
|
||||||
0] # strip extension
|
if pred_list_flag:
|
||||||
|
pred_strs = [[proc(s, **kwargs) for s in preds]
|
||||||
preds['predictions'] = pred_strs
|
for preds in pred_strs]
|
||||||
preds['references'] = (test_set[self.output_column]
|
|
||||||
if self.output_column else None)
|
|
||||||
preds['test_set'] = test_set
|
|
||||||
if 'origin_prompt' not in preds:
|
|
||||||
try:
|
|
||||||
preds['origin_prompt'] = [
|
|
||||||
None for _ in range(len(pred_strs))
|
|
||||||
]
|
|
||||||
except TypeError:
|
|
||||||
preds['origin_prompt'] = None
|
|
||||||
preds = {
|
|
||||||
k: preds[k]
|
|
||||||
for k in signature(icl_evaluator.score).parameters
|
|
||||||
}
|
|
||||||
k = self.dataset_cfg.get('k', 1)
|
|
||||||
n = self.dataset_cfg.get('n', 1)
|
|
||||||
result = icl_evaluator.evaluate(k, n, copy.deepcopy(test_set),
|
|
||||||
**preds)
|
|
||||||
|
|
||||||
# Get model postprocess result
|
|
||||||
model_details = None
|
|
||||||
model_result = None
|
|
||||||
if 'model_postprocessor' in self.eval_cfg:
|
|
||||||
model_preds = copy.deepcopy(preds)
|
|
||||||
model_preds['predictions'] = model_pred_strs
|
|
||||||
model_result = icl_evaluator.evaluate(k, n,
|
|
||||||
copy.deepcopy(test_set),
|
|
||||||
**model_preds)
|
|
||||||
for key in model_result:
|
|
||||||
if key == 'details':
|
|
||||||
model_details = model_result[key]
|
|
||||||
continue
|
|
||||||
new_key = 'model_postprocess_' + key
|
|
||||||
result[new_key] = model_result[key]
|
|
||||||
|
|
||||||
if self.dump_details:
|
|
||||||
details = result.get('details', None)
|
|
||||||
# Try to format details is details is not provided by evaluator
|
|
||||||
if details is None:
|
|
||||||
self.logger.info(
|
|
||||||
'Details is not give by evaluator, try to format it')
|
|
||||||
try:
|
|
||||||
result['details'] = self.format_details(
|
|
||||||
pred_strs,
|
|
||||||
model_pred_strs,
|
|
||||||
test_set[self.output_column],
|
|
||||||
details,
|
|
||||||
model_details,
|
|
||||||
pred_dicts,
|
|
||||||
)
|
|
||||||
self.logger.warning(
|
|
||||||
f"result['details'] : {result['details']}"),
|
|
||||||
result['type'] = result['details'].pop('type', None)
|
|
||||||
if self.cal_extract_rate:
|
|
||||||
# Calculate the extraction success
|
|
||||||
# rate for prediction
|
|
||||||
result['extract_rate'] = self.extract_rate(result)
|
|
||||||
|
|
||||||
if 'PPL' in str(
|
|
||||||
self.dataset_cfg.infer_cfg.inferencer.type):
|
|
||||||
result['correct_bpb'], result['incorrect_bpb'] = (
|
|
||||||
self.calculate_bpb(pred_dicts))
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(
|
|
||||||
f'Skip dumping details due to: {e}.')
|
|
||||||
else:
|
else:
|
||||||
result.pop('details', None)
|
pred_strs = [proc(s, **kwargs) for s in pred_strs]
|
||||||
|
|
||||||
if 'error' in result:
|
# Dataset Specified Postprocessor
|
||||||
self.logger.error(
|
if 'pred_postprocessor' in self.eval_cfg:
|
||||||
f'Task {task_abbr_from_cfg(self.cfg)}: {result["error"]}')
|
kwargs = copy.deepcopy(self.eval_cfg['pred_postprocessor'])
|
||||||
return
|
proc = kwargs.pop('type')
|
||||||
elif model_result is None:
|
if isinstance(proc, str):
|
||||||
result_wo_details = {
|
proc = TEXT_POSTPROCESSORS.get(proc)
|
||||||
i: result[i]
|
if pred_list_flag:
|
||||||
for i in result if i != 'details'
|
pred_strs = [[proc(s, **kwargs) for s in preds]
|
||||||
}
|
for preds in pred_strs]
|
||||||
self.logger.info(
|
else:
|
||||||
f'Task {task_abbr_from_cfg(self.cfg)}: {result_wo_details}')
|
pred_strs = [proc(s, **kwargs) for s in pred_strs]
|
||||||
|
|
||||||
|
return pred_strs
|
||||||
|
|
||||||
|
def _evaluate_predictions(
|
||||||
|
self,
|
||||||
|
pred_strs,
|
||||||
|
test_set,
|
||||||
|
pred_dicts,
|
||||||
|
):
|
||||||
|
"""Evaluate predictions using the configured evaluator."""
|
||||||
|
# Get references from test set
|
||||||
|
references = (None if self.output_column is None else
|
||||||
|
[sample[self.output_column] for sample in test_set])
|
||||||
|
# Build evaluator from config
|
||||||
|
evaluator_cfg = self.eval_cfg.get('evaluator', {})
|
||||||
|
evaluator_type = evaluator_cfg.get('type')
|
||||||
|
if isinstance(evaluator_type, str):
|
||||||
|
evaluator_type = ICL_EVALUATORS.get(evaluator_type)
|
||||||
|
|
||||||
|
# Prepare evaluator inputs
|
||||||
|
evaluator_cfg_copy = copy.deepcopy(evaluator_cfg)
|
||||||
|
evaluator_cfg_copy.pop('type', None)
|
||||||
|
# Initialize evaluator with appropriate parameters
|
||||||
|
sig = signature(evaluator_type)
|
||||||
|
if 'predictions' in sig.parameters and 'references' in sig.parameters:
|
||||||
|
evaluator = evaluator_type(
|
||||||
|
predictions=pred_strs,
|
||||||
|
references=references,
|
||||||
|
**evaluator_cfg_copy,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
result_wo_details = {
|
evaluator = evaluator_type(**evaluator_cfg_copy)
|
||||||
i: result[i]
|
|
||||||
for i in result if i != 'details'
|
|
||||||
}
|
|
||||||
model_result_wo_details = {
|
|
||||||
i: model_result[i]
|
|
||||||
for i in model_result if i != 'details'
|
|
||||||
}
|
|
||||||
self.logger.info(
|
|
||||||
f'Task {task_abbr_from_cfg(self.cfg)}: {result_wo_details}')
|
|
||||||
self.logger.info(
|
|
||||||
'Model Postprocess Task: ' +
|
|
||||||
f'{task_abbr_from_cfg(self.cfg)}:{model_result_wo_details}')
|
|
||||||
|
|
||||||
# Save result
|
# Set output directory for the evaluator
|
||||||
|
out_path = get_infer_output_path(
|
||||||
|
self.model_cfg,
|
||||||
|
self.dataset_cfg,
|
||||||
|
osp.join(self.work_dir, 'results'),
|
||||||
|
)
|
||||||
|
evaluator._out_dir = osp.splitext(out_path)[0] # strip extension
|
||||||
|
|
||||||
|
# If preds contains keys that match the score method
|
||||||
|
# parameters, include them
|
||||||
|
if pred_dicts:
|
||||||
|
preds = {
|
||||||
|
k: [pred.get(k) for pred in pred_dicts]
|
||||||
|
for k in pred_dicts[0]
|
||||||
|
}
|
||||||
|
# Add predictions and references if they're expected
|
||||||
|
# by the score method
|
||||||
|
preds['predictions'] = pred_strs
|
||||||
|
preds['references'] = (test_set[self.output_column]
|
||||||
|
if self.output_column else None)
|
||||||
|
preds['test_set'] = test_set
|
||||||
|
if 'origin_prompt' not in preds:
|
||||||
|
try:
|
||||||
|
preds['origin_prompt'] = [None for _ in range(len(pred_strs))]
|
||||||
|
except TypeError:
|
||||||
|
preds['origin_prompt'] = None
|
||||||
|
preds = {k: preds[k] for k in signature(evaluator.score).parameters}
|
||||||
|
# Call evaluate with the appropriate parameters
|
||||||
|
k = self.dataset_cfg.get('k', 1)
|
||||||
|
n = self.dataset_cfg.get('n', 1)
|
||||||
|
result = evaluator.evaluate(k, n, copy.deepcopy(test_set), **preds)
|
||||||
|
|
||||||
|
# Format details if needed
|
||||||
|
if self.dump_details:
|
||||||
|
# Get detailed results if available
|
||||||
|
details = result.get('details', None)
|
||||||
|
if details is None:
|
||||||
|
self.logger.info(
|
||||||
|
'Details is not give by evaluator, try to format it')
|
||||||
|
try:
|
||||||
|
result['details'] = self.format_details(
|
||||||
|
pred_strs,
|
||||||
|
references,
|
||||||
|
details,
|
||||||
|
pred_dicts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate extraction rate if needed
|
||||||
|
if self.cal_extract_rate and details is not None:
|
||||||
|
result['extract_rate'] = self.extract_rate(result)
|
||||||
|
|
||||||
|
# Calculate BPB if applicable
|
||||||
|
if pred_dicts and 'BPB' in pred_dicts[0].get(
|
||||||
|
list(pred_dicts[0].keys())[0], {}):
|
||||||
|
correct_bpb, incorrect_bpb = self.calculate_bpb(
|
||||||
|
pred_dicts)
|
||||||
|
result['correct_bpb'] = correct_bpb
|
||||||
|
result['incorrect_bpb'] = incorrect_bpb
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f'Skip dumping details due to: {e}.')
|
||||||
|
else:
|
||||||
|
result.pop('details', None)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _save_results(self, result):
|
||||||
|
"""Save evaluation results to file."""
|
||||||
out_path = get_infer_output_path(
|
out_path = get_infer_output_path(
|
||||||
self.model_cfg,
|
self.model_cfg,
|
||||||
self.dataset_cfg,
|
self.dataset_cfg,
|
||||||
@ -351,10 +344,8 @@ class OpenICLEvalTask(BaseTask):
|
|||||||
def format_details(
|
def format_details(
|
||||||
self,
|
self,
|
||||||
predictions,
|
predictions,
|
||||||
model_pred_strs,
|
|
||||||
references,
|
references,
|
||||||
details,
|
details,
|
||||||
model_details,
|
|
||||||
pred_dicts,
|
pred_dicts,
|
||||||
):
|
):
|
||||||
"""This function is responsible for formatting prediction details.
|
"""This function is responsible for formatting prediction details.
|
||||||
@ -393,20 +384,6 @@ class OpenICLEvalTask(BaseTask):
|
|||||||
result['predictions'] = str(predictions[i])
|
result['predictions'] = str(predictions[i])
|
||||||
result['references'] = str(references[i])
|
result['references'] = str(references[i])
|
||||||
result['correct'] = str(predictions[i]) == str(references[i])
|
result['correct'] = str(predictions[i]) == str(references[i])
|
||||||
elif details is not None and model_details is not None:
|
|
||||||
assert (
|
|
||||||
model_pred_strs != []
|
|
||||||
), 'Model details is not None, but model_pred_strs is empty'
|
|
||||||
self.logger.info(
|
|
||||||
f"model_details[i]['pred']: {model_details[i]['pred']}")
|
|
||||||
results['type'] = 'GEN'
|
|
||||||
result['prompt'] = origin_prediction['origin_prompt']
|
|
||||||
result['origin_prediction'] = pred_dicts[i]['prediction']
|
|
||||||
result['predictions'] = details[i]['pred']
|
|
||||||
result['model_extract_predictions'] = model_details[i]['pred']
|
|
||||||
result['references'] = details[i]['answer']
|
|
||||||
result['correct'] = details[i]['correct']
|
|
||||||
result['model_extract_correct'] = model_details[i]['correct']
|
|
||||||
elif details is not None:
|
elif details is not None:
|
||||||
results['type'] = 'GEN'
|
results['type'] = 'GEN'
|
||||||
result['prompt'] = origin_prediction['origin_prompt']
|
result['prompt'] = origin_prediction['origin_prompt']
|
||||||
|
@ -10,9 +10,7 @@ from .fileio import * # noqa
|
|||||||
from .lark import * # noqa
|
from .lark import * # noqa
|
||||||
from .logging import * # noqa
|
from .logging import * # noqa
|
||||||
from .menu import * # noqa
|
from .menu import * # noqa
|
||||||
from .model_postprocessors import * # noqa
|
|
||||||
from .network import * # noqa
|
from .network import * # noqa
|
||||||
from .postprocessors import * # noqa
|
|
||||||
from .prompt import * # noqa
|
from .prompt import * # noqa
|
||||||
from .result_station import * # noqa
|
from .result_station import * # noqa
|
||||||
from .text_postprocessors import * # noqa
|
from .text_postprocessors import * # noqa
|
||||||
|
@ -1,135 +0,0 @@
|
|||||||
from functools import partial
|
|
||||||
from multiprocessing import Pool
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from opencompass.registry import TEXT_POSTPROCESSORS
|
|
||||||
|
|
||||||
from .postprocessors.naive import NaiveExtractor, format_input_naive
|
|
||||||
from .postprocessors.xfinder.extractor import Extractor
|
|
||||||
from .postprocessors.xfinder.xfinder_utils import (DataProcessor,
|
|
||||||
convert_to_xfinder_format)
|
|
||||||
|
|
||||||
|
|
||||||
def gen_output_naive(ori_data, extractor):
|
|
||||||
extracted_answers = []
|
|
||||||
for item in tqdm(ori_data):
|
|
||||||
user_input = extractor.prepare_input(item)
|
|
||||||
extracted_answer = extractor.gen_output(user_input)
|
|
||||||
item['extracted_answer'] = extracted_answer
|
|
||||||
extracted_answers.append(extracted_answer)
|
|
||||||
|
|
||||||
return extracted_answers
|
|
||||||
|
|
||||||
|
|
||||||
@TEXT_POSTPROCESSORS.register_module('naive')
|
|
||||||
def naive_model_postprocess(preds: list,
|
|
||||||
model_name: str,
|
|
||||||
custom_instruction: str,
|
|
||||||
api_url: Union[str, list],
|
|
||||||
num_processes: int = 8,
|
|
||||||
**kwargs) -> list:
|
|
||||||
"""Postprocess the text extracted by custom model.
|
|
||||||
Args:
|
|
||||||
preds (list): The question, reference answer and model prediction.
|
|
||||||
model_name (str): The name of the model.
|
|
||||||
custom_instruction (str): Custom instruction for the dataset.
|
|
||||||
url (Union[str, list]): The api url of the model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: The postprocessed answers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _eval_pred(texts, extractor, num_processes):
|
|
||||||
ori_data = texts
|
|
||||||
extracted_answers = []
|
|
||||||
batched_ori_data = []
|
|
||||||
# Split data into batches
|
|
||||||
num_processes = min(num_processes, len(ori_data))
|
|
||||||
batch_size = len(ori_data) // num_processes
|
|
||||||
for i in range(0, len(ori_data), batch_size):
|
|
||||||
batched_ori_data.append(ori_data[i:i + batch_size])
|
|
||||||
with Pool(num_processes) as p:
|
|
||||||
results = p.map(partial(gen_output_naive, extractor=extractor),
|
|
||||||
batched_ori_data)
|
|
||||||
for result in results:
|
|
||||||
extracted_answers.extend(result)
|
|
||||||
return extracted_answers
|
|
||||||
|
|
||||||
format_data = format_input_naive(preds)
|
|
||||||
assert api_url is not None, 'Please provide the api url.'
|
|
||||||
extractor = NaiveExtractor(
|
|
||||||
model_name=model_name,
|
|
||||||
custom_instruction=custom_instruction,
|
|
||||||
url=api_url.split(',') if ',' in api_url else api_url)
|
|
||||||
calc_acc_func = partial(_eval_pred,
|
|
||||||
extractor=extractor,
|
|
||||||
num_processes=num_processes)
|
|
||||||
extracted_answers = calc_acc_func(format_data)
|
|
||||||
return extracted_answers
|
|
||||||
|
|
||||||
|
|
||||||
def gen_output_xfinder(ori_data, extractor):
|
|
||||||
ext_cor_pairs = []
|
|
||||||
extracted_data = []
|
|
||||||
extracted_answers = []
|
|
||||||
for item in tqdm(ori_data):
|
|
||||||
user_input = extractor.prepare_input(item)
|
|
||||||
extracted_answer = extractor.gen_output(user_input)
|
|
||||||
ext_cor_pairs.append([
|
|
||||||
item['key_answer_type'], item['standard_answer_range'],
|
|
||||||
extracted_answer, item['correct_answer']
|
|
||||||
])
|
|
||||||
item['xfinder_extracted_answer'] = extracted_answer
|
|
||||||
extracted_answers.append(extracted_answer)
|
|
||||||
extracted_data.append(item)
|
|
||||||
|
|
||||||
return extracted_answers, ext_cor_pairs, extracted_data
|
|
||||||
|
|
||||||
|
|
||||||
@TEXT_POSTPROCESSORS.register_module('xfinder')
|
|
||||||
def xfinder_postprocess(preds: list, question_type: str, model_name: str,
|
|
||||||
api_url: Union[str, list], **kwargs) -> list:
|
|
||||||
"""Postprocess the text extracted by xFinder model.
|
|
||||||
Args:
|
|
||||||
preds (list): The question, reference answer and model prediction.
|
|
||||||
question_type (str): The type of the question.
|
|
||||||
url (Union[str, list]): The api url of the xFinder model.
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: The postprocessed texts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _eval_pred(texts, data_processor, extractor, num_processes=8):
|
|
||||||
ori_data = data_processor.read_data(texts)
|
|
||||||
extracted_correct_pairs = []
|
|
||||||
extracted_data = []
|
|
||||||
extracted_answers = []
|
|
||||||
batched_ori_data = []
|
|
||||||
# Split data into batches
|
|
||||||
num_processes = min(num_processes, len(ori_data))
|
|
||||||
batch_size = len(ori_data) // num_processes
|
|
||||||
for i in range(0, len(ori_data), batch_size):
|
|
||||||
batched_ori_data.append(ori_data[i:i + batch_size])
|
|
||||||
with Pool(num_processes) as p:
|
|
||||||
results = p.map(partial(gen_output_xfinder, extractor=extractor),
|
|
||||||
batched_ori_data)
|
|
||||||
for result in results:
|
|
||||||
extracted_answers += result[0]
|
|
||||||
extracted_correct_pairs += result[1]
|
|
||||||
extracted_data += result[2]
|
|
||||||
return extracted_answers
|
|
||||||
|
|
||||||
format_data = convert_to_xfinder_format(question_type, preds)
|
|
||||||
assert api_url is not None, 'Please provide the api url.'
|
|
||||||
data_processor = DataProcessor()
|
|
||||||
extractor = Extractor(
|
|
||||||
model_name=model_name,
|
|
||||||
url=api_url.split(',') if ',' in api_url else api_url)
|
|
||||||
calc_acc_func = partial(_eval_pred,
|
|
||||||
data_processor=data_processor,
|
|
||||||
extractor=extractor)
|
|
||||||
extracted_answers = calc_acc_func(format_data)
|
|
||||||
return extracted_answers
|
|
@ -1,11 +0,0 @@
|
|||||||
OPTION_NAVIE_PROMPT_TEMPLATE = """
|
|
||||||
There is a detailed explanation of the final answer you should extract:
|
|
||||||
1. You should extract the final answer option like 'A', 'B', 'C', 'D' ... from the given output sentences.
|
|
||||||
2. The question is a single choice question, so the final answer option should be one of the options, not a combination of options.
|
|
||||||
""" # noqa
|
|
||||||
|
|
||||||
MATH_NAVIE_PROMPT_TEMPLATE = """
|
|
||||||
This is a detailed explanation of the final answer you should extract:
|
|
||||||
1. The question type is a math question, so the final answer should be a number, set, vector, matrix, interval, expression, function, equation, or inequality and any combination of them.
|
|
||||||
2. If the final answer includes additional symbols, such as units, you should exclude them and only extract the pure final answer.
|
|
||||||
""" # noqa
|
|
@ -1,71 +0,0 @@
|
|||||||
## Short Usage Introduction for Naive Model Postprocessor with Custom Model
|
|
||||||
|
|
||||||
<!-- Now OC can use -->
|
|
||||||
|
|
||||||
### Step 1: Deploy an API server using vLLM or LMDeploy
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lmdeploy serve api_server meta-llama/Meta-Llama-3-8B-Instruct --model-name llama3-8b-instruct --server-port 23333 --backend turbomind --tp 1
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 2: Add Naive Model Postprocessor to the configuration file
|
|
||||||
|
|
||||||
Take GSM8K as an example, you can add the following lines to the configuration file and replace the `api_url` with the correct address of the API server.
|
|
||||||
|
|
||||||
```python
|
|
||||||
...
|
|
||||||
from opencompass.utils.model_postprocessors import navie_model_postprocess
|
|
||||||
from opencompass.utils.postprocessors.naive import MATH_NAVIE_PROMPT_TEMPLATE
|
|
||||||
|
|
||||||
...
|
|
||||||
|
|
||||||
gsm8k_eval_cfg = dict(
|
|
||||||
evaluator=dict(type=MATHEvaluator, version='v2'),
|
|
||||||
pred_postprocessor=dict(type=math_postprocess_v2),
|
|
||||||
dataset_postprocessor=dict(type=gsm8k_dataset_postprocess),
|
|
||||||
# Add the following line to use the naive model postprocessor
|
|
||||||
model_postprocessor=dict(
|
|
||||||
type=navie_model_postprocess,
|
|
||||||
custom_instruction=MATH_NAVIE_PROMPT_TEMPLATE,
|
|
||||||
model_name='llama3-8b-instruct',
|
|
||||||
api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
|
|
||||||
)
|
|
||||||
...
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
The prompt for extraction can also be customized by changing the `custom_instruction` parameter. Now support two default templates: `MATH_NAVIE_PROMPT_TEMPLATE` for math problems extraction like GSM8K and MATH, and `OPTION_NAVIE_PROMPT_TEMPLATE` for option problems extraction like MMLU. You can also write your own prompt template, like:
|
|
||||||
|
|
||||||
```python
|
|
||||||
OPTION_NAVIE_PROMPT_TEMPLATE = """
|
|
||||||
There is a detailed explanation of the final answer you should extract:
|
|
||||||
1. You should extract the final answer option like 'A', 'B', 'C', 'D' ... from the given output sentences.
|
|
||||||
2. The question is a single choice question, so the final answer option should be one of the options, not a combination of options.
|
|
||||||
"""
|
|
||||||
```
|
|
||||||
|
|
||||||
Your prompt should start with `There is a detailed explanation of the final answer you should extract:` and following with your customized instructions.
|
|
||||||
|
|
||||||
### Step 3: Run the Evaluation as Usual
|
|
||||||
|
|
||||||
Now you can run the evaluation as usual with the configuration file you modified. The evaluation will use the custom model as the post-process model to get the final result. The final result will be the `model_postprocess_accuracy` in the evaluation result, like:
|
|
||||||
|
|
||||||
```Markdown
|
|
||||||
dataset version metric mode llama-3-8b-instruct-turbomind
|
|
||||||
------------------------------------------------- --------- -------------------------- ------ -------------------------------
|
|
||||||
gsm8k a58960 accuracy gen 73.46
|
|
||||||
gsm8k a58960 model_postprocess_accuracy gen 78.77
|
|
||||||
```
|
|
||||||
|
|
||||||
## Experiment Results
|
|
||||||
|
|
||||||
We have tested the model postprocess method with different models (Qwen2-72B-Chat, Llama3-8b-Chat) as post-process model on the GSM8K, MMLU datasets for `Meta-Llama-3-8B-Instruct` with above settings, and the results are as follows:
|
|
||||||
|
|
||||||
```Markdown
|
|
||||||
| Dataset | Type | Config ID | Regex Postprocess Score | Model Postprocess Score (Llama3-8b-Instruct) | Model Postprocess Score (Qwen2-72B-Chat) |
|
|
||||||
| ------- | --------------- | ------------------------ | ----------------------- | ----------------------- |----------------------- |
|
|
||||||
| gsm8k | math | a58960 | 73.46 | 79.08 | 78.77 |
|
|
||||||
| mmlu | option | 4d595a | 67.89 | 65.26 | 67.94 |
|
|
||||||
```
|
|
||||||
|
|
||||||
The `metric` column with `model_postprocess_accuracy` is the final result after the `Naive Model Postprocessor` is applied.
|
|
@ -1,2 +0,0 @@
|
|||||||
from .extractor import * # noqa
|
|
||||||
from .PROMPT_TEMPLATE import * # noqa
|
|
@ -1,121 +0,0 @@
|
|||||||
# Naive model extractor for OpenCompass, modified from xFinder: https://github.com/IAAR-Shanghai/xFinder # noqa
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from logging import getLogger
|
|
||||||
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
Meta_Instruction = """I will provide you with a question, output sentences along with an answer range. The output sentences are the response of the question provided. The answer range could either describe the type of answer expected or list all possible valid answers. Using the information provided, you must accurately and precisely determine and extract the intended key answer from the output sentences. Please don't have your subjective thoughts about the question.
|
|
||||||
First, you need to determine whether the content of the output sentences is relevant to the given question. If the entire output sentences are unrelated to the question (meaning the output sentences are not addressing the question), then output [No valid answer].
|
|
||||||
Otherwise, ignore the parts of the output sentences that have no relevance to the question and then extract the key answer that matches the answer range.
|
|
||||||
Below are some special cases you need to be aware of:
|
|
||||||
(1) If the output sentences present multiple different answers, carefully determine if the later provided answer is a correction or modification of a previous one. If so, extract this corrected or modified answer as the final response. Conversely, if the output sentences fluctuate between multiple answers without a clear final answer, you should output [No valid answer].
|
|
||||||
(2) If the answer range is a list and the key answer in the output sentences is not explicitly listed among the candidate options in the answer range, also output [No valid answer].
|
|
||||||
(3) You should only return the precise answer you extract, without processing the answer. Please return only the answer and do not add any additional content.
|
|
||||||
|
|
||||||
""" # noqa
|
|
||||||
|
|
||||||
|
|
||||||
def format_input_naive(data):
|
|
||||||
format_data = []
|
|
||||||
for item in data:
|
|
||||||
template = {}
|
|
||||||
question = item['origin_prompt'][-1]['prompt']
|
|
||||||
llm_output = item['prediction']
|
|
||||||
correct_answer = item['reference'] if item['reference'] else item[
|
|
||||||
'gold']
|
|
||||||
template['correct_answer'] = correct_answer
|
|
||||||
template['question'] = question
|
|
||||||
template['llm_output'] = llm_output
|
|
||||||
|
|
||||||
format_data.append(template)
|
|
||||||
return format_data
|
|
||||||
|
|
||||||
|
|
||||||
class NaiveExtractor:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name,
|
|
||||||
model_path=None,
|
|
||||||
url=None,
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=3000,
|
|
||||||
api_key='EMPTY',
|
|
||||||
SYSTEM='You are a help assistant tasked with extracting the precise key answer from given output sentences. You must only provide the extracted key answer without including any additional text.', # noqa
|
|
||||||
custom_instruction=''):
|
|
||||||
self.model_name = model_name
|
|
||||||
self.SYSTEM = SYSTEM
|
|
||||||
self.model_path = model_path
|
|
||||||
self.url = url
|
|
||||||
self.api_key = api_key
|
|
||||||
self.temperature = temperature
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.custom_instruction = custom_instruction
|
|
||||||
self.logger = getLogger(__name__)
|
|
||||||
|
|
||||||
def prepare_input(self, item):
|
|
||||||
user_input = Meta_Instruction + self.custom_instruction + \
|
|
||||||
"Question: \"\"\"" + item['question'] + "\"\"\"\n\n" + \
|
|
||||||
"Output sentences: \"\"\"" + item['llm_output'] + "\"\"\"\n\n" + \
|
|
||||||
'Key extracted answer: '
|
|
||||||
|
|
||||||
return user_input
|
|
||||||
|
|
||||||
def gen_output(self, query):
|
|
||||||
return self.openai_infer(query)
|
|
||||||
|
|
||||||
def openai_infer(self, query: str, retry=9) -> str:
|
|
||||||
"""Perform inference on the OpenAI model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (str): The input query.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The extracted answer (xFinder's output).
|
|
||||||
"""
|
|
||||||
if isinstance(self.url, list):
|
|
||||||
# Randomly api for better load balancing
|
|
||||||
import random
|
|
||||||
self.url = random.choice(self.url)
|
|
||||||
self.client = OpenAI(
|
|
||||||
api_key=self.api_key,
|
|
||||||
base_url=self.url,
|
|
||||||
)
|
|
||||||
self.retry = retry
|
|
||||||
|
|
||||||
t = time.time()
|
|
||||||
retry = self.retry
|
|
||||||
response = ''
|
|
||||||
while retry > 0:
|
|
||||||
try:
|
|
||||||
chat_response = self.client.chat.completions.create(
|
|
||||||
model=self.client.models.list().data[0].id
|
|
||||||
if self.model_name == '' else self.model_name,
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
'role': 'system',
|
|
||||||
'content': self.SYSTEM
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'role': 'user',
|
|
||||||
'content': query
|
|
||||||
},
|
|
||||||
],
|
|
||||||
temperature=self.temperature,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
)
|
|
||||||
js_response = json.loads(chat_response.model_dump_json())
|
|
||||||
response = js_response['choices'][0]['message']['content']
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.info(f'Error: {e}')
|
|
||||||
self.logger.info(f'{self.url} is down. Retrying...')
|
|
||||||
self.logger.info(f'Time elapsed: {time.time() - t} seconds')
|
|
||||||
time.sleep(6)
|
|
||||||
retry -= 1
|
|
||||||
if retry == 0:
|
|
||||||
response = 'Error: Failed to get response.'
|
|
||||||
self.logger.info(f'{response} after {self.retry} tries.')
|
|
||||||
raise ValueError('The api is down')
|
|
||||||
return response.strip()
|
|
@ -1,194 +0,0 @@
|
|||||||
## Extract Final Answers with Postprocess Models
|
|
||||||
|
|
||||||
OpenCompass now support postprocess (extract) prediction answers with postprocess models, to get the true ability level of models. Now, we use [XFinder](https://github.com/IAAR-Shanghai/xFinder) as our first postprocess model to extract the final answers from the model outputs.
|
|
||||||
|
|
||||||
We support four types of task types now:
|
|
||||||
|
|
||||||
1. **math**: for math questions with numerical pr formula answers, like GSM8k, Math, etc.
|
|
||||||
2. **alphabet_option**: for alphabet option questions with alphabet answers, like CommonsenseQA, MMLU, etc.
|
|
||||||
3. **short_text**: for questions answer type is a short text with selected short text answers.
|
|
||||||
|
|
||||||
Here are some examples of the question type supported from the official xFinder repo:
|
|
||||||
|
|
||||||
```json
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"model_name": "Llama3_8B_Instruct",
|
|
||||||
"dataset": "MetaMathQA",
|
|
||||||
"key_answer_type": "math",
|
|
||||||
"question": "Find the smallest, positive four-digit multiple of $15.$",
|
|
||||||
"llm_output": "The smallest four-digit multiple of $15$ is $\\boxed{1005}.$",
|
|
||||||
"correct_answer": "1005",
|
|
||||||
"standard_answer_range": "a(n) number / set / vector / matrix / interval / expression / function / equation / inequality"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"model_name": "LLaMA2_7B_Chat",
|
|
||||||
"dataset": "SIQA_text",
|
|
||||||
"key_answer_type": "short_text",
|
|
||||||
"question": "Jordan kept Carson up all night talking because Jordan's girlfriend left them and they were quite upset. How would Carson feel as a result? Answer Choices: drowsy / upset their girlfriend left / energetic",
|
|
||||||
"llm_output": "Based on the information provided in the question, Carson would likely feel upset because Jordan's girlfriend left them and they were quite upset. Therefore, the correct answer is (B) upset their girlfriend left.",
|
|
||||||
"correct_answer": "drowsy",
|
|
||||||
"standard_answer_range": [
|
|
||||||
"drowsy",
|
|
||||||
"upset their girlfriend left",
|
|
||||||
"energetic"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"model_name": "PHI2",
|
|
||||||
"dataset": "CommonsenseQA_alpha",
|
|
||||||
"key_answer_type": "alphabet_option",
|
|
||||||
"question": "Why would you go to an office? Answer Choices: (A) work (B) school building (C) paper (D) city (E) habit",
|
|
||||||
"llm_output": "A) work (B) school building (C) paper (D) city (E) habit\n\nExplanation:\n We are looking for a reason why someone would go to an office. The only answer choice that matches this is option A) work. Therefore, the correct answer is A) work.",
|
|
||||||
"correct_answer": "A",
|
|
||||||
"standard_answer_range": [
|
|
||||||
[
|
|
||||||
"A",
|
|
||||||
"work"
|
|
||||||
],
|
|
||||||
[
|
|
||||||
"B",
|
|
||||||
"school building"
|
|
||||||
],
|
|
||||||
[
|
|
||||||
"C",
|
|
||||||
"paper"
|
|
||||||
],
|
|
||||||
[
|
|
||||||
"D",
|
|
||||||
"city"
|
|
||||||
],
|
|
||||||
[
|
|
||||||
"E",
|
|
||||||
"habit"
|
|
||||||
]
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
## How to Use Model Postprocess in OpenCompass
|
|
||||||
|
|
||||||
### Step 1: Deploy the Postprocess Model Server
|
|
||||||
|
|
||||||
For now, there are two xFinder models can use, you can download them from Huggingface model hub:
|
|
||||||
|
|
||||||
1. **IAAR-Shanghai/xFinder-qwen1505**
|
|
||||||
2. **IAAR-Shanghai/xFinder-llama38it**
|
|
||||||
|
|
||||||
You can use LMDeploy or vLLM to deploy the xFinder model server, for example, you can use the following command to deploy the xFinder model server with LMDeploy:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lmdeploy serve api_server IAAR-Shanghai/xFinder-qwen1505 --model-name xFinder-qwen1505 --server-port 23333 --backend turbomind --tp 1
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 2: Set the Postprocess Model Config in the Dataset Configuration
|
|
||||||
|
|
||||||
We make the postprocess as a common postprocess function in OpenCompass, so you can use it by setting the `postprocess` parameter in the `predict` function of OpenCompass. It can be used with the default postprocess regularization extract function at the same time. The only thing you need to do is to deploy the postprocess model server and set the `model_postprocessor` to the original `eval_cfg` in the dataset configuration, like the following example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from opencompass.utils.model_postprocessors import xfinder_postprocess
|
|
||||||
|
|
||||||
...
|
|
||||||
|
|
||||||
model_postprocessor=dict(
|
|
||||||
type=xfinder_postprocess,
|
|
||||||
question_type='math',
|
|
||||||
xfinder_model_name='xFinder-qwen1505',
|
|
||||||
xfiner_api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
|
|
||||||
```
|
|
||||||
|
|
||||||
Explanation of the parameters:
|
|
||||||
|
|
||||||
- `question_type`: the type of the question, which can be one of the three types mentioned above.
|
|
||||||
- `xfinder_model_name`: the name of the model you deploying the model server.
|
|
||||||
- `xfiner_api_url`: the URL of the model server, you can set multiple URLs with `,` to use multiple model servers, which can accelerate the postprocess speed.
|
|
||||||
|
|
||||||
📢:**Please attention following points**:
|
|
||||||
|
|
||||||
1. Now only support extract questions with Zero-shot setting.
|
|
||||||
2. For alphabet_option problems, the option should be like '\\nA. xxx\\nB. xxx\\nC. xxx\\nD. xxx\\nE. xxx\\n ...' or '\\n(A) xxx\\n(B) xxx\\n(C) xxx\\n(D) xxx\\n(E) xxx\\n ...' format, and the correct answer should be the alphabet of the correct answer, like 'A', 'B', 'C', 'D', 'E'.
|
|
||||||
|
|
||||||
For more details about the xFinder model, you can refer to the [xFinder](https://github.com/IAAR-Shanghai/xFinder), and for a complete example, you can refer to the following example, which is the configuration of the GSM8K dataset with the xFinder postprocess model:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
|
||||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
|
||||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
|
||||||
from opencompass.datasets import GSM8KDataset, gsm8k_dataset_postprocess, Gsm8kEvaluator
|
|
||||||
from opencompass.datasets import MATHEvaluator, math_postprocess_v2
|
|
||||||
from opencompass.utils.model_postprocessors import xfinder_postprocess
|
|
||||||
|
|
||||||
gsm8k_reader_cfg = dict(input_columns=['question'], output_column='answer')
|
|
||||||
|
|
||||||
gsm8k_infer_cfg = dict(
|
|
||||||
prompt_template=dict(
|
|
||||||
type=PromptTemplate,
|
|
||||||
template=dict(
|
|
||||||
round=[
|
|
||||||
dict(role='HUMAN', prompt='{question}\nPlease reason step by step, and put your final answer within \\boxed{}.'),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
),
|
|
||||||
retriever=dict(type=ZeroRetriever),
|
|
||||||
inferencer=dict(type=GenInferencer, max_out_len=512),
|
|
||||||
)
|
|
||||||
|
|
||||||
gsm8k_eval_cfg = dict(
|
|
||||||
evaluator=dict(type=MATHEvaluator, version='v2'),
|
|
||||||
pred_postprocessor=dict(type=math_postprocess_v2),
|
|
||||||
dataset_postprocessor=dict(type=gsm8k_dataset_postprocess),
|
|
||||||
model_postprocessor=dict(
|
|
||||||
type=xfinder_postprocess,
|
|
||||||
question_type='math',
|
|
||||||
xfinder_model_name='xFinder-qwen1505',
|
|
||||||
xfiner_api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
|
|
||||||
)
|
|
||||||
|
|
||||||
gsm8k_datasets = [
|
|
||||||
dict(
|
|
||||||
abbr='gsm8k',
|
|
||||||
type=GSM8KDataset,
|
|
||||||
path='opencompass/gsm8k',
|
|
||||||
reader_cfg=gsm8k_reader_cfg,
|
|
||||||
infer_cfg=gsm8k_infer_cfg,
|
|
||||||
eval_cfg=gsm8k_eval_cfg,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
For evaluation results, `accuracy` is the result using default postprocess, and `model_postprocess_accuracy` is the result using xFinder postprocess, the gap can be wider when the model is not good answering the questions properly.
|
|
||||||
|
|
||||||
You can also use the `--dump-eval-details` command to dump the detailed evaluation details to see the model postprocess results from the `results` folder.
|
|
||||||
|
|
||||||
## Results Comparison with Different Question Types
|
|
||||||
|
|
||||||
We have tested the model postprocess method with XFinder model on the GSM8K, MMLU, Natural Questions (NQ) datasets for `Meta-Llama-3-8B-Instruct` with above settings, and the results are as follows:
|
|
||||||
|
|
||||||
| Dataset | Type | Config Name | Regex Postprocess Score | Model Postprocess Score |
|
|
||||||
| ------- | --------------- | ------------------------ | ----------------------- | ----------------------- |
|
|
||||||
| gsm8k | math | gsm8k_xfinder_gen_a58960 | 73.46 | 78.09 |
|
|
||||||
| nq | short_text | nq_xfinder_gen_3dcea1 | 22.33 | 37.53 |
|
|
||||||
| mmlu | alphabet_option | mmlu_xfinder_gen_4d595a | 67.89 | 67.93 |
|
|
||||||
|
|
||||||
## Citation
|
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@misc{2023opencompass,
|
|
||||||
title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
|
|
||||||
author={OpenCompass Contributors},
|
|
||||||
howpublished = {\url{https://github.com/open-compass/opencompass}},
|
|
||||||
year={2023}
|
|
||||||
}
|
|
||||||
|
|
||||||
@misc{yu2024xfinderrobustpinpointanswer,
|
|
||||||
title={xFinder: Robust and Pinpoint Answer Extraction for Large Language Models},
|
|
||||||
author={Qingchen Yu and Zifan Zheng and Shichao Song and Zhiyu Li and Feiyu Xiong and Bo Tang and Ding Chen},
|
|
||||||
year={2024},
|
|
||||||
eprint={2405.11874},
|
|
||||||
archivePrefix={arXiv},
|
|
||||||
primaryClass={cs.CL},
|
|
||||||
url={https://arxiv.org/abs/2405.11874},
|
|
||||||
}
|
|
||||||
|
|
||||||
```
|
|
@ -1,175 +0,0 @@
|
|||||||
import json
|
|
||||||
import time
|
|
||||||
from logging import getLogger
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from .xfinder_utils import PROMPT_TEMPLATE
|
|
||||||
|
|
||||||
Instruction = """I will provide you with a question, output sentences along with an answer range. The output sentences are the response of the question provided. The answer range could either describe the type of answer expected or list all possible valid answers. Using the information provided, you must accurately and precisely determine and extract the intended key answer from the output sentences. Please don't have your subjective thoughts about the question.
|
|
||||||
First, you need to determine whether the content of the output sentences is relevant to the given question. If the entire output sentences are unrelated to the question (meaning the output sentences are not addressing the question), then output [No valid answer].
|
|
||||||
Otherwise, ignore the parts of the output sentences that have no relevance to the question and then extract the key answer that matches the answer range.
|
|
||||||
Below are some special cases you need to be aware of:
|
|
||||||
(1) If the output sentences present multiple different answers, carefully determine if the later provided answer is a correction or modification of a previous one. If so, extract this corrected or modified answer as the final response. Conversely, if the output sentences fluctuate between multiple answers without a clear final answer, you should output [No valid answer].
|
|
||||||
(2) If the answer range is a list and the key answer in the output sentences is not explicitly listed among the candidate options in the answer range, also output [No valid answer].
|
|
||||||
|
|
||||||
""" # noqa
|
|
||||||
|
|
||||||
|
|
||||||
class Extractor:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name,
|
|
||||||
model_path=None,
|
|
||||||
url=None,
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=3000,
|
|
||||||
api_key='EMPTY',
|
|
||||||
SYSTEM='You are a help assistant tasked with extracting the precise key answer from given output sentences. You must only provide the extracted key answer without including any additional text.' # noqa
|
|
||||||
):
|
|
||||||
self.model_name = model_name
|
|
||||||
self.PROMPT_TEMPLATE = PROMPT_TEMPLATE[model_name]
|
|
||||||
self.SYSTEM = SYSTEM
|
|
||||||
self.model_path = model_path
|
|
||||||
self.url = url
|
|
||||||
self.api_key = api_key
|
|
||||||
self.temperature = temperature
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.mode = 'API' if self.url is not None else 'Local'
|
|
||||||
self.logger = getLogger(__name__)
|
|
||||||
|
|
||||||
if self.mode == 'Local':
|
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
self.sampling_params = SamplingParams(temperature=self.temperature,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
stop=[
|
|
||||||
'<|endoftext|>',
|
|
||||||
'<|im_end|>', '<eoa>',
|
|
||||||
'<||>', '<end_of_turn>',
|
|
||||||
'<|eot_id|>'
|
|
||||||
])
|
|
||||||
self.llm = LLM(model=self.model_path, gpu_memory_utilization=0.5)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def prepare_input(item):
|
|
||||||
user_input = Instruction + \
|
|
||||||
"Question: \"\"\"" + item['question'] + "\"\"\"\n\n" + \
|
|
||||||
"Output sentences: \"\"\"" + item['llm_output'] + "\"\"\"\n\n" + \
|
|
||||||
'Answer range: ' + item['standard_answer_range'] + '\n\n' + \
|
|
||||||
'Key extracted answer: '
|
|
||||||
|
|
||||||
return user_input
|
|
||||||
|
|
||||||
def gen_output(self, query):
|
|
||||||
if self.mode == 'API':
|
|
||||||
# return self.send_request(query)
|
|
||||||
return self.openai_infer(query)
|
|
||||||
else:
|
|
||||||
return self.offline_infer(query)
|
|
||||||
|
|
||||||
def send_request(self, query: str) -> str:
|
|
||||||
"""Send a request to the model's API and return the response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (str): The input query.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The extracted answer (xFinder's output).
|
|
||||||
"""
|
|
||||||
prompt = self.PROMPT_TEMPLATE.format(system=self.SYSTEM, input=query)
|
|
||||||
payload = json.dumps({
|
|
||||||
'prompt':
|
|
||||||
prompt,
|
|
||||||
'temperature':
|
|
||||||
self.temperature,
|
|
||||||
'max_tokens':
|
|
||||||
self.max_tokens,
|
|
||||||
'stop': [
|
|
||||||
'<|endoftext|>', '<|im_end|>', '<eoa>', '<||>',
|
|
||||||
'<end_of_turn>', '<|eot_id|>'
|
|
||||||
],
|
|
||||||
})
|
|
||||||
headers = {'Content-Type': 'application/json'}
|
|
||||||
res = requests.request('POST', self.url, headers=headers, data=payload)
|
|
||||||
res = res.json()['text'][0]
|
|
||||||
res = res.replace(prompt, '')
|
|
||||||
# res = requests.post(self.url, json=payload)
|
|
||||||
# res = res.json()['text']
|
|
||||||
res = res.strip()
|
|
||||||
return res
|
|
||||||
|
|
||||||
def openai_infer(self, query: str, retry=9) -> str:
|
|
||||||
"""Perform inference on the OpenAI model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (str): The input query.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The extracted answer (xFinder's output).
|
|
||||||
"""
|
|
||||||
if isinstance(self.url, list):
|
|
||||||
# Randomly api for better load balancing
|
|
||||||
import random
|
|
||||||
self.url = random.choice(self.url)
|
|
||||||
self.client = OpenAI(
|
|
||||||
api_key=self.api_key,
|
|
||||||
base_url=self.url,
|
|
||||||
)
|
|
||||||
self.retry = retry
|
|
||||||
|
|
||||||
t = time.time()
|
|
||||||
retry = self.retry
|
|
||||||
response = ''
|
|
||||||
while retry > 0:
|
|
||||||
try:
|
|
||||||
chat_response = self.client.chat.completions.create(
|
|
||||||
model=self.client.models.list().data[0].id
|
|
||||||
if self.model_name == '' else self.model_name,
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
'role': 'system',
|
|
||||||
'content': self.SYSTEM
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'role': 'user',
|
|
||||||
'content': query
|
|
||||||
},
|
|
||||||
],
|
|
||||||
stop=[
|
|
||||||
'<|endoftext|>', '<|im_end|>', '<eoa>', '<||>',
|
|
||||||
'<end_of_turn>', '<|eot_id|>'
|
|
||||||
],
|
|
||||||
temperature=self.temperature,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
)
|
|
||||||
js_response = json.loads(chat_response.model_dump_json())
|
|
||||||
response = js_response['choices'][0]['message']['content']
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.info(f'Error: {e}')
|
|
||||||
self.logger.info(f'{self.url} is down. Retrying...')
|
|
||||||
self.logger.info(f'Time elapsed: {time.time() - t} seconds')
|
|
||||||
time.sleep(6)
|
|
||||||
retry -= 1
|
|
||||||
if retry == 0:
|
|
||||||
response = 'Error: Failed to get response.'
|
|
||||||
self.logger.info(f'{response} after {self.retry} tries.')
|
|
||||||
raise ValueError('The api is down')
|
|
||||||
return response.strip()
|
|
||||||
|
|
||||||
def offline_infer(self, query: str) -> str:
|
|
||||||
"""Perform inference on the local xFinder model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (str): The input query.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The extracted answer (xFinder's output).
|
|
||||||
"""
|
|
||||||
prompt = self.PROMPT_TEMPLATE.format(system=self.SYSTEM, input=query)
|
|
||||||
res = self.llm.generate(prompt, self.sampling_params)
|
|
||||||
res = res[0]
|
|
||||||
res = res.outputs[0].text.strip()
|
|
||||||
return res
|
|
@ -1,14 +0,0 @@
|
|||||||
PROMPT_TEMPLATE = {
|
|
||||||
'xFinder-qwen1505':
|
|
||||||
"""<|System|>:{system}
|
|
||||||
<|User|>:{input}
|
|
||||||
<|Bot|>:""",
|
|
||||||
'xFinder-llama38it':
|
|
||||||
"""<|start_header_id|>system<|end_header_id|>
|
|
||||||
|
|
||||||
{system}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
|
||||||
|
|
||||||
{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
|
||||||
|
|
||||||
""",
|
|
||||||
}
|
|
@ -1,3 +0,0 @@
|
|||||||
from .convert_data import * # noqa
|
|
||||||
from .data_process import * # noqa
|
|
||||||
from .PROMPT_TEMPLATE import * # noqa
|
|
@ -1,123 +0,0 @@
|
|||||||
# Convert OpenCompass prediction data to XFinder format
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
|
|
||||||
xfinder_template = {
|
|
||||||
'math': {
|
|
||||||
'model_name':
|
|
||||||
'',
|
|
||||||
'dataset':
|
|
||||||
'',
|
|
||||||
'key_answer_type':
|
|
||||||
'math',
|
|
||||||
'question':
|
|
||||||
'',
|
|
||||||
'llm_output':
|
|
||||||
'',
|
|
||||||
'correct_answer':
|
|
||||||
'',
|
|
||||||
'standard_answer_range':
|
|
||||||
'a(n) number / set / vector / matrix / interval / expression / function / equation / inequality' # noqa
|
|
||||||
},
|
|
||||||
'alphabet_option': {
|
|
||||||
'model_name': '',
|
|
||||||
'dataset': '',
|
|
||||||
'key_answer_type': 'alphabet_option',
|
|
||||||
'question': '',
|
|
||||||
'llm_output': '.',
|
|
||||||
'correct_answer': '',
|
|
||||||
'standard_answer_range': []
|
|
||||||
},
|
|
||||||
'categorical_label': {
|
|
||||||
'model_name': '',
|
|
||||||
'dataset': '',
|
|
||||||
'key_answer_type': '',
|
|
||||||
'question': '',
|
|
||||||
'llm_output': '',
|
|
||||||
'correct_answer': '',
|
|
||||||
'standard_answer_range': []
|
|
||||||
},
|
|
||||||
'short_text': {
|
|
||||||
'model_name': '',
|
|
||||||
'dataset': '',
|
|
||||||
'key_answer_type': 'short_text',
|
|
||||||
'question': '',
|
|
||||||
'llm_output': '',
|
|
||||||
'correct_answer': '',
|
|
||||||
'standard_answer_range': []
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def parse_options(text: str):
|
|
||||||
lines = text.split('\n')
|
|
||||||
parsed_options = []
|
|
||||||
option_pattern = r'^[A-Z]\)|[A-Z]\.|[A-Z]\)|[A-Z]:|\([A-Z]\)'
|
|
||||||
for line in lines:
|
|
||||||
line = line.strip()
|
|
||||||
match = re.match(option_pattern, line)
|
|
||||||
if match:
|
|
||||||
option = ''
|
|
||||||
# 等于第一个属于选项的字符
|
|
||||||
for c in line:
|
|
||||||
if c.isalpha():
|
|
||||||
option = c
|
|
||||||
break
|
|
||||||
content_start = match.end() + 1
|
|
||||||
content = line[content_start:].strip()
|
|
||||||
parsed_options.append([option, content])
|
|
||||||
|
|
||||||
return parsed_options
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_xfinder_format(typ, data, model_name='', dataset_name=''):
|
|
||||||
assert typ in xfinder_template.keys(), f'Invalid type {typ}'
|
|
||||||
format_data = []
|
|
||||||
for item in data:
|
|
||||||
template = copy.deepcopy(xfinder_template[typ])
|
|
||||||
question = item['origin_prompt'][-1]['prompt']
|
|
||||||
llm_output = item['prediction']
|
|
||||||
correct_answer = item['reference'] if item['reference'] else item[
|
|
||||||
'gold']
|
|
||||||
template['correct_answer'] = correct_answer
|
|
||||||
template['model_name'] = model_name
|
|
||||||
template['dataset'] = dataset_name
|
|
||||||
template['question'] = question
|
|
||||||
template['llm_output'] = llm_output
|
|
||||||
try:
|
|
||||||
assert typ in list(xfinder_template.keys())
|
|
||||||
if typ == 'alphabet_option':
|
|
||||||
options = parse_options(question)
|
|
||||||
template['standard_answer_range'] = options
|
|
||||||
elif typ == 'short_text':
|
|
||||||
template['standard_answer_range'] = item['gold']
|
|
||||||
elif typ == 'categorical_label':
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
print(f'Error when parsing question options: {e}, skipping...')
|
|
||||||
continue
|
|
||||||
|
|
||||||
format_data.append(template)
|
|
||||||
return format_data
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# Test
|
|
||||||
example_data = {
|
|
||||||
'origin_prompt': [{
|
|
||||||
'role':
|
|
||||||
'HUMAN',
|
|
||||||
'prompt':
|
|
||||||
'Alice, Bob, Claire, Dave, and Eve are dancers at a square dance. At the start of a song, they each have a partner: Alice is dancing with Ophelia, Bob is dancing with Jamie, Claire is dancing with Melissa, Dave is dancing with Rodrigo, and Eve is dancing with Patrick.\nThroughout the song, the dancers often trade partners. First, Claire and Bob switch partners. Then, Claire and Eve switch partners. Then, Claire and Bob switch partners. Then, Eve and Dave switch partners. Finally, Claire and Alice switch partners. At the end of the dance, Alice is dancing with\nOptions:\n(A) Ophelia\n(B) Jamie\n(C) Melissa\n(D) Rodrigo\n(E) Patrick' # noqa
|
|
||||||
}],
|
|
||||||
'origin_prediction':
|
|
||||||
'\n 答案: B) 前者小于后者',
|
|
||||||
'prediction':
|
|
||||||
'B',
|
|
||||||
'reference':
|
|
||||||
'A'
|
|
||||||
}
|
|
||||||
example_data = convert_to_xfinder_format('alphabet_option', [example_data],
|
|
||||||
'GPT-3', 'OpenAI')
|
|
||||||
print(json.dumps(example_data, indent=4, ensure_ascii=False))
|
|
@ -1,24 +0,0 @@
|
|||||||
import ast
|
|
||||||
|
|
||||||
|
|
||||||
class DataProcessor:
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def read_data(self, data):
|
|
||||||
for item in data:
|
|
||||||
if isinstance(item['standard_answer_range'],
|
|
||||||
str) and item['key_answer_type'] != 'math':
|
|
||||||
try:
|
|
||||||
item['standard_answer_range'] = ast.literal_eval(
|
|
||||||
item['standard_answer_range'])
|
|
||||||
except Exception as e:
|
|
||||||
print(f'Error: {e}')
|
|
||||||
print('Please check the form of standard_answer_range')
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
item['standard_answer_range'] = str(item['standard_answer_range'])
|
|
||||||
item['key_answer_type'] = str(item['key_answer_type'])
|
|
||||||
|
|
||||||
return data
|
|
Loading…
Reference in New Issue
Block a user