mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feat] add safety to collections (#185)
* [Feat] add safety to collections * minor fix
This commit is contained in:
parent
f4c70ba6c3
commit
5a9539f375
@ -1,4 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .civilcomments_ppl_6a2561 import civilcomments_datasets # noqa: F401, F403
|
||||
from .civilcomments_clp_a3c5fd import civilcomments_datasets # noqa: F401, F403
|
@ -53,5 +53,9 @@ with read_base():
|
||||
from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
|
||||
from ..flores.flores_gen_806ede import flores_datasets
|
||||
from ..crowspairs.crowspairs_ppl_e811e1 import crowspairs_datasets
|
||||
from ..civilcomments.civilcomments_clp_a3c5fd import civilcomments_datasets
|
||||
from ..jigsawmultilingual.jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets
|
||||
from ..realtoxicprompts.realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets
|
||||
from ..truthfulqa.truthfulqa_gen_5ddc62 import truthfulqa_datasets
|
||||
|
||||
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
|
||||
|
@ -53,5 +53,9 @@ with read_base():
|
||||
from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
|
||||
from ..flores.flores_gen_806ede import flores_datasets
|
||||
from ..crowspairs.crowspairs_gen_21f7cb import crowspairs_datasets
|
||||
from ..civilcomments.civilcomments_clp_a3c5fd import civilcomments_datasets
|
||||
from ..jigsawmultilingual.jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets
|
||||
from ..realtoxicprompts.realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets
|
||||
from ..truthfulqa.truthfulqa_gen_5ddc62 import truthfulqa_datasets
|
||||
|
||||
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
|
||||
|
@ -1,4 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .jigsawmultilingual_ppl_fe50d8 import jigsawmultilingual_datasets # noqa: F401, F403
|
||||
from .jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets # noqa: F401, F403
|
@ -1,4 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .realtoxicprompts_gen_ac723c import realtoxicprompts_datasets # noqa: F401, F403
|
||||
from .realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets # noqa: F401, F403
|
||||
|
@ -11,9 +11,7 @@ truthfulqa_reader_cfg = dict(
|
||||
|
||||
# TODO: allow empty output-column
|
||||
truthfulqa_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template='{question}'),
|
||||
prompt_template=dict(type=PromptTemplate, template='{question}'),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
@ -31,6 +29,7 @@ truthfulqa_eval_cfg = dict(
|
||||
|
||||
truthfulqa_datasets = [
|
||||
dict(
|
||||
abbr='truthful_qa',
|
||||
type=TruthfulQADataset,
|
||||
path='truthful_qa',
|
||||
name='generation',
|
||||
|
@ -31,6 +31,7 @@ truthfulqa_eval_cfg = dict(
|
||||
|
||||
truthfulqa_datasets = [
|
||||
dict(
|
||||
abbr='truthful_qa',
|
||||
type=TruthfulQADataset,
|
||||
path='truthful_qa',
|
||||
name='generation',
|
||||
|
@ -10,7 +10,7 @@ with read_base():
|
||||
from .groups.jigsaw_multilingual import jigsaw_multilingual_summary_groups
|
||||
|
||||
summarizer = dict(
|
||||
dataset_abbrs = [
|
||||
dataset_abbrs=[
|
||||
'--------- 考试 Exam ---------', # category
|
||||
# 'Mixed', # subcategory
|
||||
"ceval",
|
||||
@ -87,8 +87,17 @@ summarizer = dict(
|
||||
'--------- 安全 Safety ---------', # category
|
||||
# '偏见', # subcategory
|
||||
'crows_pairs',
|
||||
# '有毒性(判别)', # subcategory
|
||||
'civil_comments',
|
||||
# '有毒性(判别)多语言', # subcategory
|
||||
'jigsaw_multilingual',
|
||||
# '有毒性(生成)', # subcategory
|
||||
'real-toxicity-prompts',
|
||||
# '真实性/有用性', # subcategory
|
||||
'truthful_qa',
|
||||
],
|
||||
summary_groups=sum([v for k, v in locals().items() if k.endswith("_summary_groups")], []),
|
||||
summary_groups=sum(
|
||||
[v for k, v in locals().items() if k.endswith("_summary_groups")], []),
|
||||
prompt_db=dict(
|
||||
database_path='configs/datasets/log.json',
|
||||
config_dir='configs/datasets',
|
||||
|
@ -147,6 +147,23 @@ class PPLInferencerOutputHandler:
|
||||
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
|
||||
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl
|
||||
|
||||
|
||||
class CLPInferencerOutputHandler:
|
||||
results_dict = {}
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.results_dict = {}
|
||||
|
||||
def write_to_json(self, save_dir: str, filename: str):
|
||||
"""Dump the result to a json file."""
|
||||
dump_results_dict(self.results_dict, Path(save_dir) / filename)
|
||||
|
||||
def save_ice(self, ice):
|
||||
for idx, example in enumerate(ice):
|
||||
if str(idx) not in self.results_dict.keys():
|
||||
self.results_dict[str(idx)] = {}
|
||||
self.results_dict[str(idx)]['in-context examples'] = example
|
||||
|
||||
def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices):
|
||||
if str(idx) not in self.results_dict.keys():
|
||||
self.results_dict[str(idx)] = {}
|
||||
|
@ -13,7 +13,7 @@ from opencompass.registry import ICL_INFERENCERS
|
||||
from ..icl_prompt_template import PromptTemplate
|
||||
from ..icl_retriever import BaseRetriever
|
||||
from ..utils import get_logger
|
||||
from .icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler
|
||||
from .icl_base_inferencer import BaseInferencer, CLPInferencerOutputHandler
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -79,7 +79,7 @@ class CLPInferencer(BaseInferencer):
|
||||
output_json_filename: Optional[str] = None,
|
||||
normalizing_str: Optional[str] = None) -> List:
|
||||
# 1. Preparation for output logs
|
||||
output_handler = PPLInferencerOutputHandler()
|
||||
output_handler = CLPInferencerOutputHandler()
|
||||
|
||||
ice = []
|
||||
|
||||
@ -88,6 +88,20 @@ class CLPInferencer(BaseInferencer):
|
||||
if output_json_filename is None:
|
||||
output_json_filename = self.output_json_filename
|
||||
|
||||
# CLP cannot infer with log probability for api models
|
||||
# unless model provided such options which needs specific
|
||||
# implementation, open an issue if you encounter the case.
|
||||
if self.model.is_api:
|
||||
# Write empty file in case always rerun for this model
|
||||
if self.is_main_process:
|
||||
os.makedirs(output_json_filepath, exist_ok=True)
|
||||
err_msg = 'API model is not supported for conditional log '\
|
||||
'probability inference and skip this exp.'
|
||||
output_handler.results_dict = {'error': err_msg}
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
output_json_filename)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
if self.fix_id_list:
|
||||
ice_idx_list = retriever.retrieve(self.fix_id_list)
|
||||
@ -117,7 +131,7 @@ class CLPInferencer(BaseInferencer):
|
||||
choice_ids = [self.model.tokenizer.encode(c) for c in choices]
|
||||
if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer': # noqa
|
||||
choice_ids = [c[2:] for c in choice_ids]
|
||||
else:
|
||||
elif hasattr(self.model.tokenizer, 'add_bos_token'):
|
||||
if self.model.tokenizer.add_bos_token:
|
||||
choice_ids = [c[1:] for c in choice_ids]
|
||||
if self.model.tokenizer.add_eos_token:
|
||||
@ -135,6 +149,7 @@ class CLPInferencer(BaseInferencer):
|
||||
ice[idx],
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
prompt = self.model.parse_template(prompt, mode='ppl')
|
||||
if self.max_seq_len is not None:
|
||||
prompt_token_num = get_token_len(prompt)
|
||||
# add one because additional token will be added in the end
|
||||
|
Loading…
Reference in New Issue
Block a user