mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feat] add safety to collections (#185)
* [Feat] add safety to collections * minor fix
This commit is contained in:
parent
f4c70ba6c3
commit
5a9539f375
@ -1,4 +1,4 @@
|
|||||||
from mmengine.config import read_base
|
from mmengine.config import read_base
|
||||||
|
|
||||||
with read_base():
|
with read_base():
|
||||||
from .civilcomments_ppl_6a2561 import civilcomments_datasets # noqa: F401, F403
|
from .civilcomments_clp_a3c5fd import civilcomments_datasets # noqa: F401, F403
|
@ -53,5 +53,9 @@ with read_base():
|
|||||||
from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
|
from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
|
||||||
from ..flores.flores_gen_806ede import flores_datasets
|
from ..flores.flores_gen_806ede import flores_datasets
|
||||||
from ..crowspairs.crowspairs_ppl_e811e1 import crowspairs_datasets
|
from ..crowspairs.crowspairs_ppl_e811e1 import crowspairs_datasets
|
||||||
|
from ..civilcomments.civilcomments_clp_a3c5fd import civilcomments_datasets
|
||||||
|
from ..jigsawmultilingual.jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets
|
||||||
|
from ..realtoxicprompts.realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets
|
||||||
|
from ..truthfulqa.truthfulqa_gen_5ddc62 import truthfulqa_datasets
|
||||||
|
|
||||||
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
|
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
|
||||||
|
@ -53,5 +53,9 @@ with read_base():
|
|||||||
from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
|
from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
|
||||||
from ..flores.flores_gen_806ede import flores_datasets
|
from ..flores.flores_gen_806ede import flores_datasets
|
||||||
from ..crowspairs.crowspairs_gen_21f7cb import crowspairs_datasets
|
from ..crowspairs.crowspairs_gen_21f7cb import crowspairs_datasets
|
||||||
|
from ..civilcomments.civilcomments_clp_a3c5fd import civilcomments_datasets
|
||||||
|
from ..jigsawmultilingual.jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets
|
||||||
|
from ..realtoxicprompts.realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets
|
||||||
|
from ..truthfulqa.truthfulqa_gen_5ddc62 import truthfulqa_datasets
|
||||||
|
|
||||||
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
|
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from mmengine.config import read_base
|
from mmengine.config import read_base
|
||||||
|
|
||||||
with read_base():
|
with read_base():
|
||||||
from .jigsawmultilingual_ppl_fe50d8 import jigsawmultilingual_datasets # noqa: F401, F403
|
from .jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets # noqa: F401, F403
|
@ -1,4 +1,4 @@
|
|||||||
from mmengine.config import read_base
|
from mmengine.config import read_base
|
||||||
|
|
||||||
with read_base():
|
with read_base():
|
||||||
from .realtoxicprompts_gen_ac723c import realtoxicprompts_datasets # noqa: F401, F403
|
from .realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets # noqa: F401, F403
|
||||||
|
@ -11,9 +11,7 @@ truthfulqa_reader_cfg = dict(
|
|||||||
|
|
||||||
# TODO: allow empty output-column
|
# TODO: allow empty output-column
|
||||||
truthfulqa_infer_cfg = dict(
|
truthfulqa_infer_cfg = dict(
|
||||||
prompt_template=dict(
|
prompt_template=dict(type=PromptTemplate, template='{question}'),
|
||||||
type=PromptTemplate,
|
|
||||||
template='{question}'),
|
|
||||||
retriever=dict(type=ZeroRetriever),
|
retriever=dict(type=ZeroRetriever),
|
||||||
inferencer=dict(type=GenInferencer))
|
inferencer=dict(type=GenInferencer))
|
||||||
|
|
||||||
@ -31,6 +29,7 @@ truthfulqa_eval_cfg = dict(
|
|||||||
|
|
||||||
truthfulqa_datasets = [
|
truthfulqa_datasets = [
|
||||||
dict(
|
dict(
|
||||||
|
abbr='truthful_qa',
|
||||||
type=TruthfulQADataset,
|
type=TruthfulQADataset,
|
||||||
path='truthful_qa',
|
path='truthful_qa',
|
||||||
name='generation',
|
name='generation',
|
||||||
|
@ -31,6 +31,7 @@ truthfulqa_eval_cfg = dict(
|
|||||||
|
|
||||||
truthfulqa_datasets = [
|
truthfulqa_datasets = [
|
||||||
dict(
|
dict(
|
||||||
|
abbr='truthful_qa',
|
||||||
type=TruthfulQADataset,
|
type=TruthfulQADataset,
|
||||||
path='truthful_qa',
|
path='truthful_qa',
|
||||||
name='generation',
|
name='generation',
|
||||||
|
@ -10,15 +10,15 @@ with read_base():
|
|||||||
from .groups.jigsaw_multilingual import jigsaw_multilingual_summary_groups
|
from .groups.jigsaw_multilingual import jigsaw_multilingual_summary_groups
|
||||||
|
|
||||||
summarizer = dict(
|
summarizer = dict(
|
||||||
dataset_abbrs = [
|
dataset_abbrs=[
|
||||||
'--------- 考试 Exam ---------', # category
|
'--------- 考试 Exam ---------', # category
|
||||||
# 'Mixed', # subcategory
|
# 'Mixed', # subcategory
|
||||||
"ceval",
|
"ceval",
|
||||||
'agieval',
|
'agieval',
|
||||||
'mmlu',
|
'mmlu',
|
||||||
"GaokaoBench",
|
"GaokaoBench",
|
||||||
'ARC-c',
|
'ARC-c',
|
||||||
'--------- 语言 Language ---------', # category
|
'--------- 语言 Language ---------', # category
|
||||||
# '字词释义', # subcategory
|
# '字词释义', # subcategory
|
||||||
'WiC',
|
'WiC',
|
||||||
'summedits',
|
'summedits',
|
||||||
@ -33,14 +33,14 @@ summarizer = dict(
|
|||||||
'winogrande',
|
'winogrande',
|
||||||
# '翻译', # subcategory
|
# '翻译', # subcategory
|
||||||
'flores_100',
|
'flores_100',
|
||||||
'--------- 知识 Knowledge ---------', # category
|
'--------- 知识 Knowledge ---------', # category
|
||||||
# '知识问答', # subcategory
|
# '知识问答', # subcategory
|
||||||
'BoolQ',
|
'BoolQ',
|
||||||
'commonsense_qa',
|
'commonsense_qa',
|
||||||
'nq',
|
'nq',
|
||||||
'triviaqa',
|
'triviaqa',
|
||||||
# '多语种问答', # subcategory
|
# '多语种问答', # subcategory
|
||||||
'--------- 推理 Reasoning ---------', # category
|
'--------- 推理 Reasoning ---------', # category
|
||||||
# '文本蕴含', # subcategory
|
# '文本蕴含', # subcategory
|
||||||
'cmnli',
|
'cmnli',
|
||||||
'ocnli',
|
'ocnli',
|
||||||
@ -67,7 +67,7 @@ summarizer = dict(
|
|||||||
'mbpp',
|
'mbpp',
|
||||||
# '综合推理', # subcategory
|
# '综合推理', # subcategory
|
||||||
"bbh",
|
"bbh",
|
||||||
'--------- 理解 Understanding ---------', # category
|
'--------- 理解 Understanding ---------', # category
|
||||||
# '阅读理解', # subcategory
|
# '阅读理解', # subcategory
|
||||||
'C3',
|
'C3',
|
||||||
'CMRC_dev',
|
'CMRC_dev',
|
||||||
@ -84,11 +84,20 @@ summarizer = dict(
|
|||||||
'eprstmt-dev',
|
'eprstmt-dev',
|
||||||
'lambada',
|
'lambada',
|
||||||
'tnews-dev',
|
'tnews-dev',
|
||||||
'--------- 安全 Safety ---------', # category
|
'--------- 安全 Safety ---------', # category
|
||||||
# '偏见', # subcategory
|
# '偏见', # subcategory
|
||||||
'crows_pairs',
|
'crows_pairs',
|
||||||
|
# '有毒性(判别)', # subcategory
|
||||||
|
'civil_comments',
|
||||||
|
# '有毒性(判别)多语言', # subcategory
|
||||||
|
'jigsaw_multilingual',
|
||||||
|
# '有毒性(生成)', # subcategory
|
||||||
|
'real-toxicity-prompts',
|
||||||
|
# '真实性/有用性', # subcategory
|
||||||
|
'truthful_qa',
|
||||||
],
|
],
|
||||||
summary_groups=sum([v for k, v in locals().items() if k.endswith("_summary_groups")], []),
|
summary_groups=sum(
|
||||||
|
[v for k, v in locals().items() if k.endswith("_summary_groups")], []),
|
||||||
prompt_db=dict(
|
prompt_db=dict(
|
||||||
database_path='configs/datasets/log.json',
|
database_path='configs/datasets/log.json',
|
||||||
config_dir='configs/datasets',
|
config_dir='configs/datasets',
|
||||||
|
@ -147,6 +147,23 @@ class PPLInferencerOutputHandler:
|
|||||||
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
|
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
|
||||||
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl
|
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl
|
||||||
|
|
||||||
|
|
||||||
|
class CLPInferencerOutputHandler:
|
||||||
|
results_dict = {}
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.results_dict = {}
|
||||||
|
|
||||||
|
def write_to_json(self, save_dir: str, filename: str):
|
||||||
|
"""Dump the result to a json file."""
|
||||||
|
dump_results_dict(self.results_dict, Path(save_dir) / filename)
|
||||||
|
|
||||||
|
def save_ice(self, ice):
|
||||||
|
for idx, example in enumerate(ice):
|
||||||
|
if str(idx) not in self.results_dict.keys():
|
||||||
|
self.results_dict[str(idx)] = {}
|
||||||
|
self.results_dict[str(idx)]['in-context examples'] = example
|
||||||
|
|
||||||
def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices):
|
def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices):
|
||||||
if str(idx) not in self.results_dict.keys():
|
if str(idx) not in self.results_dict.keys():
|
||||||
self.results_dict[str(idx)] = {}
|
self.results_dict[str(idx)] = {}
|
||||||
|
@ -13,7 +13,7 @@ from opencompass.registry import ICL_INFERENCERS
|
|||||||
from ..icl_prompt_template import PromptTemplate
|
from ..icl_prompt_template import PromptTemplate
|
||||||
from ..icl_retriever import BaseRetriever
|
from ..icl_retriever import BaseRetriever
|
||||||
from ..utils import get_logger
|
from ..utils import get_logger
|
||||||
from .icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler
|
from .icl_base_inferencer import BaseInferencer, CLPInferencerOutputHandler
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -79,7 +79,7 @@ class CLPInferencer(BaseInferencer):
|
|||||||
output_json_filename: Optional[str] = None,
|
output_json_filename: Optional[str] = None,
|
||||||
normalizing_str: Optional[str] = None) -> List:
|
normalizing_str: Optional[str] = None) -> List:
|
||||||
# 1. Preparation for output logs
|
# 1. Preparation for output logs
|
||||||
output_handler = PPLInferencerOutputHandler()
|
output_handler = CLPInferencerOutputHandler()
|
||||||
|
|
||||||
ice = []
|
ice = []
|
||||||
|
|
||||||
@ -88,6 +88,20 @@ class CLPInferencer(BaseInferencer):
|
|||||||
if output_json_filename is None:
|
if output_json_filename is None:
|
||||||
output_json_filename = self.output_json_filename
|
output_json_filename = self.output_json_filename
|
||||||
|
|
||||||
|
# CLP cannot infer with log probability for api models
|
||||||
|
# unless model provided such options which needs specific
|
||||||
|
# implementation, open an issue if you encounter the case.
|
||||||
|
if self.model.is_api:
|
||||||
|
# Write empty file in case always rerun for this model
|
||||||
|
if self.is_main_process:
|
||||||
|
os.makedirs(output_json_filepath, exist_ok=True)
|
||||||
|
err_msg = 'API model is not supported for conditional log '\
|
||||||
|
'probability inference and skip this exp.'
|
||||||
|
output_handler.results_dict = {'error': err_msg}
|
||||||
|
output_handler.write_to_json(output_json_filepath,
|
||||||
|
output_json_filename)
|
||||||
|
raise ValueError(err_msg)
|
||||||
|
|
||||||
# 2. Get results of retrieval process
|
# 2. Get results of retrieval process
|
||||||
if self.fix_id_list:
|
if self.fix_id_list:
|
||||||
ice_idx_list = retriever.retrieve(self.fix_id_list)
|
ice_idx_list = retriever.retrieve(self.fix_id_list)
|
||||||
@ -117,7 +131,7 @@ class CLPInferencer(BaseInferencer):
|
|||||||
choice_ids = [self.model.tokenizer.encode(c) for c in choices]
|
choice_ids = [self.model.tokenizer.encode(c) for c in choices]
|
||||||
if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer': # noqa
|
if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer': # noqa
|
||||||
choice_ids = [c[2:] for c in choice_ids]
|
choice_ids = [c[2:] for c in choice_ids]
|
||||||
else:
|
elif hasattr(self.model.tokenizer, 'add_bos_token'):
|
||||||
if self.model.tokenizer.add_bos_token:
|
if self.model.tokenizer.add_bos_token:
|
||||||
choice_ids = [c[1:] for c in choice_ids]
|
choice_ids = [c[1:] for c in choice_ids]
|
||||||
if self.model.tokenizer.add_eos_token:
|
if self.model.tokenizer.add_eos_token:
|
||||||
@ -135,6 +149,7 @@ class CLPInferencer(BaseInferencer):
|
|||||||
ice[idx],
|
ice[idx],
|
||||||
ice_template=ice_template,
|
ice_template=ice_template,
|
||||||
prompt_template=prompt_template)
|
prompt_template=prompt_template)
|
||||||
|
prompt = self.model.parse_template(prompt, mode='ppl')
|
||||||
if self.max_seq_len is not None:
|
if self.max_seq_len is not None:
|
||||||
prompt_token_num = get_token_len(prompt)
|
prompt_token_num = get_token_len(prompt)
|
||||||
# add one because additional token will be added in the end
|
# add one because additional token will be added in the end
|
||||||
|
Loading…
Reference in New Issue
Block a user