mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Add logger info and remove dataset bugs (#61)
* Add logger info and remove dataset bugs * fix typo
This commit is contained in:
parent
77a1cc4486
commit
1326aff77e
@ -21,7 +21,7 @@ chid_infer_cfg = dict(
|
|||||||
retriever=dict(type=ZeroRetriever),
|
retriever=dict(type=ZeroRetriever),
|
||||||
inferencer=dict(type=PPLInferencer))
|
inferencer=dict(type=PPLInferencer))
|
||||||
|
|
||||||
chid_eval_cfg = dict(evaluator=dict(type=AccEvaluator), pred_role="BOT")
|
chid_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||||
|
|
||||||
chid_datasets = [
|
chid_datasets = [
|
||||||
dict(
|
dict(
|
||||||
|
@ -14,7 +14,7 @@ class GaokaoBenchDataset(BaseDataset):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(path: str):
|
def load(path: str):
|
||||||
with open(path) as f:
|
with open(path, encoding='utf-8') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
return Dataset.from_list(data['example'])
|
return Dataset.from_list(data['example'])
|
||||||
|
|
||||||
|
@ -6,11 +6,10 @@ from datasets import Dataset
|
|||||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
||||||
|
|
||||||
|
from ..base import BaseDataset
|
||||||
from .math_equivalence import is_equiv
|
from .math_equivalence import is_equiv
|
||||||
from .post_process import parse_math_answer
|
from .post_process import parse_math_answer
|
||||||
|
|
||||||
from ..base import BaseDataset
|
|
||||||
|
|
||||||
|
|
||||||
@LOAD_DATASET.register_module()
|
@LOAD_DATASET.register_module()
|
||||||
class AGIEvalDataset(BaseDataset):
|
class AGIEvalDataset(BaseDataset):
|
||||||
@ -40,7 +39,7 @@ class AGIEvalDataset_v2(BaseDataset):
|
|||||||
def load(path: str, name: str, setting_name: str):
|
def load(path: str, name: str, setting_name: str):
|
||||||
assert setting_name in 'zero-shot', 'only support zero-shot setting'
|
assert setting_name in 'zero-shot', 'only support zero-shot setting'
|
||||||
filename = osp.join(path, name + '.jsonl')
|
filename = osp.join(path, name + '.jsonl')
|
||||||
with open(filename) as f:
|
with open(filename, encoding='utf-8') as f:
|
||||||
_data = [json.loads(line.strip()) for line in f]
|
_data = [json.loads(line.strip()) for line in f]
|
||||||
data = []
|
data = []
|
||||||
for _d in _data:
|
for _d in _data:
|
||||||
|
@ -17,7 +17,7 @@ class MMLUDataset(BaseDataset):
|
|||||||
for split in ['dev', 'test']:
|
for split in ['dev', 'test']:
|
||||||
raw_data = []
|
raw_data = []
|
||||||
filename = osp.join(path, split, f'{name}_{split}.csv')
|
filename = osp.join(path, split, f'{name}_{split}.csv')
|
||||||
with open(filename) as f:
|
with open(filename, encoding='utf-8') as f:
|
||||||
reader = csv.reader(f)
|
reader = csv.reader(f)
|
||||||
for row in reader:
|
for row in reader:
|
||||||
assert len(row) == 6
|
assert len(row) == 6
|
||||||
|
@ -125,6 +125,8 @@ class OpenICLEvalTask(BaseTask):
|
|||||||
self.logger.error(
|
self.logger.error(
|
||||||
f'Task {task_abbr_from_cfg(self.cfg)}: {result["error"]}')
|
f'Task {task_abbr_from_cfg(self.cfg)}: {result["error"]}')
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}: {result}')
|
||||||
|
|
||||||
# Save result
|
# Save result
|
||||||
out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg,
|
out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg,
|
||||||
|
@ -11,7 +11,8 @@ from opencompass.registry import (ICL_INFERENCERS, ICL_PROMPT_TEMPLATES,
|
|||||||
ICL_RETRIEVERS, TASKS)
|
ICL_RETRIEVERS, TASKS)
|
||||||
from opencompass.tasks.base import BaseTask
|
from opencompass.tasks.base import BaseTask
|
||||||
from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg,
|
from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg,
|
||||||
get_infer_output_path, get_logger)
|
get_infer_output_path, get_logger,
|
||||||
|
task_abbr_from_cfg)
|
||||||
|
|
||||||
|
|
||||||
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
|
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
|
||||||
@ -30,6 +31,7 @@ class OpenICLInferTask(BaseTask):
|
|||||||
run_cfg = self.model_cfgs[0].get('run_cfg', {})
|
run_cfg = self.model_cfgs[0].get('run_cfg', {})
|
||||||
self.num_gpus = run_cfg.get('num_gpus', 0)
|
self.num_gpus = run_cfg.get('num_gpus', 0)
|
||||||
self.num_procs = run_cfg.get('num_procs', 1)
|
self.num_procs = run_cfg.get('num_procs', 1)
|
||||||
|
self.logger = get_logger()
|
||||||
|
|
||||||
def get_command(self, cfg_path, template):
|
def get_command(self, cfg_path, template):
|
||||||
"""Get the command template for the task.
|
"""Get the command template for the task.
|
||||||
@ -51,6 +53,7 @@ class OpenICLInferTask(BaseTask):
|
|||||||
return template.format(task_cmd=command)
|
return template.format(task_cmd=command)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}')
|
||||||
for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs):
|
for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs):
|
||||||
self.max_out_len = model_cfg.get('max_out_len', None)
|
self.max_out_len = model_cfg.get('max_out_len', None)
|
||||||
self.batch_size = model_cfg.get('batch_size', None)
|
self.batch_size = model_cfg.get('batch_size', None)
|
||||||
@ -61,6 +64,10 @@ class OpenICLInferTask(BaseTask):
|
|||||||
self.dataset_cfg = dataset_cfg
|
self.dataset_cfg = dataset_cfg
|
||||||
self.infer_cfg = self.dataset_cfg['infer_cfg']
|
self.infer_cfg = self.dataset_cfg['infer_cfg']
|
||||||
self.dataset = build_dataset_from_cfg(self.dataset_cfg)
|
self.dataset = build_dataset_from_cfg(self.dataset_cfg)
|
||||||
|
self.sub_cfg = {
|
||||||
|
'models': [self.model_cfg],
|
||||||
|
'datasets': [[self.dataset_cfg]],
|
||||||
|
}
|
||||||
out_path = get_infer_output_path(
|
out_path = get_infer_output_path(
|
||||||
self.model_cfg, self.dataset_cfg,
|
self.model_cfg, self.dataset_cfg,
|
||||||
osp.join(self.work_dir, 'predictions'))
|
osp.join(self.work_dir, 'predictions'))
|
||||||
@ -69,6 +76,8 @@ class OpenICLInferTask(BaseTask):
|
|||||||
self._inference()
|
self._inference()
|
||||||
|
|
||||||
def _inference(self):
|
def _inference(self):
|
||||||
|
self.logger.info(
|
||||||
|
f'Start inferencing {task_abbr_from_cfg(self.sub_cfg)}')
|
||||||
|
|
||||||
assert hasattr(self.infer_cfg, 'ice_template') or hasattr(self.infer_cfg, 'prompt_template'), \
|
assert hasattr(self.infer_cfg, 'ice_template') or hasattr(self.infer_cfg, 'prompt_template'), \
|
||||||
'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501
|
'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501
|
||||||
|
Loading…
Reference in New Issue
Block a user