[Feat] add safety to collections (#185)

* [Feat] add safety to collections

* minor fix
This commit is contained in:
Hubert 2023-08-11 11:19:26 +08:00 committed by GitHub
parent f4c70ba6c3
commit 5a9539f375
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 66 additions and 17 deletions

View File

@ -1,4 +1,4 @@
from mmengine.config import read_base
with read_base():
from .civilcomments_ppl_6a2561 import civilcomments_datasets # noqa: F401, F403
from .civilcomments_clp_a3c5fd import civilcomments_datasets # noqa: F401, F403

View File

@ -53,5 +53,9 @@ with read_base():
from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from ..flores.flores_gen_806ede import flores_datasets
from ..crowspairs.crowspairs_ppl_e811e1 import crowspairs_datasets
from ..civilcomments.civilcomments_clp_a3c5fd import civilcomments_datasets
from ..jigsawmultilingual.jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets
from ..realtoxicprompts.realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets
from ..truthfulqa.truthfulqa_gen_5ddc62 import truthfulqa_datasets
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])

View File

@ -53,5 +53,9 @@ with read_base():
from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from ..flores.flores_gen_806ede import flores_datasets
from ..crowspairs.crowspairs_gen_21f7cb import crowspairs_datasets
from ..civilcomments.civilcomments_clp_a3c5fd import civilcomments_datasets
from ..jigsawmultilingual.jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets
from ..realtoxicprompts.realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets
from ..truthfulqa.truthfulqa_gen_5ddc62 import truthfulqa_datasets
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])

View File

@ -1,4 +1,4 @@
from mmengine.config import read_base
with read_base():
from .jigsawmultilingual_ppl_fe50d8 import jigsawmultilingual_datasets # noqa: F401, F403
from .jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets # noqa: F401, F403

View File

@ -1,4 +1,4 @@
from mmengine.config import read_base
with read_base():
from .realtoxicprompts_gen_ac723c import realtoxicprompts_datasets # noqa: F401, F403
from .realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets # noqa: F401, F403

View File

@ -11,9 +11,7 @@ truthfulqa_reader_cfg = dict(
# TODO: allow empty output-column
truthfulqa_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template='{question}'),
prompt_template=dict(type=PromptTemplate, template='{question}'),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer))
@ -31,6 +29,7 @@ truthfulqa_eval_cfg = dict(
truthfulqa_datasets = [
dict(
abbr='truthful_qa',
type=TruthfulQADataset,
path='truthful_qa',
name='generation',

View File

@ -31,6 +31,7 @@ truthfulqa_eval_cfg = dict(
truthfulqa_datasets = [
dict(
abbr='truthful_qa',
type=TruthfulQADataset,
path='truthful_qa',
name='generation',

View File

@ -10,7 +10,7 @@ with read_base():
from .groups.jigsaw_multilingual import jigsaw_multilingual_summary_groups
summarizer = dict(
dataset_abbrs = [
dataset_abbrs=[
'--------- 考试 Exam ---------', # category
# 'Mixed', # subcategory
"ceval",
@ -87,8 +87,17 @@ summarizer = dict(
'--------- 安全 Safety ---------', # category
# '偏见', # subcategory
'crows_pairs',
# '有毒性(判别)', # subcategory
'civil_comments',
# '有毒性(判别)多语言', # subcategory
'jigsaw_multilingual',
# '有毒性(生成)', # subcategory
'real-toxicity-prompts',
# '真实性/有用性', # subcategory
'truthful_qa',
],
summary_groups=sum([v for k, v in locals().items() if k.endswith("_summary_groups")], []),
summary_groups=sum(
[v for k, v in locals().items() if k.endswith("_summary_groups")], []),
prompt_db=dict(
database_path='configs/datasets/log.json',
config_dir='configs/datasets',

View File

@ -147,6 +147,23 @@ class PPLInferencerOutputHandler:
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl
class CLPInferencerOutputHandler:
results_dict = {}
def __init__(self) -> None:
self.results_dict = {}
def write_to_json(self, save_dir: str, filename: str):
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, Path(save_dir) / filename)
def save_ice(self, ice):
for idx, example in enumerate(ice):
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}
self.results_dict[str(idx)]['in-context examples'] = example
def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices):
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}

View File

@ -13,7 +13,7 @@ from opencompass.registry import ICL_INFERENCERS
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils import get_logger
from .icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler
from .icl_base_inferencer import BaseInferencer, CLPInferencerOutputHandler
logger = get_logger(__name__)
@ -79,7 +79,7 @@ class CLPInferencer(BaseInferencer):
output_json_filename: Optional[str] = None,
normalizing_str: Optional[str] = None) -> List:
# 1. Preparation for output logs
output_handler = PPLInferencerOutputHandler()
output_handler = CLPInferencerOutputHandler()
ice = []
@ -88,6 +88,20 @@ class CLPInferencer(BaseInferencer):
if output_json_filename is None:
output_json_filename = self.output_json_filename
# CLP cannot infer with log probability for api models
# unless model provided such options which needs specific
# implementation, open an issue if you encounter the case.
if self.model.is_api:
# Write empty file in case always rerun for this model
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
err_msg = 'API model is not supported for conditional log '\
'probability inference and skip this exp.'
output_handler.results_dict = {'error': err_msg}
output_handler.write_to_json(output_json_filepath,
output_json_filename)
raise ValueError(err_msg)
# 2. Get results of retrieval process
if self.fix_id_list:
ice_idx_list = retriever.retrieve(self.fix_id_list)
@ -117,7 +131,7 @@ class CLPInferencer(BaseInferencer):
choice_ids = [self.model.tokenizer.encode(c) for c in choices]
if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer': # noqa
choice_ids = [c[2:] for c in choice_ids]
else:
elif hasattr(self.model.tokenizer, 'add_bos_token'):
if self.model.tokenizer.add_bos_token:
choice_ids = [c[1:] for c in choice_ids]
if self.model.tokenizer.add_eos_token:
@ -135,6 +149,7 @@ class CLPInferencer(BaseInferencer):
ice[idx],
ice_template=ice_template,
prompt_template=prompt_template)
prompt = self.model.parse_template(prompt, mode='ppl')
if self.max_seq_len is not None:
prompt_token_num = get_token_len(prompt)
# add one because additional token will be added in the end