[Feat] add safety to collections (#185)

* [Feat] add safety to collections

* minor fix
This commit is contained in:
Hubert 2023-08-11 11:19:26 +08:00 committed by GitHub
parent f4c70ba6c3
commit 5a9539f375
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 66 additions and 17 deletions

View File

@ -1,4 +1,4 @@
from mmengine.config import read_base from mmengine.config import read_base
with read_base(): with read_base():
from .civilcomments_ppl_6a2561 import civilcomments_datasets # noqa: F401, F403 from .civilcomments_clp_a3c5fd import civilcomments_datasets # noqa: F401, F403

View File

@ -53,5 +53,9 @@ with read_base():
from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from ..flores.flores_gen_806ede import flores_datasets from ..flores.flores_gen_806ede import flores_datasets
from ..crowspairs.crowspairs_ppl_e811e1 import crowspairs_datasets from ..crowspairs.crowspairs_ppl_e811e1 import crowspairs_datasets
from ..civilcomments.civilcomments_clp_a3c5fd import civilcomments_datasets
from ..jigsawmultilingual.jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets
from ..realtoxicprompts.realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets
from ..truthfulqa.truthfulqa_gen_5ddc62 import truthfulqa_datasets
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])

View File

@ -53,5 +53,9 @@ with read_base():
from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from ..flores.flores_gen_806ede import flores_datasets from ..flores.flores_gen_806ede import flores_datasets
from ..crowspairs.crowspairs_gen_21f7cb import crowspairs_datasets from ..crowspairs.crowspairs_gen_21f7cb import crowspairs_datasets
from ..civilcomments.civilcomments_clp_a3c5fd import civilcomments_datasets
from ..jigsawmultilingual.jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets
from ..realtoxicprompts.realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets
from ..truthfulqa.truthfulqa_gen_5ddc62 import truthfulqa_datasets
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])

View File

@ -1,4 +1,4 @@
from mmengine.config import read_base from mmengine.config import read_base
with read_base(): with read_base():
from .jigsawmultilingual_ppl_fe50d8 import jigsawmultilingual_datasets # noqa: F401, F403 from .jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets # noqa: F401, F403

View File

@ -1,4 +1,4 @@
from mmengine.config import read_base from mmengine.config import read_base
with read_base(): with read_base():
from .realtoxicprompts_gen_ac723c import realtoxicprompts_datasets # noqa: F401, F403 from .realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets # noqa: F401, F403

View File

@ -11,9 +11,7 @@ truthfulqa_reader_cfg = dict(
# TODO: allow empty output-column # TODO: allow empty output-column
truthfulqa_infer_cfg = dict( truthfulqa_infer_cfg = dict(
prompt_template=dict( prompt_template=dict(type=PromptTemplate, template='{question}'),
type=PromptTemplate,
template='{question}'),
retriever=dict(type=ZeroRetriever), retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer)) inferencer=dict(type=GenInferencer))
@ -31,6 +29,7 @@ truthfulqa_eval_cfg = dict(
truthfulqa_datasets = [ truthfulqa_datasets = [
dict( dict(
abbr='truthful_qa',
type=TruthfulQADataset, type=TruthfulQADataset,
path='truthful_qa', path='truthful_qa',
name='generation', name='generation',

View File

@ -31,6 +31,7 @@ truthfulqa_eval_cfg = dict(
truthfulqa_datasets = [ truthfulqa_datasets = [
dict( dict(
abbr='truthful_qa',
type=TruthfulQADataset, type=TruthfulQADataset,
path='truthful_qa', path='truthful_qa',
name='generation', name='generation',

View File

@ -10,15 +10,15 @@ with read_base():
from .groups.jigsaw_multilingual import jigsaw_multilingual_summary_groups from .groups.jigsaw_multilingual import jigsaw_multilingual_summary_groups
summarizer = dict( summarizer = dict(
dataset_abbrs = [ dataset_abbrs=[
'--------- 考试 Exam ---------', # category '--------- 考试 Exam ---------', # category
# 'Mixed', # subcategory # 'Mixed', # subcategory
"ceval", "ceval",
'agieval', 'agieval',
'mmlu', 'mmlu',
"GaokaoBench", "GaokaoBench",
'ARC-c', 'ARC-c',
'--------- 语言 Language ---------', # category '--------- 语言 Language ---------', # category
# '字词释义', # subcategory # '字词释义', # subcategory
'WiC', 'WiC',
'summedits', 'summedits',
@ -33,14 +33,14 @@ summarizer = dict(
'winogrande', 'winogrande',
# '翻译', # subcategory # '翻译', # subcategory
'flores_100', 'flores_100',
'--------- 知识 Knowledge ---------', # category '--------- 知识 Knowledge ---------', # category
# '知识问答', # subcategory # '知识问答', # subcategory
'BoolQ', 'BoolQ',
'commonsense_qa', 'commonsense_qa',
'nq', 'nq',
'triviaqa', 'triviaqa',
# '多语种问答', # subcategory # '多语种问答', # subcategory
'--------- 推理 Reasoning ---------', # category '--------- 推理 Reasoning ---------', # category
# '文本蕴含', # subcategory # '文本蕴含', # subcategory
'cmnli', 'cmnli',
'ocnli', 'ocnli',
@ -67,7 +67,7 @@ summarizer = dict(
'mbpp', 'mbpp',
# '综合推理', # subcategory # '综合推理', # subcategory
"bbh", "bbh",
'--------- 理解 Understanding ---------', # category '--------- 理解 Understanding ---------', # category
# '阅读理解', # subcategory # '阅读理解', # subcategory
'C3', 'C3',
'CMRC_dev', 'CMRC_dev',
@ -84,11 +84,20 @@ summarizer = dict(
'eprstmt-dev', 'eprstmt-dev',
'lambada', 'lambada',
'tnews-dev', 'tnews-dev',
'--------- 安全 Safety ---------', # category '--------- 安全 Safety ---------', # category
# '偏见', # subcategory # '偏见', # subcategory
'crows_pairs', 'crows_pairs',
# '有毒性(判别)', # subcategory
'civil_comments',
# '有毒性(判别)多语言', # subcategory
'jigsaw_multilingual',
# '有毒性(生成)', # subcategory
'real-toxicity-prompts',
# '真实性/有用性', # subcategory
'truthful_qa',
], ],
summary_groups=sum([v for k, v in locals().items() if k.endswith("_summary_groups")], []), summary_groups=sum(
[v for k, v in locals().items() if k.endswith("_summary_groups")], []),
prompt_db=dict( prompt_db=dict(
database_path='configs/datasets/log.json', database_path='configs/datasets/log.json',
config_dir='configs/datasets', config_dir='configs/datasets',

View File

@ -147,6 +147,23 @@ class PPLInferencerOutputHandler:
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl
class CLPInferencerOutputHandler:
results_dict = {}
def __init__(self) -> None:
self.results_dict = {}
def write_to_json(self, save_dir: str, filename: str):
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, Path(save_dir) / filename)
def save_ice(self, ice):
for idx, example in enumerate(ice):
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}
self.results_dict[str(idx)]['in-context examples'] = example
def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices): def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices):
if str(idx) not in self.results_dict.keys(): if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {} self.results_dict[str(idx)] = {}

View File

@ -13,7 +13,7 @@ from opencompass.registry import ICL_INFERENCERS
from ..icl_prompt_template import PromptTemplate from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever from ..icl_retriever import BaseRetriever
from ..utils import get_logger from ..utils import get_logger
from .icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler from .icl_base_inferencer import BaseInferencer, CLPInferencerOutputHandler
logger = get_logger(__name__) logger = get_logger(__name__)
@ -79,7 +79,7 @@ class CLPInferencer(BaseInferencer):
output_json_filename: Optional[str] = None, output_json_filename: Optional[str] = None,
normalizing_str: Optional[str] = None) -> List: normalizing_str: Optional[str] = None) -> List:
# 1. Preparation for output logs # 1. Preparation for output logs
output_handler = PPLInferencerOutputHandler() output_handler = CLPInferencerOutputHandler()
ice = [] ice = []
@ -88,6 +88,20 @@ class CLPInferencer(BaseInferencer):
if output_json_filename is None: if output_json_filename is None:
output_json_filename = self.output_json_filename output_json_filename = self.output_json_filename
# CLP cannot infer with log probability for api models
# unless model provided such options which needs specific
# implementation, open an issue if you encounter the case.
if self.model.is_api:
# Write empty file in case always rerun for this model
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
err_msg = 'API model is not supported for conditional log '\
'probability inference and skip this exp.'
output_handler.results_dict = {'error': err_msg}
output_handler.write_to_json(output_json_filepath,
output_json_filename)
raise ValueError(err_msg)
# 2. Get results of retrieval process # 2. Get results of retrieval process
if self.fix_id_list: if self.fix_id_list:
ice_idx_list = retriever.retrieve(self.fix_id_list) ice_idx_list = retriever.retrieve(self.fix_id_list)
@ -117,7 +131,7 @@ class CLPInferencer(BaseInferencer):
choice_ids = [self.model.tokenizer.encode(c) for c in choices] choice_ids = [self.model.tokenizer.encode(c) for c in choices]
if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer': # noqa if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer': # noqa
choice_ids = [c[2:] for c in choice_ids] choice_ids = [c[2:] for c in choice_ids]
else: elif hasattr(self.model.tokenizer, 'add_bos_token'):
if self.model.tokenizer.add_bos_token: if self.model.tokenizer.add_bos_token:
choice_ids = [c[1:] for c in choice_ids] choice_ids = [c[1:] for c in choice_ids]
if self.model.tokenizer.add_eos_token: if self.model.tokenizer.add_eos_token:
@ -135,6 +149,7 @@ class CLPInferencer(BaseInferencer):
ice[idx], ice[idx],
ice_template=ice_template, ice_template=ice_template,
prompt_template=prompt_template) prompt_template=prompt_template)
prompt = self.model.parse_template(prompt, mode='ppl')
if self.max_seq_len is not None: if self.max_seq_len is not None:
prompt_token_num = get_token_len(prompt) prompt_token_num = get_token_len(prompt)
# add one because additional token will be added in the end # add one because additional token will be added in the end