[Feature] Add logger info and remove dataset bugs (#61)

* Add logger info and remove dataset bugs

* fix typo
This commit is contained in:
Leymore 2023-07-17 14:26:30 +08:00 committed by GitHub
parent 77a1cc4486
commit 1326aff77e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 17 additions and 7 deletions

View File

@ -21,7 +21,7 @@ chid_infer_cfg = dict(
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=PPLInferencer))
chid_eval_cfg = dict(evaluator=dict(type=AccEvaluator), pred_role="BOT")
chid_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
chid_datasets = [
dict(

View File

@ -14,7 +14,7 @@ class GaokaoBenchDataset(BaseDataset):
@staticmethod
def load(path: str):
with open(path) as f:
with open(path, encoding='utf-8') as f:
data = json.load(f)
return Dataset.from_list(data['example'])

View File

@ -6,11 +6,10 @@ from datasets import Dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
from ..base import BaseDataset
from .math_equivalence import is_equiv
from .post_process import parse_math_answer
from ..base import BaseDataset
@LOAD_DATASET.register_module()
class AGIEvalDataset(BaseDataset):
@ -40,7 +39,7 @@ class AGIEvalDataset_v2(BaseDataset):
def load(path: str, name: str, setting_name: str):
assert setting_name in 'zero-shot', 'only support zero-shot setting'
filename = osp.join(path, name + '.jsonl')
with open(filename) as f:
with open(filename, encoding='utf-8') as f:
_data = [json.loads(line.strip()) for line in f]
data = []
for _d in _data:

View File

@ -17,7 +17,7 @@ class MMLUDataset(BaseDataset):
for split in ['dev', 'test']:
raw_data = []
filename = osp.join(path, split, f'{name}_{split}.csv')
with open(filename) as f:
with open(filename, encoding='utf-8') as f:
reader = csv.reader(f)
for row in reader:
assert len(row) == 6

View File

@ -125,6 +125,8 @@ class OpenICLEvalTask(BaseTask):
self.logger.error(
f'Task {task_abbr_from_cfg(self.cfg)}: {result["error"]}')
return
else:
self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}: {result}')
# Save result
out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg,

View File

@ -11,7 +11,8 @@ from opencompass.registry import (ICL_INFERENCERS, ICL_PROMPT_TEMPLATES,
ICL_RETRIEVERS, TASKS)
from opencompass.tasks.base import BaseTask
from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg,
get_infer_output_path, get_logger)
get_infer_output_path, get_logger,
task_abbr_from_cfg)
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
@ -30,6 +31,7 @@ class OpenICLInferTask(BaseTask):
run_cfg = self.model_cfgs[0].get('run_cfg', {})
self.num_gpus = run_cfg.get('num_gpus', 0)
self.num_procs = run_cfg.get('num_procs', 1)
self.logger = get_logger()
def get_command(self, cfg_path, template):
"""Get the command template for the task.
@ -51,6 +53,7 @@ class OpenICLInferTask(BaseTask):
return template.format(task_cmd=command)
def run(self):
self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}')
for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs):
self.max_out_len = model_cfg.get('max_out_len', None)
self.batch_size = model_cfg.get('batch_size', None)
@ -61,6 +64,10 @@ class OpenICLInferTask(BaseTask):
self.dataset_cfg = dataset_cfg
self.infer_cfg = self.dataset_cfg['infer_cfg']
self.dataset = build_dataset_from_cfg(self.dataset_cfg)
self.sub_cfg = {
'models': [self.model_cfg],
'datasets': [[self.dataset_cfg]],
}
out_path = get_infer_output_path(
self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'predictions'))
@ -69,6 +76,8 @@ class OpenICLInferTask(BaseTask):
self._inference()
def _inference(self):
self.logger.info(
f'Start inferencing {task_abbr_from_cfg(self.sub_cfg)}')
assert hasattr(self.infer_cfg, 'ice_template') or hasattr(self.infer_cfg, 'prompt_template'), \
'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501