[Sync] Initial support of subjective evaluation (#421)

Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
Tong Gao 2023-09-22 15:42:31 +08:00 committed by GitHub
parent 0f2c388280
commit a1ea3c094a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 270 additions and 43 deletions

View File

@ -0,0 +1,17 @@
from typing import List, Optional
from datasets import Dataset, DatasetDict
from opencompass.datasets import BaseDataset
class LMEvalDataset(BaseDataset):
"""A dataset wrapper around the evaluator inputs, designed for
OpenCompass's internal use."""
@staticmethod
def load(predictions: List, references: Optional[List] = None):
content = {'prediction': predictions}
if references:
content['reference'] = references
return DatasetDict(dict(test=Dataset.from_dict(content)))

View File

@ -58,7 +58,7 @@ class DatasetReader:
def __init__(self,
dataset: Union[Dataset, DatasetDict, str],
input_columns: Union[List[str], str],
output_column: str,
output_column: Optional[str],
input_template: Optional[PromptTemplate] = None,
output_template: Optional[PromptTemplate] = None,
train_split: str = 'train',
@ -68,7 +68,9 @@ class DatasetReader:
self.input_columns = _check_type_list(input_columns, [List, str])
if isinstance(self.input_columns, str):
self.input_columns = self.input_columns.split()
self.output_column = _check_str(output_column)
self.output_column = None
if output_column:
self.output_column = _check_str(output_column)
train_range = _check_type_list(train_range, [None, int, float, str])
test_range = _check_type_list(test_range, [None, int, float, str])

View File

@ -4,3 +4,4 @@ from .icl_base_evaluator import BaseEvaluator # noqa
from .icl_em_evaluator import EMEvaluator # noqa
from .icl_hf_evaluator import * # noqa
from .icl_toxic_evaluator import ToxicEvaluator # noqa
from .lm_evaluator import LMEvaluator # noqa

View File

@ -0,0 +1,94 @@
import os.path as osp
from typing import Dict, List, Optional
import mmengine
from mmengine.config import ConfigDict
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.registry import ICL_PROMPT_TEMPLATES
from opencompass.utils import build_dataset_from_cfg, build_model_from_cfg
from opencompass.utils.logging import get_logger
from opencompass.utils.text_postprocessors import first_number_postprocess
from opencompass.utils.types import get_type_from_cfg
class LMEvaluator:
"""Evaluate output with language model.
Args:
prompt_template (ConfigDict): Prompt template configuration. Used to
prompt the language model for scores. User can use two reserved
keywords, ``{prediction}`` and ``{reference}``, referring to
the prediction and optionally the reference answer.
judge_cfg (ConfigDict): The config of language model as a judge.
output_path (str): The path to prediction output.
dataset_cfg (ConfigDict, optional): The config of the dataset to be
evaluated.
postprocessor (ConfigDict): The model prediction's postprocessor
config.
"""
def __init__(
self,
prompt_template: ConfigDict,
judge_cfg: ConfigDict,
output_path: str,
dataset_cfg: Optional[ConfigDict] = None,
postprocessor: ConfigDict = dict(type=first_number_postprocess)
) -> None:
self.output_path = output_path
out_dir, out_name = osp.split(output_path)
if not out_dir:
out_dir = './'
self.prompt_tmpl = ICL_PROMPT_TEMPLATES.build(prompt_template)
max_out_len = judge_cfg.get('max_out_len', None)
batch_size = judge_cfg.get('batch_size', None)
model = build_model_from_cfg(model_cfg=judge_cfg)
self.inferencer = GenInferencer(model,
max_out_len=max_out_len,
batch_size=batch_size,
output_json_filepath=out_dir,
output_json_filename=out_name)
self.postprocessor = get_type_from_cfg(postprocessor)
self.logger = get_logger()
self.dataset_cfg = dataset_cfg
def score(self, predictions, references: Optional[List] = None) -> Dict:
if self.dataset_cfg:
dataset = build_dataset_from_cfg(self.dataset_cfg)
dataset.reader.dataset['test'] = dataset.test.add_column(
'prediction', predictions)
dataset.reader.input_columns.append('prediction')
if references:
dataset.reader.input_columns.append('reference')
dataset.reader.dataset['test'] = dataset.test.add_column(
'reference', references)
else:
from opencompass.datasets.lmeval import LMEvalDataset
input_columns = ['prediction']
if references:
input_columns.append('reference')
dataset = LMEvalDataset(reader_cfg=dict(
input_columns=input_columns,
output_column=None,
train_split='test'),
predictions=predictions,
references=references)
retriever = ZeroRetriever(dataset)
self.inferencer.inference(retriever=retriever,
prompt_template=self.prompt_tmpl)
output = mmengine.load(self.output_path)
scores = []
for k, v in output.items():
score = self.postprocessor(v['prediction'])
output[k]['score'] = score
scores.append(score)
try:
output['score'] = sum(scores) / len(scores)
except Exception:
pass
return output

View File

@ -13,11 +13,16 @@ class BasePartitioner:
Args:
out_dir (str): The output directory of tasks.
keep_keys (List[str]): The keys to be kept from the experiment config
to the task config.
"""
def __init__(self, out_dir: str):
def __init__(self,
out_dir: str,
keep_keys: List[str] = ['eval.runner.task.judge_cfg']):
self.logger = get_logger()
self.out_dir = out_dir
self.keep_keys = keep_keys
def __call__(self, cfg: ConfigDict) -> List[Dict]:
"""Generate tasks from config. Each task is defined as a
@ -45,7 +50,26 @@ class BasePartitioner:
datasets = cfg['datasets']
work_dir = cfg['work_dir']
tasks = self.partition(models, datasets, work_dir, self.out_dir)
add_cfg = {}
for k in self.keep_keys:
try:
key_chain = k.split('.')
ori_ptr = cfg
tgt_ptr = add_cfg
for key in key_chain[:-1]:
ori_ptr = ori_ptr[key]
if key not in tgt_ptr:
tgt_ptr[key] = {}
tgt_ptr = tgt_ptr[key]
tgt_ptr[key_chain[-1]] = ori_ptr[key_chain[-1]]
except AttributeError:
self.logger.warning(f'Key {k} not found in config, ignored.')
tasks = self.partition(models,
datasets,
work_dir,
self.out_dir,
add_cfg=add_cfg)
self.logger.info(f'Partitioned into {len(tasks)} tasks.')
for i, task in enumerate(tasks):
@ -54,8 +78,12 @@ class BasePartitioner:
return tasks
@abstractmethod
def partition(self, models: List[ConfigDict], datasets: List[ConfigDict],
work_dir: str, out_dir: str) -> List[Dict]:
def partition(self,
models: List[ConfigDict],
datasets: List[ConfigDict],
work_dir: str,
out_dir: str,
add_cfg: Dict = {}) -> List[Dict]:
"""Partition model-dataset pairs into tasks. Each task is defined as a
dict and will run independently as a unit. Its structure is as
follows:
@ -67,6 +95,7 @@ class BasePartitioner:
'datasets': [[]], # a nested list of dataset configs, each
list corresponds to a model
'work_dir': '', # the work dir
**add_cfg # other keys to be added in the config
}
Args:
@ -76,6 +105,8 @@ class BasePartitioner:
out_dir (str): The full output path for the task, intended for
Partitioners to check whether the task is finished via the
existency of result file in this directory.
add_cfg (dict): Other common keys to be added in the task config,
used to share the same config among tasks. Defaults to {}.
Returns:
List[Dict]: A list of tasks.

View File

@ -15,11 +15,17 @@ class NaivePartitioner(BasePartitioner):
model-dataset pair.
Args:
config (ConfigDict): The full config dict.
out_dir (str): The output directory of tasks.
keep_keys (List[str]): The keys to be kept from the experiment config
to the task config.
"""
def partition(self, models: List[ConfigDict], datasets: List[ConfigDict],
work_dir: str, out_dir: str) -> List[Dict]:
def partition(self,
models: List[ConfigDict],
datasets: List[ConfigDict],
work_dir: str,
out_dir: str,
add_cfg: Dict = {}) -> List[Dict]:
"""Partition model-dataset pairs into tasks. Each task is defined as a
dict and will run independently as a unit. Its structure is as
follows:
@ -54,7 +60,8 @@ class NaivePartitioner(BasePartitioner):
task = Config({
'models': [model],
'datasets': [[dataset]],
'work_dir': work_dir
'work_dir': work_dir,
**add_cfg
})
tasks.append(task)
return tasks

View File

@ -2,7 +2,7 @@ import copy
import math
import os.path as osp
from fnmatch import fnmatch
from typing import List, Tuple, Union
from typing import Dict, List, Tuple, Union
import mmengine
from mmengine.config import Config, ConfigDict
@ -25,20 +25,27 @@ class SizePartitioner(BasePartitioner):
gen_task_coef (int): The dataset cost measurement coefficient for
generation tasks.
dataset_size_path (str): The path to the dataset size cache file.
keep_keys (list[str]): The keys to be kept from the experiment config
to the task config.
"""
def __init__(self,
out_dir: str,
max_task_size: int = 40000,
gen_task_coef: int = 20,
dataset_size_path: str = '.cache/dataset_size.json'):
super().__init__(out_dir)
dataset_size_path: str = '.cache/dataset_size.json',
keep_keys: List[str] = ['eval.runner.task.judge_cfg']):
super().__init__(out_dir=out_dir, keep_keys=keep_keys)
self.max_task_size = max_task_size
self.gen_task_coef = gen_task_coef
self.dataset_size_path = dataset_size_path
def partition(self, models: List[ConfigDict], datasets: List[ConfigDict],
work_dir: str, out_dir: str) -> List[ConfigDict]:
def partition(self,
models: List[ConfigDict],
datasets: List[ConfigDict],
work_dir: str,
out_dir: str,
add_cfg: Dict = {}) -> List[ConfigDict]:
"""Partition model-dataset pairs into tasks. Each task is defined as a
dict and will run independently as a unit. Its structure is as
follows:
@ -50,6 +57,7 @@ class SizePartitioner(BasePartitioner):
'datasets': [[]], # a nested list of dataset configs, each
list corresponds to a model
'work_dir': '', # the work dir
**add_cfg # other keys to be kept in the config
}
Args:
@ -59,6 +67,8 @@ class SizePartitioner(BasePartitioner):
out_dir (str): The full output path for the task, intended for
Partitioners to check whether the task is finished via the
existency of result file in this directory.
add_cfg (dict): Other common keys to be added in the task config,
used to share the same config among tasks. Defaults to {}.
Returns:
List[ConfigDict]: A list of tasks.
@ -72,7 +82,8 @@ class SizePartitioner(BasePartitioner):
task = Config({
'models': [model],
'datasets': [[]],
'work_dir': work_dir
'work_dir': work_dir,
**add_cfg
})
num_data = 0
for dataset in datasets:
@ -91,7 +102,8 @@ class SizePartitioner(BasePartitioner):
Config({
'models': [model],
'datasets': [[dataset_split]],
'work_dir': work_dir
'work_dir': work_dir,
**add_cfg
}))
else:
if num_data + dataset_size > self.max_task_size:
@ -99,7 +111,8 @@ class SizePartitioner(BasePartitioner):
task = Config({
'models': [model],
'datasets': [[]],
'work_dir': work_dir
'work_dir': work_dir,
**add_cfg
})
num_data = 0
task['datasets'][0].append(dataset)

View File

@ -63,11 +63,11 @@ class DLCRunner(BaseRunner):
status = [self._launch(task, random_sleep=False) for task in tasks]
return status
def _launch(self, task_cfg: ConfigDict, random_sleep: bool = True):
def _launch(self, cfg: ConfigDict, random_sleep: bool = True):
"""Launch a single task.
Args:
task_cfg (ConfigDict): Task config.
cfg (ConfigDict): Task config.
random_sleep (bool): Whether to sleep for a random time before
running the command. This avoids cluster error when launching
multiple tasks at the same time. Default: True.
@ -76,10 +76,7 @@ class DLCRunner(BaseRunner):
tuple[str, int]: Task name and exit code.
"""
task_type = self.task_cfg.type
if isinstance(self.task_cfg.type, str):
task_type = TASKS.get(task_type)
task = task_type(task_cfg)
task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type']))
num_gpus = task.num_gpus
task_name = task.name
@ -87,7 +84,7 @@ class DLCRunner(BaseRunner):
mmengine.mkdir_or_exist('tmp/')
param_file = f'tmp/{os.getpid()}_params.py'
try:
task_cfg.dump(param_file)
cfg.dump(param_file)
# Build up DLC command
pwd = os.getcwd()

View File

@ -57,7 +57,7 @@ class LocalRunner(BaseRunner):
status = []
if self.debug:
for task in tasks:
task = TASKS.build(dict(type=self.task_cfg.type, cfg=task))
task = TASKS.build(dict(cfg=task, type=self.task_cfg['type']))
task_name = task.name
# get cmd
mmengine.mkdir_or_exist('tmp/')
@ -94,7 +94,7 @@ class LocalRunner(BaseRunner):
lock = Lock()
def submit(task, index):
task = TASKS.build(dict(type=self.task_cfg.type, cfg=task))
task = TASKS.build(dict(cfg=task, type=self.task_cfg['type']))
num_gpus = task.num_gpus
assert len(gpus) >= num_gpus

View File

@ -69,11 +69,11 @@ class SlurmRunner(BaseRunner):
status = [self._launch(task, random_sleep=False) for task in tasks]
return status
def _launch(self, task_cfg: ConfigDict, random_sleep: bool = True):
def _launch(self, cfg: ConfigDict, random_sleep: bool = True):
"""Launch a single task.
Args:
task_cfg (ConfigDict): Task config.
cfg (ConfigDict): Task config.
random_sleep (bool): Whether to sleep for a random time before
running the command. This avoids cluster error when launching
multiple tasks at the same time. Default: True.
@ -81,10 +81,7 @@ class SlurmRunner(BaseRunner):
Returns:
tuple[str, int]: Task name and exit code.
"""
task_type = self.task_cfg.type
if isinstance(self.task_cfg.type, str):
task_type = TASKS.get(task_type)
task = task_type(task_cfg)
task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type']))
num_gpus = task.num_gpus
task_name = task.name
@ -92,7 +89,7 @@ class SlurmRunner(BaseRunner):
mmengine.mkdir_or_exist('tmp/')
param_file = f'tmp/{os.getpid()}_params.py'
try:
task_cfg.dump(param_file)
cfg.dump(param_file)
# Build up slurm command
tmpl = 'srun'

View File

@ -1,6 +1,8 @@
import argparse
import copy
import fnmatch
import os.path as osp
import random
import time
from collections import Counter
from inspect import signature
@ -10,12 +12,14 @@ import mmengine
from mmengine.config import Config, ConfigDict
from mmengine.utils import mkdir_or_exist
from opencompass.openicl.icl_evaluator.lm_evaluator import LMEvaluator
from opencompass.registry import (ICL_EVALUATORS, MODELS, TASKS,
TEXT_POSTPROCESSORS)
from opencompass.tasks.base import BaseTask
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
get_infer_output_path, get_logger,
task_abbr_from_cfg)
from opencompass.utils.types import get_type_from_cfg
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
@ -24,6 +28,9 @@ class OpenICLEvalTask(BaseTask):
This task is used to evaluate the metric between predictions and
references.
Args:
cfg (ConfigDict): The configuration of the entire evaluation task.
"""
name_prefix = 'OpenICLEval'
@ -32,12 +39,30 @@ class OpenICLEvalTask(BaseTask):
def __init__(self, cfg: ConfigDict):
super().__init__(cfg)
self.num_gpus = 0
self.logger = get_logger()
judge_cfg = cfg.eval.runner.task.get('judge_cfg', {})
run_cfg = judge_cfg.get('run_cfg', {})
self.num_gpus = run_cfg.get('num_gpus', 0)
self.num_procs = run_cfg.get('num_procs', 1)
self.judge_cfg = copy.deepcopy(judge_cfg)
def get_command(self, cfg_path, template):
"""Get the command template for the task.
Args:
cfg_path (str): The path to the config file of the task.
template (str): The template which have '{task_cmd}' to format
the command.
"""
script_path = __file__
command = f'python3 {script_path} {cfg_path}'
if self.num_gpus > 0:
port = random.randint(12000, 32000)
command = (f'torchrun --master_port={port} '
f'--nproc_per_node {self.num_procs} '
f'{script_path} {cfg_path}')
else:
command = f'python {script_path} {cfg_path}'
return template.format(task_cmd=command)
def run(self):
@ -94,6 +119,10 @@ class OpenICLEvalTask(BaseTask):
# Get sc_size if use Self-Consistency
sc_size = self.eval_cfg.get('sc_size')
# Get out_path
out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'results'))
if not osp.exists(osp.realpath(filename)) and not osp.exists(
osp.realpath(partial_filename)):
result = {'error': 'No predictions found.'}
@ -160,9 +189,18 @@ class OpenICLEvalTask(BaseTask):
Counter(s).most_common(1)[0][0] for s in pred_strs
]
if get_type_from_cfg(self.eval_cfg['evaluator']) == LMEvaluator:
if not self.judge_cfg:
raise ValueError('Using LMEvaluator in dataset, but '
'missing "eval.runner.task.judge_cfg" '
'as the judge configuration.')
self.eval_cfg['evaluator']['judge_cfg'] = self.judge_cfg
self.eval_cfg['evaluator']['dataset_cfg'] = self.dataset_cfg
self.eval_cfg['evaluator']['output_path'] = out_path
icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator'])
preds['predictions'] = pred_strs
preds['references'] = test_set[self.output_column]
preds['references'] = (test_set[self.output_column]
if self.output_column else None)
preds = {
k: preds[k]
for k in signature(icl_evaluator.score).parameters
@ -177,10 +215,12 @@ class OpenICLEvalTask(BaseTask):
self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}: {result}')
# Save result
out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'results'))
mkdir_or_exist(osp.split(out_path)[0])
mmengine.dump(result, out_path)
mmengine.dump(result,
open(out_path, 'w', encoding='utf-8'),
file_format='json',
ensure_ascii=False,
indent=4)
def _extract_role_pred(self, s: str, begin_str: Optional[str],
end_str: Optional[str]) -> str:

View File

@ -5,7 +5,7 @@ from mmengine.config import ConfigDict
from opencompass.registry import LOAD_DATASET, MODELS
def build_dataset_from_cfg(dataset_cfg: ConfigDict) -> ConfigDict:
def build_dataset_from_cfg(dataset_cfg: ConfigDict):
dataset_cfg = copy.deepcopy(dataset_cfg)
dataset_cfg.pop('infer_cfg', None)
dataset_cfg.pop('eval_cfg', None)
@ -13,7 +13,7 @@ def build_dataset_from_cfg(dataset_cfg: ConfigDict) -> ConfigDict:
return LOAD_DATASET.build(dataset_cfg)
def build_model_from_cfg(model_cfg: ConfigDict) -> ConfigDict:
def build_model_from_cfg(model_cfg: ConfigDict):
model_cfg = copy.deepcopy(model_cfg)
model_cfg.pop('run_cfg', None)
model_cfg.pop('max_out_len', None)

View File

@ -86,3 +86,15 @@ def last_option_postprocess(text: str, options: str) -> str:
if match:
return match[-1]
return ''
def first_number_postprocess(text: str) -> float:
"""Return the first number in a string."""
# regex pattern to match numbers (both integers and decimals)
pattern = r'(-?\d*\.?\d+)'
# search the string for the pattern
match = re.search(pattern, text)
# if a match is found, return it. Otherwise, return None.
return float(match.group(1)) if match else None

View File

@ -1,6 +1,22 @@
from typing import Dict, List, Union
from typing import Any, Dict, List, Union
from datasets import Dataset, DatasetDict
from mmengine.config import Config
from opencompass.registry import TASKS
def get_type_from_cfg(cfg: Union[Config, Dict]) -> Any:
"""Get the object type given MMEngine's Config.
It loads the "type" field and return the corresponding object type.
"""
type = cfg['type']
if isinstance(type, str):
# FIXME: This has nothing to do with any specific registry, to be fixed
# in MMEngine
type = TASKS.get(type)
return type
def _check_type_list(obj, typelist: List):