From a1ea3c094a1675962e5b142d73330c498fbe8eac Mon Sep 17 00:00:00 2001 From: Tong Gao Date: Fri, 22 Sep 2023 15:42:31 +0800 Subject: [PATCH] [Sync] Initial support of subjective evaluation (#421) Co-authored-by: Leymore --- opencompass/datasets/lmeval.py | 17 ++++ opencompass/openicl/icl_dataset_reader.py | 6 +- opencompass/openicl/icl_evaluator/__init__.py | 1 + .../openicl/icl_evaluator/lm_evaluator.py | 94 +++++++++++++++++++ opencompass/partitioners/base.py | 39 +++++++- opencompass/partitioners/naive.py | 15 ++- opencompass/partitioners/size.py | 29 ++++-- opencompass/runners/dlc.py | 11 +-- opencompass/runners/local.py | 4 +- opencompass/runners/slurm.py | 11 +-- opencompass/tasks/openicl_eval.py | 52 ++++++++-- opencompass/utils/build.py | 4 +- opencompass/utils/text_postprocessors.py | 12 +++ opencompass/utils/types.py | 18 +++- 14 files changed, 270 insertions(+), 43 deletions(-) create mode 100644 opencompass/datasets/lmeval.py create mode 100644 opencompass/openicl/icl_evaluator/lm_evaluator.py diff --git a/opencompass/datasets/lmeval.py b/opencompass/datasets/lmeval.py new file mode 100644 index 00000000..08fdb938 --- /dev/null +++ b/opencompass/datasets/lmeval.py @@ -0,0 +1,17 @@ +from typing import List, Optional + +from datasets import Dataset, DatasetDict + +from opencompass.datasets import BaseDataset + + +class LMEvalDataset(BaseDataset): + """A dataset wrapper around the evaluator inputs, designed for + OpenCompass's internal use.""" + + @staticmethod + def load(predictions: List, references: Optional[List] = None): + content = {'prediction': predictions} + if references: + content['reference'] = references + return DatasetDict(dict(test=Dataset.from_dict(content))) diff --git a/opencompass/openicl/icl_dataset_reader.py b/opencompass/openicl/icl_dataset_reader.py index 46388dd8..5aa4d1c3 100644 --- a/opencompass/openicl/icl_dataset_reader.py +++ b/opencompass/openicl/icl_dataset_reader.py @@ -58,7 +58,7 @@ class DatasetReader: def __init__(self, dataset: Union[Dataset, DatasetDict, str], input_columns: Union[List[str], str], - output_column: str, + output_column: Optional[str], input_template: Optional[PromptTemplate] = None, output_template: Optional[PromptTemplate] = None, train_split: str = 'train', @@ -68,7 +68,9 @@ class DatasetReader: self.input_columns = _check_type_list(input_columns, [List, str]) if isinstance(self.input_columns, str): self.input_columns = self.input_columns.split() - self.output_column = _check_str(output_column) + self.output_column = None + if output_column: + self.output_column = _check_str(output_column) train_range = _check_type_list(train_range, [None, int, float, str]) test_range = _check_type_list(test_range, [None, int, float, str]) diff --git a/opencompass/openicl/icl_evaluator/__init__.py b/opencompass/openicl/icl_evaluator/__init__.py index e24f8eca..137b192b 100644 --- a/opencompass/openicl/icl_evaluator/__init__.py +++ b/opencompass/openicl/icl_evaluator/__init__.py @@ -4,3 +4,4 @@ from .icl_base_evaluator import BaseEvaluator # noqa from .icl_em_evaluator import EMEvaluator # noqa from .icl_hf_evaluator import * # noqa from .icl_toxic_evaluator import ToxicEvaluator # noqa +from .lm_evaluator import LMEvaluator # noqa diff --git a/opencompass/openicl/icl_evaluator/lm_evaluator.py b/opencompass/openicl/icl_evaluator/lm_evaluator.py new file mode 100644 index 00000000..df586ccd --- /dev/null +++ b/opencompass/openicl/icl_evaluator/lm_evaluator.py @@ -0,0 +1,94 @@ +import os.path as osp +from typing import Dict, List, Optional + +import mmengine +from mmengine.config import ConfigDict + +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.registry import ICL_PROMPT_TEMPLATES +from opencompass.utils import build_dataset_from_cfg, build_model_from_cfg +from opencompass.utils.logging import get_logger +from opencompass.utils.text_postprocessors import first_number_postprocess +from opencompass.utils.types import get_type_from_cfg + + +class LMEvaluator: + """Evaluate output with language model. + + Args: + prompt_template (ConfigDict): Prompt template configuration. Used to + prompt the language model for scores. User can use two reserved + keywords, ``{prediction}`` and ``{reference}``, referring to + the prediction and optionally the reference answer. + judge_cfg (ConfigDict): The config of language model as a judge. + output_path (str): The path to prediction output. + dataset_cfg (ConfigDict, optional): The config of the dataset to be + evaluated. + postprocessor (ConfigDict): The model prediction's postprocessor + config. + """ + + def __init__( + self, + prompt_template: ConfigDict, + judge_cfg: ConfigDict, + output_path: str, + dataset_cfg: Optional[ConfigDict] = None, + postprocessor: ConfigDict = dict(type=first_number_postprocess) + ) -> None: + self.output_path = output_path + out_dir, out_name = osp.split(output_path) + if not out_dir: + out_dir = './' + + self.prompt_tmpl = ICL_PROMPT_TEMPLATES.build(prompt_template) + + max_out_len = judge_cfg.get('max_out_len', None) + batch_size = judge_cfg.get('batch_size', None) + model = build_model_from_cfg(model_cfg=judge_cfg) + self.inferencer = GenInferencer(model, + max_out_len=max_out_len, + batch_size=batch_size, + output_json_filepath=out_dir, + output_json_filename=out_name) + self.postprocessor = get_type_from_cfg(postprocessor) + self.logger = get_logger() + self.dataset_cfg = dataset_cfg + + def score(self, predictions, references: Optional[List] = None) -> Dict: + if self.dataset_cfg: + dataset = build_dataset_from_cfg(self.dataset_cfg) + dataset.reader.dataset['test'] = dataset.test.add_column( + 'prediction', predictions) + dataset.reader.input_columns.append('prediction') + if references: + dataset.reader.input_columns.append('reference') + dataset.reader.dataset['test'] = dataset.test.add_column( + 'reference', references) + else: + from opencompass.datasets.lmeval import LMEvalDataset + input_columns = ['prediction'] + if references: + input_columns.append('reference') + dataset = LMEvalDataset(reader_cfg=dict( + input_columns=input_columns, + output_column=None, + train_split='test'), + predictions=predictions, + references=references) + retriever = ZeroRetriever(dataset) + self.inferencer.inference(retriever=retriever, + prompt_template=self.prompt_tmpl) + + output = mmengine.load(self.output_path) + scores = [] + for k, v in output.items(): + score = self.postprocessor(v['prediction']) + output[k]['score'] = score + scores.append(score) + try: + output['score'] = sum(scores) / len(scores) + except Exception: + pass + return output diff --git a/opencompass/partitioners/base.py b/opencompass/partitioners/base.py index eac39f9e..cd523faa 100644 --- a/opencompass/partitioners/base.py +++ b/opencompass/partitioners/base.py @@ -13,11 +13,16 @@ class BasePartitioner: Args: out_dir (str): The output directory of tasks. + keep_keys (List[str]): The keys to be kept from the experiment config + to the task config. """ - def __init__(self, out_dir: str): + def __init__(self, + out_dir: str, + keep_keys: List[str] = ['eval.runner.task.judge_cfg']): self.logger = get_logger() self.out_dir = out_dir + self.keep_keys = keep_keys def __call__(self, cfg: ConfigDict) -> List[Dict]: """Generate tasks from config. Each task is defined as a @@ -45,7 +50,26 @@ class BasePartitioner: datasets = cfg['datasets'] work_dir = cfg['work_dir'] - tasks = self.partition(models, datasets, work_dir, self.out_dir) + add_cfg = {} + for k in self.keep_keys: + try: + key_chain = k.split('.') + ori_ptr = cfg + tgt_ptr = add_cfg + for key in key_chain[:-1]: + ori_ptr = ori_ptr[key] + if key not in tgt_ptr: + tgt_ptr[key] = {} + tgt_ptr = tgt_ptr[key] + tgt_ptr[key_chain[-1]] = ori_ptr[key_chain[-1]] + except AttributeError: + self.logger.warning(f'Key {k} not found in config, ignored.') + + tasks = self.partition(models, + datasets, + work_dir, + self.out_dir, + add_cfg=add_cfg) self.logger.info(f'Partitioned into {len(tasks)} tasks.') for i, task in enumerate(tasks): @@ -54,8 +78,12 @@ class BasePartitioner: return tasks @abstractmethod - def partition(self, models: List[ConfigDict], datasets: List[ConfigDict], - work_dir: str, out_dir: str) -> List[Dict]: + def partition(self, + models: List[ConfigDict], + datasets: List[ConfigDict], + work_dir: str, + out_dir: str, + add_cfg: Dict = {}) -> List[Dict]: """Partition model-dataset pairs into tasks. Each task is defined as a dict and will run independently as a unit. Its structure is as follows: @@ -67,6 +95,7 @@ class BasePartitioner: 'datasets': [[]], # a nested list of dataset configs, each list corresponds to a model 'work_dir': '', # the work dir + **add_cfg # other keys to be added in the config } Args: @@ -76,6 +105,8 @@ class BasePartitioner: out_dir (str): The full output path for the task, intended for Partitioners to check whether the task is finished via the existency of result file in this directory. + add_cfg (dict): Other common keys to be added in the task config, + used to share the same config among tasks. Defaults to {}. Returns: List[Dict]: A list of tasks. diff --git a/opencompass/partitioners/naive.py b/opencompass/partitioners/naive.py index 92447d76..42bfcf57 100644 --- a/opencompass/partitioners/naive.py +++ b/opencompass/partitioners/naive.py @@ -15,11 +15,17 @@ class NaivePartitioner(BasePartitioner): model-dataset pair. Args: - config (ConfigDict): The full config dict. + out_dir (str): The output directory of tasks. + keep_keys (List[str]): The keys to be kept from the experiment config + to the task config. """ - def partition(self, models: List[ConfigDict], datasets: List[ConfigDict], - work_dir: str, out_dir: str) -> List[Dict]: + def partition(self, + models: List[ConfigDict], + datasets: List[ConfigDict], + work_dir: str, + out_dir: str, + add_cfg: Dict = {}) -> List[Dict]: """Partition model-dataset pairs into tasks. Each task is defined as a dict and will run independently as a unit. Its structure is as follows: @@ -54,7 +60,8 @@ class NaivePartitioner(BasePartitioner): task = Config({ 'models': [model], 'datasets': [[dataset]], - 'work_dir': work_dir + 'work_dir': work_dir, + **add_cfg }) tasks.append(task) return tasks diff --git a/opencompass/partitioners/size.py b/opencompass/partitioners/size.py index 7b94f915..3bbd17fa 100644 --- a/opencompass/partitioners/size.py +++ b/opencompass/partitioners/size.py @@ -2,7 +2,7 @@ import copy import math import os.path as osp from fnmatch import fnmatch -from typing import List, Tuple, Union +from typing import Dict, List, Tuple, Union import mmengine from mmengine.config import Config, ConfigDict @@ -25,20 +25,27 @@ class SizePartitioner(BasePartitioner): gen_task_coef (int): The dataset cost measurement coefficient for generation tasks. dataset_size_path (str): The path to the dataset size cache file. + keep_keys (list[str]): The keys to be kept from the experiment config + to the task config. """ def __init__(self, out_dir: str, max_task_size: int = 40000, gen_task_coef: int = 20, - dataset_size_path: str = '.cache/dataset_size.json'): - super().__init__(out_dir) + dataset_size_path: str = '.cache/dataset_size.json', + keep_keys: List[str] = ['eval.runner.task.judge_cfg']): + super().__init__(out_dir=out_dir, keep_keys=keep_keys) self.max_task_size = max_task_size self.gen_task_coef = gen_task_coef self.dataset_size_path = dataset_size_path - def partition(self, models: List[ConfigDict], datasets: List[ConfigDict], - work_dir: str, out_dir: str) -> List[ConfigDict]: + def partition(self, + models: List[ConfigDict], + datasets: List[ConfigDict], + work_dir: str, + out_dir: str, + add_cfg: Dict = {}) -> List[ConfigDict]: """Partition model-dataset pairs into tasks. Each task is defined as a dict and will run independently as a unit. Its structure is as follows: @@ -50,6 +57,7 @@ class SizePartitioner(BasePartitioner): 'datasets': [[]], # a nested list of dataset configs, each list corresponds to a model 'work_dir': '', # the work dir + **add_cfg # other keys to be kept in the config } Args: @@ -59,6 +67,8 @@ class SizePartitioner(BasePartitioner): out_dir (str): The full output path for the task, intended for Partitioners to check whether the task is finished via the existency of result file in this directory. + add_cfg (dict): Other common keys to be added in the task config, + used to share the same config among tasks. Defaults to {}. Returns: List[ConfigDict]: A list of tasks. @@ -72,7 +82,8 @@ class SizePartitioner(BasePartitioner): task = Config({ 'models': [model], 'datasets': [[]], - 'work_dir': work_dir + 'work_dir': work_dir, + **add_cfg }) num_data = 0 for dataset in datasets: @@ -91,7 +102,8 @@ class SizePartitioner(BasePartitioner): Config({ 'models': [model], 'datasets': [[dataset_split]], - 'work_dir': work_dir + 'work_dir': work_dir, + **add_cfg })) else: if num_data + dataset_size > self.max_task_size: @@ -99,7 +111,8 @@ class SizePartitioner(BasePartitioner): task = Config({ 'models': [model], 'datasets': [[]], - 'work_dir': work_dir + 'work_dir': work_dir, + **add_cfg }) num_data = 0 task['datasets'][0].append(dataset) diff --git a/opencompass/runners/dlc.py b/opencompass/runners/dlc.py index 42b4b30e..d179a6f4 100644 --- a/opencompass/runners/dlc.py +++ b/opencompass/runners/dlc.py @@ -63,11 +63,11 @@ class DLCRunner(BaseRunner): status = [self._launch(task, random_sleep=False) for task in tasks] return status - def _launch(self, task_cfg: ConfigDict, random_sleep: bool = True): + def _launch(self, cfg: ConfigDict, random_sleep: bool = True): """Launch a single task. Args: - task_cfg (ConfigDict): Task config. + cfg (ConfigDict): Task config. random_sleep (bool): Whether to sleep for a random time before running the command. This avoids cluster error when launching multiple tasks at the same time. Default: True. @@ -76,10 +76,7 @@ class DLCRunner(BaseRunner): tuple[str, int]: Task name and exit code. """ - task_type = self.task_cfg.type - if isinstance(self.task_cfg.type, str): - task_type = TASKS.get(task_type) - task = task_type(task_cfg) + task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type'])) num_gpus = task.num_gpus task_name = task.name @@ -87,7 +84,7 @@ class DLCRunner(BaseRunner): mmengine.mkdir_or_exist('tmp/') param_file = f'tmp/{os.getpid()}_params.py' try: - task_cfg.dump(param_file) + cfg.dump(param_file) # Build up DLC command pwd = os.getcwd() diff --git a/opencompass/runners/local.py b/opencompass/runners/local.py index 2f9fed67..81eede41 100644 --- a/opencompass/runners/local.py +++ b/opencompass/runners/local.py @@ -57,7 +57,7 @@ class LocalRunner(BaseRunner): status = [] if self.debug: for task in tasks: - task = TASKS.build(dict(type=self.task_cfg.type, cfg=task)) + task = TASKS.build(dict(cfg=task, type=self.task_cfg['type'])) task_name = task.name # get cmd mmengine.mkdir_or_exist('tmp/') @@ -94,7 +94,7 @@ class LocalRunner(BaseRunner): lock = Lock() def submit(task, index): - task = TASKS.build(dict(type=self.task_cfg.type, cfg=task)) + task = TASKS.build(dict(cfg=task, type=self.task_cfg['type'])) num_gpus = task.num_gpus assert len(gpus) >= num_gpus diff --git a/opencompass/runners/slurm.py b/opencompass/runners/slurm.py index 5ceb5bad..363a21a9 100644 --- a/opencompass/runners/slurm.py +++ b/opencompass/runners/slurm.py @@ -69,11 +69,11 @@ class SlurmRunner(BaseRunner): status = [self._launch(task, random_sleep=False) for task in tasks] return status - def _launch(self, task_cfg: ConfigDict, random_sleep: bool = True): + def _launch(self, cfg: ConfigDict, random_sleep: bool = True): """Launch a single task. Args: - task_cfg (ConfigDict): Task config. + cfg (ConfigDict): Task config. random_sleep (bool): Whether to sleep for a random time before running the command. This avoids cluster error when launching multiple tasks at the same time. Default: True. @@ -81,10 +81,7 @@ class SlurmRunner(BaseRunner): Returns: tuple[str, int]: Task name and exit code. """ - task_type = self.task_cfg.type - if isinstance(self.task_cfg.type, str): - task_type = TASKS.get(task_type) - task = task_type(task_cfg) + task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type'])) num_gpus = task.num_gpus task_name = task.name @@ -92,7 +89,7 @@ class SlurmRunner(BaseRunner): mmengine.mkdir_or_exist('tmp/') param_file = f'tmp/{os.getpid()}_params.py' try: - task_cfg.dump(param_file) + cfg.dump(param_file) # Build up slurm command tmpl = 'srun' diff --git a/opencompass/tasks/openicl_eval.py b/opencompass/tasks/openicl_eval.py index 807a75d0..424c713e 100644 --- a/opencompass/tasks/openicl_eval.py +++ b/opencompass/tasks/openicl_eval.py @@ -1,6 +1,8 @@ import argparse +import copy import fnmatch import os.path as osp +import random import time from collections import Counter from inspect import signature @@ -10,12 +12,14 @@ import mmengine from mmengine.config import Config, ConfigDict from mmengine.utils import mkdir_or_exist +from opencompass.openicl.icl_evaluator.lm_evaluator import LMEvaluator from opencompass.registry import (ICL_EVALUATORS, MODELS, TASKS, TEXT_POSTPROCESSORS) from opencompass.tasks.base import BaseTask from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg, get_infer_output_path, get_logger, task_abbr_from_cfg) +from opencompass.utils.types import get_type_from_cfg @TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run @@ -24,6 +28,9 @@ class OpenICLEvalTask(BaseTask): This task is used to evaluate the metric between predictions and references. + + Args: + cfg (ConfigDict): The configuration of the entire evaluation task. """ name_prefix = 'OpenICLEval' @@ -32,12 +39,30 @@ class OpenICLEvalTask(BaseTask): def __init__(self, cfg: ConfigDict): super().__init__(cfg) - self.num_gpus = 0 self.logger = get_logger() + judge_cfg = cfg.eval.runner.task.get('judge_cfg', {}) + run_cfg = judge_cfg.get('run_cfg', {}) + self.num_gpus = run_cfg.get('num_gpus', 0) + self.num_procs = run_cfg.get('num_procs', 1) + self.judge_cfg = copy.deepcopy(judge_cfg) def get_command(self, cfg_path, template): + """Get the command template for the task. + + Args: + cfg_path (str): The path to the config file of the task. + template (str): The template which have '{task_cmd}' to format + the command. + """ script_path = __file__ - command = f'python3 {script_path} {cfg_path}' + if self.num_gpus > 0: + port = random.randint(12000, 32000) + command = (f'torchrun --master_port={port} ' + f'--nproc_per_node {self.num_procs} ' + f'{script_path} {cfg_path}') + else: + command = f'python {script_path} {cfg_path}' + return template.format(task_cmd=command) def run(self): @@ -94,6 +119,10 @@ class OpenICLEvalTask(BaseTask): # Get sc_size if use Self-Consistency sc_size = self.eval_cfg.get('sc_size') + # Get out_path + out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg, + osp.join(self.work_dir, 'results')) + if not osp.exists(osp.realpath(filename)) and not osp.exists( osp.realpath(partial_filename)): result = {'error': 'No predictions found.'} @@ -160,9 +189,18 @@ class OpenICLEvalTask(BaseTask): Counter(s).most_common(1)[0][0] for s in pred_strs ] + if get_type_from_cfg(self.eval_cfg['evaluator']) == LMEvaluator: + if not self.judge_cfg: + raise ValueError('Using LMEvaluator in dataset, but ' + 'missing "eval.runner.task.judge_cfg" ' + 'as the judge configuration.') + self.eval_cfg['evaluator']['judge_cfg'] = self.judge_cfg + self.eval_cfg['evaluator']['dataset_cfg'] = self.dataset_cfg + self.eval_cfg['evaluator']['output_path'] = out_path icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator']) preds['predictions'] = pred_strs - preds['references'] = test_set[self.output_column] + preds['references'] = (test_set[self.output_column] + if self.output_column else None) preds = { k: preds[k] for k in signature(icl_evaluator.score).parameters @@ -177,10 +215,12 @@ class OpenICLEvalTask(BaseTask): self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}: {result}') # Save result - out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg, - osp.join(self.work_dir, 'results')) mkdir_or_exist(osp.split(out_path)[0]) - mmengine.dump(result, out_path) + mmengine.dump(result, + open(out_path, 'w', encoding='utf-8'), + file_format='json', + ensure_ascii=False, + indent=4) def _extract_role_pred(self, s: str, begin_str: Optional[str], end_str: Optional[str]) -> str: diff --git a/opencompass/utils/build.py b/opencompass/utils/build.py index 7d6fe132..b27a133d 100644 --- a/opencompass/utils/build.py +++ b/opencompass/utils/build.py @@ -5,7 +5,7 @@ from mmengine.config import ConfigDict from opencompass.registry import LOAD_DATASET, MODELS -def build_dataset_from_cfg(dataset_cfg: ConfigDict) -> ConfigDict: +def build_dataset_from_cfg(dataset_cfg: ConfigDict): dataset_cfg = copy.deepcopy(dataset_cfg) dataset_cfg.pop('infer_cfg', None) dataset_cfg.pop('eval_cfg', None) @@ -13,7 +13,7 @@ def build_dataset_from_cfg(dataset_cfg: ConfigDict) -> ConfigDict: return LOAD_DATASET.build(dataset_cfg) -def build_model_from_cfg(model_cfg: ConfigDict) -> ConfigDict: +def build_model_from_cfg(model_cfg: ConfigDict): model_cfg = copy.deepcopy(model_cfg) model_cfg.pop('run_cfg', None) model_cfg.pop('max_out_len', None) diff --git a/opencompass/utils/text_postprocessors.py b/opencompass/utils/text_postprocessors.py index c6db0ba9..ec668f4d 100644 --- a/opencompass/utils/text_postprocessors.py +++ b/opencompass/utils/text_postprocessors.py @@ -86,3 +86,15 @@ def last_option_postprocess(text: str, options: str) -> str: if match: return match[-1] return '' + + +def first_number_postprocess(text: str) -> float: + """Return the first number in a string.""" + # regex pattern to match numbers (both integers and decimals) + pattern = r'(-?\d*\.?\d+)' + + # search the string for the pattern + match = re.search(pattern, text) + + # if a match is found, return it. Otherwise, return None. + return float(match.group(1)) if match else None diff --git a/opencompass/utils/types.py b/opencompass/utils/types.py index 868040f1..ea5476b6 100644 --- a/opencompass/utils/types.py +++ b/opencompass/utils/types.py @@ -1,6 +1,22 @@ -from typing import Dict, List, Union +from typing import Any, Dict, List, Union from datasets import Dataset, DatasetDict +from mmengine.config import Config + +from opencompass.registry import TASKS + + +def get_type_from_cfg(cfg: Union[Config, Dict]) -> Any: + """Get the object type given MMEngine's Config. + + It loads the "type" field and return the corresponding object type. + """ + type = cfg['type'] + if isinstance(type, str): + # FIXME: This has nothing to do with any specific registry, to be fixed + # in MMEngine + type = TASKS.get(type) + return type def _check_type_list(obj, typelist: List):