mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Sync] Initial support of subjective evaluation (#421)
Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
parent
0f2c388280
commit
a1ea3c094a
17
opencompass/datasets/lmeval.py
Normal file
17
opencompass/datasets/lmeval.py
Normal file
@ -0,0 +1,17 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
from opencompass.datasets import BaseDataset
|
||||
|
||||
|
||||
class LMEvalDataset(BaseDataset):
|
||||
"""A dataset wrapper around the evaluator inputs, designed for
|
||||
OpenCompass's internal use."""
|
||||
|
||||
@staticmethod
|
||||
def load(predictions: List, references: Optional[List] = None):
|
||||
content = {'prediction': predictions}
|
||||
if references:
|
||||
content['reference'] = references
|
||||
return DatasetDict(dict(test=Dataset.from_dict(content)))
|
@ -58,7 +58,7 @@ class DatasetReader:
|
||||
def __init__(self,
|
||||
dataset: Union[Dataset, DatasetDict, str],
|
||||
input_columns: Union[List[str], str],
|
||||
output_column: str,
|
||||
output_column: Optional[str],
|
||||
input_template: Optional[PromptTemplate] = None,
|
||||
output_template: Optional[PromptTemplate] = None,
|
||||
train_split: str = 'train',
|
||||
@ -68,7 +68,9 @@ class DatasetReader:
|
||||
self.input_columns = _check_type_list(input_columns, [List, str])
|
||||
if isinstance(self.input_columns, str):
|
||||
self.input_columns = self.input_columns.split()
|
||||
self.output_column = _check_str(output_column)
|
||||
self.output_column = None
|
||||
if output_column:
|
||||
self.output_column = _check_str(output_column)
|
||||
|
||||
train_range = _check_type_list(train_range, [None, int, float, str])
|
||||
test_range = _check_type_list(test_range, [None, int, float, str])
|
||||
|
@ -4,3 +4,4 @@ from .icl_base_evaluator import BaseEvaluator # noqa
|
||||
from .icl_em_evaluator import EMEvaluator # noqa
|
||||
from .icl_hf_evaluator import * # noqa
|
||||
from .icl_toxic_evaluator import ToxicEvaluator # noqa
|
||||
from .lm_evaluator import LMEvaluator # noqa
|
||||
|
94
opencompass/openicl/icl_evaluator/lm_evaluator.py
Normal file
94
opencompass/openicl/icl_evaluator/lm_evaluator.py
Normal file
@ -0,0 +1,94 @@
|
||||
import os.path as osp
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import ConfigDict
|
||||
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.registry import ICL_PROMPT_TEMPLATES
|
||||
from opencompass.utils import build_dataset_from_cfg, build_model_from_cfg
|
||||
from opencompass.utils.logging import get_logger
|
||||
from opencompass.utils.text_postprocessors import first_number_postprocess
|
||||
from opencompass.utils.types import get_type_from_cfg
|
||||
|
||||
|
||||
class LMEvaluator:
|
||||
"""Evaluate output with language model.
|
||||
|
||||
Args:
|
||||
prompt_template (ConfigDict): Prompt template configuration. Used to
|
||||
prompt the language model for scores. User can use two reserved
|
||||
keywords, ``{prediction}`` and ``{reference}``, referring to
|
||||
the prediction and optionally the reference answer.
|
||||
judge_cfg (ConfigDict): The config of language model as a judge.
|
||||
output_path (str): The path to prediction output.
|
||||
dataset_cfg (ConfigDict, optional): The config of the dataset to be
|
||||
evaluated.
|
||||
postprocessor (ConfigDict): The model prediction's postprocessor
|
||||
config.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_template: ConfigDict,
|
||||
judge_cfg: ConfigDict,
|
||||
output_path: str,
|
||||
dataset_cfg: Optional[ConfigDict] = None,
|
||||
postprocessor: ConfigDict = dict(type=first_number_postprocess)
|
||||
) -> None:
|
||||
self.output_path = output_path
|
||||
out_dir, out_name = osp.split(output_path)
|
||||
if not out_dir:
|
||||
out_dir = './'
|
||||
|
||||
self.prompt_tmpl = ICL_PROMPT_TEMPLATES.build(prompt_template)
|
||||
|
||||
max_out_len = judge_cfg.get('max_out_len', None)
|
||||
batch_size = judge_cfg.get('batch_size', None)
|
||||
model = build_model_from_cfg(model_cfg=judge_cfg)
|
||||
self.inferencer = GenInferencer(model,
|
||||
max_out_len=max_out_len,
|
||||
batch_size=batch_size,
|
||||
output_json_filepath=out_dir,
|
||||
output_json_filename=out_name)
|
||||
self.postprocessor = get_type_from_cfg(postprocessor)
|
||||
self.logger = get_logger()
|
||||
self.dataset_cfg = dataset_cfg
|
||||
|
||||
def score(self, predictions, references: Optional[List] = None) -> Dict:
|
||||
if self.dataset_cfg:
|
||||
dataset = build_dataset_from_cfg(self.dataset_cfg)
|
||||
dataset.reader.dataset['test'] = dataset.test.add_column(
|
||||
'prediction', predictions)
|
||||
dataset.reader.input_columns.append('prediction')
|
||||
if references:
|
||||
dataset.reader.input_columns.append('reference')
|
||||
dataset.reader.dataset['test'] = dataset.test.add_column(
|
||||
'reference', references)
|
||||
else:
|
||||
from opencompass.datasets.lmeval import LMEvalDataset
|
||||
input_columns = ['prediction']
|
||||
if references:
|
||||
input_columns.append('reference')
|
||||
dataset = LMEvalDataset(reader_cfg=dict(
|
||||
input_columns=input_columns,
|
||||
output_column=None,
|
||||
train_split='test'),
|
||||
predictions=predictions,
|
||||
references=references)
|
||||
retriever = ZeroRetriever(dataset)
|
||||
self.inferencer.inference(retriever=retriever,
|
||||
prompt_template=self.prompt_tmpl)
|
||||
|
||||
output = mmengine.load(self.output_path)
|
||||
scores = []
|
||||
for k, v in output.items():
|
||||
score = self.postprocessor(v['prediction'])
|
||||
output[k]['score'] = score
|
||||
scores.append(score)
|
||||
try:
|
||||
output['score'] = sum(scores) / len(scores)
|
||||
except Exception:
|
||||
pass
|
||||
return output
|
@ -13,11 +13,16 @@ class BasePartitioner:
|
||||
|
||||
Args:
|
||||
out_dir (str): The output directory of tasks.
|
||||
keep_keys (List[str]): The keys to be kept from the experiment config
|
||||
to the task config.
|
||||
"""
|
||||
|
||||
def __init__(self, out_dir: str):
|
||||
def __init__(self,
|
||||
out_dir: str,
|
||||
keep_keys: List[str] = ['eval.runner.task.judge_cfg']):
|
||||
self.logger = get_logger()
|
||||
self.out_dir = out_dir
|
||||
self.keep_keys = keep_keys
|
||||
|
||||
def __call__(self, cfg: ConfigDict) -> List[Dict]:
|
||||
"""Generate tasks from config. Each task is defined as a
|
||||
@ -45,7 +50,26 @@ class BasePartitioner:
|
||||
datasets = cfg['datasets']
|
||||
work_dir = cfg['work_dir']
|
||||
|
||||
tasks = self.partition(models, datasets, work_dir, self.out_dir)
|
||||
add_cfg = {}
|
||||
for k in self.keep_keys:
|
||||
try:
|
||||
key_chain = k.split('.')
|
||||
ori_ptr = cfg
|
||||
tgt_ptr = add_cfg
|
||||
for key in key_chain[:-1]:
|
||||
ori_ptr = ori_ptr[key]
|
||||
if key not in tgt_ptr:
|
||||
tgt_ptr[key] = {}
|
||||
tgt_ptr = tgt_ptr[key]
|
||||
tgt_ptr[key_chain[-1]] = ori_ptr[key_chain[-1]]
|
||||
except AttributeError:
|
||||
self.logger.warning(f'Key {k} not found in config, ignored.')
|
||||
|
||||
tasks = self.partition(models,
|
||||
datasets,
|
||||
work_dir,
|
||||
self.out_dir,
|
||||
add_cfg=add_cfg)
|
||||
|
||||
self.logger.info(f'Partitioned into {len(tasks)} tasks.')
|
||||
for i, task in enumerate(tasks):
|
||||
@ -54,8 +78,12 @@ class BasePartitioner:
|
||||
return tasks
|
||||
|
||||
@abstractmethod
|
||||
def partition(self, models: List[ConfigDict], datasets: List[ConfigDict],
|
||||
work_dir: str, out_dir: str) -> List[Dict]:
|
||||
def partition(self,
|
||||
models: List[ConfigDict],
|
||||
datasets: List[ConfigDict],
|
||||
work_dir: str,
|
||||
out_dir: str,
|
||||
add_cfg: Dict = {}) -> List[Dict]:
|
||||
"""Partition model-dataset pairs into tasks. Each task is defined as a
|
||||
dict and will run independently as a unit. Its structure is as
|
||||
follows:
|
||||
@ -67,6 +95,7 @@ class BasePartitioner:
|
||||
'datasets': [[]], # a nested list of dataset configs, each
|
||||
list corresponds to a model
|
||||
'work_dir': '', # the work dir
|
||||
**add_cfg # other keys to be added in the config
|
||||
}
|
||||
|
||||
Args:
|
||||
@ -76,6 +105,8 @@ class BasePartitioner:
|
||||
out_dir (str): The full output path for the task, intended for
|
||||
Partitioners to check whether the task is finished via the
|
||||
existency of result file in this directory.
|
||||
add_cfg (dict): Other common keys to be added in the task config,
|
||||
used to share the same config among tasks. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
List[Dict]: A list of tasks.
|
||||
|
@ -15,11 +15,17 @@ class NaivePartitioner(BasePartitioner):
|
||||
model-dataset pair.
|
||||
|
||||
Args:
|
||||
config (ConfigDict): The full config dict.
|
||||
out_dir (str): The output directory of tasks.
|
||||
keep_keys (List[str]): The keys to be kept from the experiment config
|
||||
to the task config.
|
||||
"""
|
||||
|
||||
def partition(self, models: List[ConfigDict], datasets: List[ConfigDict],
|
||||
work_dir: str, out_dir: str) -> List[Dict]:
|
||||
def partition(self,
|
||||
models: List[ConfigDict],
|
||||
datasets: List[ConfigDict],
|
||||
work_dir: str,
|
||||
out_dir: str,
|
||||
add_cfg: Dict = {}) -> List[Dict]:
|
||||
"""Partition model-dataset pairs into tasks. Each task is defined as a
|
||||
dict and will run independently as a unit. Its structure is as
|
||||
follows:
|
||||
@ -54,7 +60,8 @@ class NaivePartitioner(BasePartitioner):
|
||||
task = Config({
|
||||
'models': [model],
|
||||
'datasets': [[dataset]],
|
||||
'work_dir': work_dir
|
||||
'work_dir': work_dir,
|
||||
**add_cfg
|
||||
})
|
||||
tasks.append(task)
|
||||
return tasks
|
||||
|
@ -2,7 +2,7 @@ import copy
|
||||
import math
|
||||
import os.path as osp
|
||||
from fnmatch import fnmatch
|
||||
from typing import List, Tuple, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import Config, ConfigDict
|
||||
@ -25,20 +25,27 @@ class SizePartitioner(BasePartitioner):
|
||||
gen_task_coef (int): The dataset cost measurement coefficient for
|
||||
generation tasks.
|
||||
dataset_size_path (str): The path to the dataset size cache file.
|
||||
keep_keys (list[str]): The keys to be kept from the experiment config
|
||||
to the task config.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_dir: str,
|
||||
max_task_size: int = 40000,
|
||||
gen_task_coef: int = 20,
|
||||
dataset_size_path: str = '.cache/dataset_size.json'):
|
||||
super().__init__(out_dir)
|
||||
dataset_size_path: str = '.cache/dataset_size.json',
|
||||
keep_keys: List[str] = ['eval.runner.task.judge_cfg']):
|
||||
super().__init__(out_dir=out_dir, keep_keys=keep_keys)
|
||||
self.max_task_size = max_task_size
|
||||
self.gen_task_coef = gen_task_coef
|
||||
self.dataset_size_path = dataset_size_path
|
||||
|
||||
def partition(self, models: List[ConfigDict], datasets: List[ConfigDict],
|
||||
work_dir: str, out_dir: str) -> List[ConfigDict]:
|
||||
def partition(self,
|
||||
models: List[ConfigDict],
|
||||
datasets: List[ConfigDict],
|
||||
work_dir: str,
|
||||
out_dir: str,
|
||||
add_cfg: Dict = {}) -> List[ConfigDict]:
|
||||
"""Partition model-dataset pairs into tasks. Each task is defined as a
|
||||
dict and will run independently as a unit. Its structure is as
|
||||
follows:
|
||||
@ -50,6 +57,7 @@ class SizePartitioner(BasePartitioner):
|
||||
'datasets': [[]], # a nested list of dataset configs, each
|
||||
list corresponds to a model
|
||||
'work_dir': '', # the work dir
|
||||
**add_cfg # other keys to be kept in the config
|
||||
}
|
||||
|
||||
Args:
|
||||
@ -59,6 +67,8 @@ class SizePartitioner(BasePartitioner):
|
||||
out_dir (str): The full output path for the task, intended for
|
||||
Partitioners to check whether the task is finished via the
|
||||
existency of result file in this directory.
|
||||
add_cfg (dict): Other common keys to be added in the task config,
|
||||
used to share the same config among tasks. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
List[ConfigDict]: A list of tasks.
|
||||
@ -72,7 +82,8 @@ class SizePartitioner(BasePartitioner):
|
||||
task = Config({
|
||||
'models': [model],
|
||||
'datasets': [[]],
|
||||
'work_dir': work_dir
|
||||
'work_dir': work_dir,
|
||||
**add_cfg
|
||||
})
|
||||
num_data = 0
|
||||
for dataset in datasets:
|
||||
@ -91,7 +102,8 @@ class SizePartitioner(BasePartitioner):
|
||||
Config({
|
||||
'models': [model],
|
||||
'datasets': [[dataset_split]],
|
||||
'work_dir': work_dir
|
||||
'work_dir': work_dir,
|
||||
**add_cfg
|
||||
}))
|
||||
else:
|
||||
if num_data + dataset_size > self.max_task_size:
|
||||
@ -99,7 +111,8 @@ class SizePartitioner(BasePartitioner):
|
||||
task = Config({
|
||||
'models': [model],
|
||||
'datasets': [[]],
|
||||
'work_dir': work_dir
|
||||
'work_dir': work_dir,
|
||||
**add_cfg
|
||||
})
|
||||
num_data = 0
|
||||
task['datasets'][0].append(dataset)
|
||||
|
@ -63,11 +63,11 @@ class DLCRunner(BaseRunner):
|
||||
status = [self._launch(task, random_sleep=False) for task in tasks]
|
||||
return status
|
||||
|
||||
def _launch(self, task_cfg: ConfigDict, random_sleep: bool = True):
|
||||
def _launch(self, cfg: ConfigDict, random_sleep: bool = True):
|
||||
"""Launch a single task.
|
||||
|
||||
Args:
|
||||
task_cfg (ConfigDict): Task config.
|
||||
cfg (ConfigDict): Task config.
|
||||
random_sleep (bool): Whether to sleep for a random time before
|
||||
running the command. This avoids cluster error when launching
|
||||
multiple tasks at the same time. Default: True.
|
||||
@ -76,10 +76,7 @@ class DLCRunner(BaseRunner):
|
||||
tuple[str, int]: Task name and exit code.
|
||||
"""
|
||||
|
||||
task_type = self.task_cfg.type
|
||||
if isinstance(self.task_cfg.type, str):
|
||||
task_type = TASKS.get(task_type)
|
||||
task = task_type(task_cfg)
|
||||
task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type']))
|
||||
num_gpus = task.num_gpus
|
||||
task_name = task.name
|
||||
|
||||
@ -87,7 +84,7 @@ class DLCRunner(BaseRunner):
|
||||
mmengine.mkdir_or_exist('tmp/')
|
||||
param_file = f'tmp/{os.getpid()}_params.py'
|
||||
try:
|
||||
task_cfg.dump(param_file)
|
||||
cfg.dump(param_file)
|
||||
|
||||
# Build up DLC command
|
||||
pwd = os.getcwd()
|
||||
|
@ -57,7 +57,7 @@ class LocalRunner(BaseRunner):
|
||||
status = []
|
||||
if self.debug:
|
||||
for task in tasks:
|
||||
task = TASKS.build(dict(type=self.task_cfg.type, cfg=task))
|
||||
task = TASKS.build(dict(cfg=task, type=self.task_cfg['type']))
|
||||
task_name = task.name
|
||||
# get cmd
|
||||
mmengine.mkdir_or_exist('tmp/')
|
||||
@ -94,7 +94,7 @@ class LocalRunner(BaseRunner):
|
||||
lock = Lock()
|
||||
|
||||
def submit(task, index):
|
||||
task = TASKS.build(dict(type=self.task_cfg.type, cfg=task))
|
||||
task = TASKS.build(dict(cfg=task, type=self.task_cfg['type']))
|
||||
num_gpus = task.num_gpus
|
||||
assert len(gpus) >= num_gpus
|
||||
|
||||
|
@ -69,11 +69,11 @@ class SlurmRunner(BaseRunner):
|
||||
status = [self._launch(task, random_sleep=False) for task in tasks]
|
||||
return status
|
||||
|
||||
def _launch(self, task_cfg: ConfigDict, random_sleep: bool = True):
|
||||
def _launch(self, cfg: ConfigDict, random_sleep: bool = True):
|
||||
"""Launch a single task.
|
||||
|
||||
Args:
|
||||
task_cfg (ConfigDict): Task config.
|
||||
cfg (ConfigDict): Task config.
|
||||
random_sleep (bool): Whether to sleep for a random time before
|
||||
running the command. This avoids cluster error when launching
|
||||
multiple tasks at the same time. Default: True.
|
||||
@ -81,10 +81,7 @@ class SlurmRunner(BaseRunner):
|
||||
Returns:
|
||||
tuple[str, int]: Task name and exit code.
|
||||
"""
|
||||
task_type = self.task_cfg.type
|
||||
if isinstance(self.task_cfg.type, str):
|
||||
task_type = TASKS.get(task_type)
|
||||
task = task_type(task_cfg)
|
||||
task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type']))
|
||||
num_gpus = task.num_gpus
|
||||
task_name = task.name
|
||||
|
||||
@ -92,7 +89,7 @@ class SlurmRunner(BaseRunner):
|
||||
mmengine.mkdir_or_exist('tmp/')
|
||||
param_file = f'tmp/{os.getpid()}_params.py'
|
||||
try:
|
||||
task_cfg.dump(param_file)
|
||||
cfg.dump(param_file)
|
||||
|
||||
# Build up slurm command
|
||||
tmpl = 'srun'
|
||||
|
@ -1,6 +1,8 @@
|
||||
import argparse
|
||||
import copy
|
||||
import fnmatch
|
||||
import os.path as osp
|
||||
import random
|
||||
import time
|
||||
from collections import Counter
|
||||
from inspect import signature
|
||||
@ -10,12 +12,14 @@ import mmengine
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
from opencompass.openicl.icl_evaluator.lm_evaluator import LMEvaluator
|
||||
from opencompass.registry import (ICL_EVALUATORS, MODELS, TASKS,
|
||||
TEXT_POSTPROCESSORS)
|
||||
from opencompass.tasks.base import BaseTask
|
||||
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
|
||||
get_infer_output_path, get_logger,
|
||||
task_abbr_from_cfg)
|
||||
from opencompass.utils.types import get_type_from_cfg
|
||||
|
||||
|
||||
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
|
||||
@ -24,6 +28,9 @@ class OpenICLEvalTask(BaseTask):
|
||||
|
||||
This task is used to evaluate the metric between predictions and
|
||||
references.
|
||||
|
||||
Args:
|
||||
cfg (ConfigDict): The configuration of the entire evaluation task.
|
||||
"""
|
||||
|
||||
name_prefix = 'OpenICLEval'
|
||||
@ -32,12 +39,30 @@ class OpenICLEvalTask(BaseTask):
|
||||
|
||||
def __init__(self, cfg: ConfigDict):
|
||||
super().__init__(cfg)
|
||||
self.num_gpus = 0
|
||||
self.logger = get_logger()
|
||||
judge_cfg = cfg.eval.runner.task.get('judge_cfg', {})
|
||||
run_cfg = judge_cfg.get('run_cfg', {})
|
||||
self.num_gpus = run_cfg.get('num_gpus', 0)
|
||||
self.num_procs = run_cfg.get('num_procs', 1)
|
||||
self.judge_cfg = copy.deepcopy(judge_cfg)
|
||||
|
||||
def get_command(self, cfg_path, template):
|
||||
"""Get the command template for the task.
|
||||
|
||||
Args:
|
||||
cfg_path (str): The path to the config file of the task.
|
||||
template (str): The template which have '{task_cmd}' to format
|
||||
the command.
|
||||
"""
|
||||
script_path = __file__
|
||||
command = f'python3 {script_path} {cfg_path}'
|
||||
if self.num_gpus > 0:
|
||||
port = random.randint(12000, 32000)
|
||||
command = (f'torchrun --master_port={port} '
|
||||
f'--nproc_per_node {self.num_procs} '
|
||||
f'{script_path} {cfg_path}')
|
||||
else:
|
||||
command = f'python {script_path} {cfg_path}'
|
||||
|
||||
return template.format(task_cmd=command)
|
||||
|
||||
def run(self):
|
||||
@ -94,6 +119,10 @@ class OpenICLEvalTask(BaseTask):
|
||||
# Get sc_size if use Self-Consistency
|
||||
sc_size = self.eval_cfg.get('sc_size')
|
||||
|
||||
# Get out_path
|
||||
out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg,
|
||||
osp.join(self.work_dir, 'results'))
|
||||
|
||||
if not osp.exists(osp.realpath(filename)) and not osp.exists(
|
||||
osp.realpath(partial_filename)):
|
||||
result = {'error': 'No predictions found.'}
|
||||
@ -160,9 +189,18 @@ class OpenICLEvalTask(BaseTask):
|
||||
Counter(s).most_common(1)[0][0] for s in pred_strs
|
||||
]
|
||||
|
||||
if get_type_from_cfg(self.eval_cfg['evaluator']) == LMEvaluator:
|
||||
if not self.judge_cfg:
|
||||
raise ValueError('Using LMEvaluator in dataset, but '
|
||||
'missing "eval.runner.task.judge_cfg" '
|
||||
'as the judge configuration.')
|
||||
self.eval_cfg['evaluator']['judge_cfg'] = self.judge_cfg
|
||||
self.eval_cfg['evaluator']['dataset_cfg'] = self.dataset_cfg
|
||||
self.eval_cfg['evaluator']['output_path'] = out_path
|
||||
icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator'])
|
||||
preds['predictions'] = pred_strs
|
||||
preds['references'] = test_set[self.output_column]
|
||||
preds['references'] = (test_set[self.output_column]
|
||||
if self.output_column else None)
|
||||
preds = {
|
||||
k: preds[k]
|
||||
for k in signature(icl_evaluator.score).parameters
|
||||
@ -177,10 +215,12 @@ class OpenICLEvalTask(BaseTask):
|
||||
self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}: {result}')
|
||||
|
||||
# Save result
|
||||
out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg,
|
||||
osp.join(self.work_dir, 'results'))
|
||||
mkdir_or_exist(osp.split(out_path)[0])
|
||||
mmengine.dump(result, out_path)
|
||||
mmengine.dump(result,
|
||||
open(out_path, 'w', encoding='utf-8'),
|
||||
file_format='json',
|
||||
ensure_ascii=False,
|
||||
indent=4)
|
||||
|
||||
def _extract_role_pred(self, s: str, begin_str: Optional[str],
|
||||
end_str: Optional[str]) -> str:
|
||||
|
@ -5,7 +5,7 @@ from mmengine.config import ConfigDict
|
||||
from opencompass.registry import LOAD_DATASET, MODELS
|
||||
|
||||
|
||||
def build_dataset_from_cfg(dataset_cfg: ConfigDict) -> ConfigDict:
|
||||
def build_dataset_from_cfg(dataset_cfg: ConfigDict):
|
||||
dataset_cfg = copy.deepcopy(dataset_cfg)
|
||||
dataset_cfg.pop('infer_cfg', None)
|
||||
dataset_cfg.pop('eval_cfg', None)
|
||||
@ -13,7 +13,7 @@ def build_dataset_from_cfg(dataset_cfg: ConfigDict) -> ConfigDict:
|
||||
return LOAD_DATASET.build(dataset_cfg)
|
||||
|
||||
|
||||
def build_model_from_cfg(model_cfg: ConfigDict) -> ConfigDict:
|
||||
def build_model_from_cfg(model_cfg: ConfigDict):
|
||||
model_cfg = copy.deepcopy(model_cfg)
|
||||
model_cfg.pop('run_cfg', None)
|
||||
model_cfg.pop('max_out_len', None)
|
||||
|
@ -86,3 +86,15 @@ def last_option_postprocess(text: str, options: str) -> str:
|
||||
if match:
|
||||
return match[-1]
|
||||
return ''
|
||||
|
||||
|
||||
def first_number_postprocess(text: str) -> float:
|
||||
"""Return the first number in a string."""
|
||||
# regex pattern to match numbers (both integers and decimals)
|
||||
pattern = r'(-?\d*\.?\d+)'
|
||||
|
||||
# search the string for the pattern
|
||||
match = re.search(pattern, text)
|
||||
|
||||
# if a match is found, return it. Otherwise, return None.
|
||||
return float(match.group(1)) if match else None
|
||||
|
@ -1,6 +1,22 @@
|
||||
from typing import Dict, List, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
from mmengine.config import Config
|
||||
|
||||
from opencompass.registry import TASKS
|
||||
|
||||
|
||||
def get_type_from_cfg(cfg: Union[Config, Dict]) -> Any:
|
||||
"""Get the object type given MMEngine's Config.
|
||||
|
||||
It loads the "type" field and return the corresponding object type.
|
||||
"""
|
||||
type = cfg['type']
|
||||
if isinstance(type, str):
|
||||
# FIXME: This has nothing to do with any specific registry, to be fixed
|
||||
# in MMEngine
|
||||
type = TASKS.get(type)
|
||||
return type
|
||||
|
||||
|
||||
def _check_type_list(obj, typelist: List):
|
||||
|
Loading…
Reference in New Issue
Block a user