From 5a9539f375db865e3c3b38a575d30be3dbc9b3c8 Mon Sep 17 00:00:00 2001 From: Hubert <42952108+yingfhu@users.noreply.github.com> Date: Fri, 11 Aug 2023 11:19:26 +0800 Subject: [PATCH] [Feat] add safety to collections (#185) * [Feat] add safety to collections * minor fix --- ...ilcomments_ppl.py => civilcomments_clp.py} | 2 +- ..._6a2561.py => civilcomments_clp_6a2561.py} | 0 ..._a3c5fd.py => civilcomments_clp_a3c5fd.py} | 0 configs/datasets/collections/base_medium.py | 4 +++ configs/datasets/collections/chat_medium.py | 4 +++ ...ngual_ppl.py => jigsawmultilingual_clp.py} | 2 +- ...ae.py => jigsawmultilingual_clp_1af0ae.py} | 0 ...d8.py => jigsawmultilingual_clp_fe50d8.py} | 0 .../realtoxicprompts/realtoxicprompts_gen.py | 2 +- .../truthfulqa/truthfulqa_gen_1e7d8d.py | 5 ++-- .../truthfulqa/truthfulqa_gen_5ddc62.py | 1 + configs/summarizers/medium.py | 25 +++++++++++++------ .../icl_inferencer/icl_base_inferencer.py | 17 +++++++++++++ .../icl_inferencer/icl_clp_inferencer.py | 21 +++++++++++++--- 14 files changed, 66 insertions(+), 17 deletions(-) rename configs/datasets/civilcomments/{civilcomments_ppl.py => civilcomments_clp.py} (54%) rename configs/datasets/civilcomments/{civilcomments_ppl_6a2561.py => civilcomments_clp_6a2561.py} (100%) rename configs/datasets/civilcomments/{civilcomments_ppl_a3c5fd.py => civilcomments_clp_a3c5fd.py} (100%) rename configs/datasets/jigsawmultilingual/{jigsawmultilingual_ppl.py => jigsawmultilingual_clp.py} (57%) rename configs/datasets/jigsawmultilingual/{jigsawmultilingual_ppl_1af0ae.py => jigsawmultilingual_clp_1af0ae.py} (100%) rename configs/datasets/jigsawmultilingual/{jigsawmultilingual_ppl_fe50d8.py => jigsawmultilingual_clp_fe50d8.py} (100%) diff --git a/configs/datasets/civilcomments/civilcomments_ppl.py b/configs/datasets/civilcomments/civilcomments_clp.py similarity index 54% rename from configs/datasets/civilcomments/civilcomments_ppl.py rename to configs/datasets/civilcomments/civilcomments_clp.py index 99ed3abf..efcf40b0 100644 --- a/configs/datasets/civilcomments/civilcomments_ppl.py +++ b/configs/datasets/civilcomments/civilcomments_clp.py @@ -1,4 +1,4 @@ from mmengine.config import read_base with read_base(): - from .civilcomments_ppl_6a2561 import civilcomments_datasets # noqa: F401, F403 + from .civilcomments_clp_a3c5fd import civilcomments_datasets # noqa: F401, F403 diff --git a/configs/datasets/civilcomments/civilcomments_ppl_6a2561.py b/configs/datasets/civilcomments/civilcomments_clp_6a2561.py similarity index 100% rename from configs/datasets/civilcomments/civilcomments_ppl_6a2561.py rename to configs/datasets/civilcomments/civilcomments_clp_6a2561.py diff --git a/configs/datasets/civilcomments/civilcomments_ppl_a3c5fd.py b/configs/datasets/civilcomments/civilcomments_clp_a3c5fd.py similarity index 100% rename from configs/datasets/civilcomments/civilcomments_ppl_a3c5fd.py rename to configs/datasets/civilcomments/civilcomments_clp_a3c5fd.py diff --git a/configs/datasets/collections/base_medium.py b/configs/datasets/collections/base_medium.py index d3caf379..9a9962f3 100644 --- a/configs/datasets/collections/base_medium.py +++ b/configs/datasets/collections/base_medium.py @@ -53,5 +53,9 @@ with read_base(): from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets from ..flores.flores_gen_806ede import flores_datasets from ..crowspairs.crowspairs_ppl_e811e1 import crowspairs_datasets + from ..civilcomments.civilcomments_clp_a3c5fd import civilcomments_datasets + from ..jigsawmultilingual.jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets + from ..realtoxicprompts.realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets + from ..truthfulqa.truthfulqa_gen_5ddc62 import truthfulqa_datasets datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) diff --git a/configs/datasets/collections/chat_medium.py b/configs/datasets/collections/chat_medium.py index dca077bb..6b63538d 100644 --- a/configs/datasets/collections/chat_medium.py +++ b/configs/datasets/collections/chat_medium.py @@ -53,5 +53,9 @@ with read_base(): from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets from ..flores.flores_gen_806ede import flores_datasets from ..crowspairs.crowspairs_gen_21f7cb import crowspairs_datasets + from ..civilcomments.civilcomments_clp_a3c5fd import civilcomments_datasets + from ..jigsawmultilingual.jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets + from ..realtoxicprompts.realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets + from ..truthfulqa.truthfulqa_gen_5ddc62 import truthfulqa_datasets datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) diff --git a/configs/datasets/jigsawmultilingual/jigsawmultilingual_ppl.py b/configs/datasets/jigsawmultilingual/jigsawmultilingual_clp.py similarity index 57% rename from configs/datasets/jigsawmultilingual/jigsawmultilingual_ppl.py rename to configs/datasets/jigsawmultilingual/jigsawmultilingual_clp.py index 3300888c..99caa011 100644 --- a/configs/datasets/jigsawmultilingual/jigsawmultilingual_ppl.py +++ b/configs/datasets/jigsawmultilingual/jigsawmultilingual_clp.py @@ -1,4 +1,4 @@ from mmengine.config import read_base with read_base(): - from .jigsawmultilingual_ppl_fe50d8 import jigsawmultilingual_datasets # noqa: F401, F403 + from .jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets # noqa: F401, F403 diff --git a/configs/datasets/jigsawmultilingual/jigsawmultilingual_ppl_1af0ae.py b/configs/datasets/jigsawmultilingual/jigsawmultilingual_clp_1af0ae.py similarity index 100% rename from configs/datasets/jigsawmultilingual/jigsawmultilingual_ppl_1af0ae.py rename to configs/datasets/jigsawmultilingual/jigsawmultilingual_clp_1af0ae.py diff --git a/configs/datasets/jigsawmultilingual/jigsawmultilingual_ppl_fe50d8.py b/configs/datasets/jigsawmultilingual/jigsawmultilingual_clp_fe50d8.py similarity index 100% rename from configs/datasets/jigsawmultilingual/jigsawmultilingual_ppl_fe50d8.py rename to configs/datasets/jigsawmultilingual/jigsawmultilingual_clp_fe50d8.py diff --git a/configs/datasets/realtoxicprompts/realtoxicprompts_gen.py b/configs/datasets/realtoxicprompts/realtoxicprompts_gen.py index 16a9f924..b2ac7db1 100644 --- a/configs/datasets/realtoxicprompts/realtoxicprompts_gen.py +++ b/configs/datasets/realtoxicprompts/realtoxicprompts_gen.py @@ -1,4 +1,4 @@ from mmengine.config import read_base with read_base(): - from .realtoxicprompts_gen_ac723c import realtoxicprompts_datasets # noqa: F401, F403 + from .realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets # noqa: F401, F403 diff --git a/configs/datasets/truthfulqa/truthfulqa_gen_1e7d8d.py b/configs/datasets/truthfulqa/truthfulqa_gen_1e7d8d.py index cc4b959b..269ab948 100644 --- a/configs/datasets/truthfulqa/truthfulqa_gen_1e7d8d.py +++ b/configs/datasets/truthfulqa/truthfulqa_gen_1e7d8d.py @@ -11,9 +11,7 @@ truthfulqa_reader_cfg = dict( # TODO: allow empty output-column truthfulqa_infer_cfg = dict( - prompt_template=dict( - type=PromptTemplate, - template='{question}'), + prompt_template=dict(type=PromptTemplate, template='{question}'), retriever=dict(type=ZeroRetriever), inferencer=dict(type=GenInferencer)) @@ -31,6 +29,7 @@ truthfulqa_eval_cfg = dict( truthfulqa_datasets = [ dict( + abbr='truthful_qa', type=TruthfulQADataset, path='truthful_qa', name='generation', diff --git a/configs/datasets/truthfulqa/truthfulqa_gen_5ddc62.py b/configs/datasets/truthfulqa/truthfulqa_gen_5ddc62.py index a7d6c7cb..47227cc6 100644 --- a/configs/datasets/truthfulqa/truthfulqa_gen_5ddc62.py +++ b/configs/datasets/truthfulqa/truthfulqa_gen_5ddc62.py @@ -31,6 +31,7 @@ truthfulqa_eval_cfg = dict( truthfulqa_datasets = [ dict( + abbr='truthful_qa', type=TruthfulQADataset, path='truthful_qa', name='generation', diff --git a/configs/summarizers/medium.py b/configs/summarizers/medium.py index 4772e34f..8b652ccf 100644 --- a/configs/summarizers/medium.py +++ b/configs/summarizers/medium.py @@ -10,15 +10,15 @@ with read_base(): from .groups.jigsaw_multilingual import jigsaw_multilingual_summary_groups summarizer = dict( - dataset_abbrs = [ - '--------- 考试 Exam ---------', # category + dataset_abbrs=[ + '--------- 考试 Exam ---------', # category # 'Mixed', # subcategory "ceval", 'agieval', 'mmlu', "GaokaoBench", 'ARC-c', - '--------- 语言 Language ---------', # category + '--------- 语言 Language ---------', # category # '字词释义', # subcategory 'WiC', 'summedits', @@ -33,14 +33,14 @@ summarizer = dict( 'winogrande', # '翻译', # subcategory 'flores_100', - '--------- 知识 Knowledge ---------', # category + '--------- 知识 Knowledge ---------', # category # '知识问答', # subcategory 'BoolQ', 'commonsense_qa', 'nq', 'triviaqa', # '多语种问答', # subcategory - '--------- 推理 Reasoning ---------', # category + '--------- 推理 Reasoning ---------', # category # '文本蕴含', # subcategory 'cmnli', 'ocnli', @@ -67,7 +67,7 @@ summarizer = dict( 'mbpp', # '综合推理', # subcategory "bbh", - '--------- 理解 Understanding ---------', # category + '--------- 理解 Understanding ---------', # category # '阅读理解', # subcategory 'C3', 'CMRC_dev', @@ -84,11 +84,20 @@ summarizer = dict( 'eprstmt-dev', 'lambada', 'tnews-dev', - '--------- 安全 Safety ---------', # category + '--------- 安全 Safety ---------', # category # '偏见', # subcategory 'crows_pairs', + # '有毒性(判别)', # subcategory + 'civil_comments', + # '有毒性(判别)多语言', # subcategory + 'jigsaw_multilingual', + # '有毒性(生成)', # subcategory + 'real-toxicity-prompts', + # '真实性/有用性', # subcategory + 'truthful_qa', ], - summary_groups=sum([v for k, v in locals().items() if k.endswith("_summary_groups")], []), + summary_groups=sum( + [v for k, v in locals().items() if k.endswith("_summary_groups")], []), prompt_db=dict( database_path='configs/datasets/log.json', config_dir='configs/datasets', diff --git a/opencompass/openicl/icl_inferencer/icl_base_inferencer.py b/opencompass/openicl/icl_inferencer/icl_base_inferencer.py index 70f80dfe..25ba9401 100644 --- a/opencompass/openicl/icl_inferencer/icl_base_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_base_inferencer.py @@ -147,6 +147,23 @@ class PPLInferencerOutputHandler: self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl + +class CLPInferencerOutputHandler: + results_dict = {} + + def __init__(self) -> None: + self.results_dict = {} + + def write_to_json(self, save_dir: str, filename: str): + """Dump the result to a json file.""" + dump_results_dict(self.results_dict, Path(save_dir) / filename) + + def save_ice(self, ice): + for idx, example in enumerate(ice): + if str(idx) not in self.results_dict.keys(): + self.results_dict[str(idx)] = {} + self.results_dict[str(idx)]['in-context examples'] = example + def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices): if str(idx) not in self.results_dict.keys(): self.results_dict[str(idx)] = {} diff --git a/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py b/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py index 4c622d26..b369738b 100644 --- a/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py @@ -13,7 +13,7 @@ from opencompass.registry import ICL_INFERENCERS from ..icl_prompt_template import PromptTemplate from ..icl_retriever import BaseRetriever from ..utils import get_logger -from .icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler +from .icl_base_inferencer import BaseInferencer, CLPInferencerOutputHandler logger = get_logger(__name__) @@ -79,7 +79,7 @@ class CLPInferencer(BaseInferencer): output_json_filename: Optional[str] = None, normalizing_str: Optional[str] = None) -> List: # 1. Preparation for output logs - output_handler = PPLInferencerOutputHandler() + output_handler = CLPInferencerOutputHandler() ice = [] @@ -88,6 +88,20 @@ class CLPInferencer(BaseInferencer): if output_json_filename is None: output_json_filename = self.output_json_filename + # CLP cannot infer with log probability for api models + # unless model provided such options which needs specific + # implementation, open an issue if you encounter the case. + if self.model.is_api: + # Write empty file in case always rerun for this model + if self.is_main_process: + os.makedirs(output_json_filepath, exist_ok=True) + err_msg = 'API model is not supported for conditional log '\ + 'probability inference and skip this exp.' + output_handler.results_dict = {'error': err_msg} + output_handler.write_to_json(output_json_filepath, + output_json_filename) + raise ValueError(err_msg) + # 2. Get results of retrieval process if self.fix_id_list: ice_idx_list = retriever.retrieve(self.fix_id_list) @@ -117,7 +131,7 @@ class CLPInferencer(BaseInferencer): choice_ids = [self.model.tokenizer.encode(c) for c in choices] if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer': # noqa choice_ids = [c[2:] for c in choice_ids] - else: + elif hasattr(self.model.tokenizer, 'add_bos_token'): if self.model.tokenizer.add_bos_token: choice_ids = [c[1:] for c in choice_ids] if self.model.tokenizer.add_eos_token: @@ -135,6 +149,7 @@ class CLPInferencer(BaseInferencer): ice[idx], ice_template=ice_template, prompt_template=prompt_template) + prompt = self.model.parse_template(prompt, mode='ppl') if self.max_seq_len is not None: prompt_token_num = get_token_len(prompt) # add one because additional token will be added in the end