[Feat] implementation for support promptbench (#239)

* [Feat] support adv_glue dataset for adversarial robustness

* reorg files

* minor fix

* minor fix

* support prompt bench demo

* minor fix

* minor fix

* minor fix

* minor fix

* minor fix

* minor fix

* minor fix

* minor fix
This commit is contained in:
Hubert 2023-09-15 15:06:53 +08:00 committed by GitHub
parent de8a154795
commit a11cb45c83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 966 additions and 0 deletions

View File

@ -0,0 +1,57 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import BM25Retriever
from opencompass.openicl.icl_inferencer import AttackInferencer
from opencompass.openicl.icl_evaluator import BleuEvaluator
from opencompass.datasets import IWSLT2017Dataset
from opencompass.utils.text_postprocessors import general_cn_postprocess
iwslt2017_reader_cfg = dict(
input_columns='en', output_column='de', train_split='validation')
original_prompt_list = [
"Translate the provided sentence from English to German while maintaining the original meaning and context:",
"Convert the following sentence from its original English language to the target language German:",
"Given the sentence below, perform a machine translation from English to German:",
"Translate the subsequent sentence from its source language English into the desired language German:",
"Accurately translate the sentence from English to German, ensuring the meaning remains intact:",
"Please perform a translation of the given sentence, converting it from English to German:",
"Translate the following text from the source language English to the target language German:",
"Using machine translation, convert the given sentence from English into the German language:",
"Translate the subsequent text passage from its original English language to the German language:",
"Perform a machine translation for the provided sentence, changing it from English to German:",
]
iwslt2017_infer_cfg = dict(
ice_template=dict(
type=PromptTemplate,
template=dict(
begin='</E>',
round=[
dict(role='HUMAN', prompt='{adv_prompt}\n{en}'),
dict(role='BOT', prompt='{de}'),
]
),
ice_token='</E>'),
retriever=dict(type=BM25Retriever, ice_num=1),
inferencer=dict(
type=AttackInferencer,
original_prompt_list=original_prompt_list,
adv_key='adv_prompt',
metric_key='score'))
iwslt2017_eval_cfg = dict(
evaluator=dict(type=BleuEvaluator),
pred_role='BOT',
pred_postprocessor=dict(type=general_cn_postprocess),
dataset_postprocessor=dict(type=general_cn_postprocess))
iwslt2017_datasets = [
dict(
abbr='iwslt',
type=IWSLT2017Dataset,
path='iwslt2017',
name='iwslt2017-en-de',
reader_cfg=iwslt2017_reader_cfg,
infer_cfg=iwslt2017_infer_cfg,
eval_cfg=iwslt2017_eval_cfg)
]

View File

@ -0,0 +1,44 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import AttackInferencer
from opencompass.datasets import MATHDataset, MATHEvaluator, math_postprocess
math_reader_cfg = dict(input_columns=['problem'], output_column='solution')
original_prompt_list = [
"Solve the following math question about",
"Determine the solution to this mathematical problem related to",
"Calculate the answer to the following math query about",
"Find the solution for this mathematical challenge with",
"Compute the result of this math task concerning",
"Resolve the following mathematical question associated with",
"Work out the answer to this math problem featuring",
"Figure out the solution for the following mathematical task with",
"Obtain the result for this math question regarding",
"Evaluate the following mathematical problem that includes",
]
math_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role="HUMAN",
prompt="{adv_prompt} {problem}:"),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=AttackInferencer, original_prompt_list=original_prompt_list,max_out_len=512, adv_key='adv_prompt'))
math_eval_cfg = dict(
evaluator=dict(type=MATHEvaluator), pred_postprocessor=dict(type=math_postprocess))
math_datasets = [
dict(
type=MATHDataset,
abbr='math',
path='./data/math/math.json',
reader_cfg=math_reader_cfg,
infer_cfg=math_infer_cfg,
eval_cfg=math_eval_cfg)
]

View File

@ -0,0 +1,48 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import AttackInferencer
from opencompass.datasets import SQuAD20Dataset, SQuAD20Evaluator
squad20_reader_cfg = dict(
input_columns=['context', 'question'],
output_column='answers')
original_prompt_list = [
"Based on the given context, provide the best possible answer. If there's no answer available in the context, respond with 'unanswerable'.",
"Identify the most relevant answer from the context. If it's not possible to find an answer, respond with 'unanswerable'.",
"Find the correct answer in the context provided. If an answer cannot be found, please respond with 'unanswerable'.",
"Please extract the most appropriate answer from the context. If an answer is not present, indicate 'unanswerable'.",
"Using the context, determine the most suitable answer. If the context doesn't contain the answer, respond with 'unanswerable'.",
"Locate the most accurate answer within the context. If the context doesn't provide an answer, respond with 'unanswerable'.",
"Please derive the most fitting answer from the context. If there isn't an answer in the context, respond with 'unanswerable'.",
"Discover the best answer based on the context. If the context doesn't include an answer, respond with 'unanswerable'.",
"From the context, provide the most precise answer. If the answer is not in the context, respond with 'unanswerable'.",
"Search the context for the most relevant answer. If the answer cannot be found, respond with 'unanswerable'.",
]
squad20_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{adv_prompt} {context}'),
dict(role='BOT', prompt='Answer:'),
], )),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=AttackInferencer, max_out_len=50,
original_prompt_list=original_prompt_list,
adv_key='adv_prompt',
metric_key='score'))
squad20_eval_cfg = dict(
evaluator=dict(type=SQuAD20Evaluator), pred_role='BOT')
squad20_datasets = [
dict(
type=SQuAD20Dataset,
abbr='squad_v2',
path='./data/SQuAD2.0/dev-v2.0.json',
reader_cfg=squad20_reader_cfg,
infer_cfg=squad20_infer_cfg,
eval_cfg=squad20_eval_cfg)
]

View File

@ -0,0 +1,61 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import AttackInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import wnliDataset
from opencompass.utils.text_postprocessors import first_option_postprocess
wnli_reader_cfg = dict(
input_columns=['sentence1', 'sentence2'],
output_column='label_option',
train_split='validation',
test_split='validation')
original_prompt_list = [
'Are the following two sentences entailment or not_entailment? Answer me with "A. entailment" or "B. not_entailment", just one word. ',
"Does the relationship between the given sentences represent entailment or not_entailment? Respond with 'A. entailment' or 'B. not_entailment'.",
"Examine the pair of sentences and determine if they exhibit entailment or not_entailment. Answer with either 'A. entailment' or 'B. not_entailment'.",
"Assess the connection between the following sentences and classify it as 'A. entailment' or 'B. not_entailment'.",
"Analyze the two provided sentences and decide if their relationship is 'A. entailment' or 'B. not_entailment'.",
"Identify whether the given pair of sentences demonstrates entailment or not_entailment. Answer with 'A. entailment' or 'B. not_entailment'.",
"Review the two statements and categorize their relationship as either 'A. entailment' or 'B. not_entailment'.",
"Please classify the relationship between the provided sentences as 'A. entailment' or 'B. not_entailment'.",
"Indicate if the connection between the following sentences is 'A. entailment' or 'B. not_entailment'.",
"Determine if the given pair of sentences displays entailment or not_entailment. Respond with 'A. entailment' or 'B. not_entailment'.",
"Considering the two sentences, identify if their relationship is 'A. entailment' or 'B. not_entailment'.",
]
wnli_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role="HUMAN",
prompt="""{adv_prompt}
Sentence 1: {sentence1}
Sentence 2: {sentence2}
Answer:"""),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(
type=AttackInferencer,
original_prompt_list=original_prompt_list,
adv_key='adv_prompt'))
wnli_eval_cfg = dict(
evaluator=dict(type=AccEvaluator),
pred_role="BOT",
pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
)
wnli_datasets = [
dict(
abbr='wnli',
type=wnliDataset,
path='glue',
name='wnli',
reader_cfg=wnli_reader_cfg,
infer_cfg=wnli_infer_cfg,
eval_cfg=wnli_eval_cfg)
]

27
configs/eval_attack.py Normal file
View File

@ -0,0 +1,27 @@
from mmengine.config import read_base
from opencompass.partitioners import NaivePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLAttackTask
with read_base():
# choose a list of datasets
from .datasets.promptbench.promptbench_wnli_gen_50662f import wnli_datasets
from .models.hf_vicuna_7b import models
datasets = wnli_datasets
# Please run whole dataset at a time, aka use `NaivePartitioner` only
# Please use `OpenICLAttackTask` if want to perform attack experiment
infer = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=LocalRunner,
max_num_workers=8,
task=dict(type=OpenICLAttackTask)),
)
attack = dict(
attack='textfooler',
query_budget=100,
prompt_topk=1,
)

View File

@ -0,0 +1,108 @@
# Prompt Attack
We support prompt attack following the idea of [PromptBench](https://github.com/microsoft/promptbench). The main purpose here is to evaluate the robustness of prompt instruction, which means when attack/modify the prompt to instruct the task, how well can this task perform as the original task.
## Set up environment
Some components are necessary to prompt attack experiment, therefore we need to set up environments.
```shell
git clone https://github.com/microsoft/promptbench.git
pip install textattack==0.3.8
export PYTHONPATH=$PYTHONPATH:promptbench/
```
## How to attack
### Add a dataset config
We will use GLUE-wnli dataset as example, most configuration settings can refer to [config.md](../user_guides/config.md) for help.
First we need support the basic dataset config, you can find the existing config files in `configs` or support your own config according to [new-dataset](./new_dataset.md)
Take the following `infer_cfg` as example, we need to define the prompt template. `adv_prompt` is the basic prompt placeholder to be attacked in the experiment. `sentence1` and `sentence2` are the input columns of this dataset. The attack will only modify the `adv_prompt` here.
Then, we should use `AttackInferencer` with `original_prompt_list` and `adv_key` to tell the inferencer where to attack and what text to be attacked.
More details can refer to `configs/datasets/promptbench/promptbench_wnli_gen_50662f.py` config file.
```python
original_prompt_list = [
'Are the following two sentences entailment or not_entailment? Answer me with "A. entailment" or "B. not_entailment", just one word. ',
"Does the relationship between the given sentences represent entailment or not_entailment? Respond with 'A. entailment' or 'B. not_entailment'.",
...,
]
wnli_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role="HUMAN",
prompt="""{adv_prompt}
Sentence 1: {sentence1}
Sentence 2: {sentence2}
Answer:"""),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(
type=AttackInferencer,
original_prompt_list=original_prompt_list,
adv_key='adv_prompt'))
```
### Add a eval config
We should use `OpenICLAttackTask` here for attack task. Also `NaivePartitioner` should be used because the attack experiment will run the whole dataset repeatedly for nearly hurdurds times to search the best attack, we do not want to split the dataset for convenience.
```note
Please choose a small dataset(example < 1000) for attack, due to the aforementioned repeated search, otherwise the time cost is enumerous.
```
There are several other options in `attack` config:
- `attack`: attack type, available options includes `textfooler`, `textbugger`, `deepwordbug`, `bertattack`, `checklist`, `stresstest`;
- `query_budget`: upper boundary of queries, which means the total numbers of running the dataset;
- `prompt_topk`: number of topk prompt to be attacked. In most case, the original prompt list is great than 10, running the whole set is time consuming.
```python
# Please run whole dataset at a time, aka use `NaivePartitioner` only
# Please use `OpenICLAttackTask` if want to perform attack experiment
infer = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=SlurmRunner,
max_num_workers=8,
task=dict(type=OpenICLAttackTask),
retry=0),
)
attack = dict(
attack='textfooler',
query_budget=100,
prompt_topk=2,
)
```
### Run the experiment
Please use `--mode infer` when run the attack experiment, and set `PYTHONPATH` env.
```shell
python run.py configs/eval_attack.py --mode infer
```
All the results will be saved in `attack` folder.
The content includes the original prompt accuracy and the attacked prompt with dropped accuracy of `topk` prompt, for instance:
```
Prompt: Assess the connection between the following sentences and classify it as 'A. entailment' or 'B. not_entailment'., acc: 59.15%
Prompt: Does the relationship between the given sentences represent entailment or not_entailment? Respond with 'A. entailment' or 'B. not_entailment'., acc: 57.75%
Prompt: Analyze the two provided sentences and decide if their relationship is 'A. entailment' or 'B. not_entailment'., acc: 56.34%
Prompt: Identify whether the given pair of sentences demonstrates entailment or not_entailment. Answer with 'A. entailment' or 'B. not_entailment'., acc: 54.93%
...
Original prompt: Assess the connection between the following sentences and classify it as 'A. entailment' or 'B. not_entailment'.
Attacked prompt: b"Assess the attach between the following sentences and sorted it as 'A. entailment' or 'B. not_entailment'."
Original acc: 59.15%, attacked acc: 40.85%, dropped acc: 18.31%
```

View File

@ -60,6 +60,7 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass.
advanced_guides/new_model.md
advanced_guides/evaluation_turbomind.md
advanced_guides/code_eval_service.md
advanced_guides/prompt_attack.md
.. _Tools:
.. toctree::

View File

@ -0,0 +1,108 @@
# 提示词攻击
OpenCompass 支持[PromptBench](https://github.com/microsoft/promptbench)的提示词攻击。其主要想法是评估提示指令的鲁棒性,也就是说,当攻击或修改提示以指导任务时,希望该任务能尽可能表现的像像原始任务一样好。
## 环境安装
提示词攻击需要依赖 `PromptBench` 中的组件,所以需要先配置好环境。
```shell
git clone https://github.com/microsoft/promptbench.git
pip install textattack==0.3.8
export PYTHONPATH=$PYTHONPATH:promptbench/
```
## 如何攻击
### 增加数据集配置文件
我们将使用GLUE-wnli数据集作为示例大部分配置设置可以参考[config.md](../user_guides/config.md)获取帮助。
首先,我们需要支持基本的数据集配置,你可以在`configs`中找到现有的配置文件,或者根据[new-dataset](./new_dataset.md)支持你自己的配置。
以下面的`infer_cfg`为例,我们需要定义提示模板。`adv_prompt`是实验中要被攻击的基本提示占位符。`sentence1`和`sentence2`是此数据集的输入。攻击只会修改`adv_prompt`字段。
然后,我们应该使用`AttackInferencer`与`original_prompt_list`和`adv_key`告诉推理器在哪里攻击和攻击什么文本。
更多详细信息可以参考`configs/datasets/promptbench/promptbench_wnli_gen_50662f.py`配置文件。
```python
original_prompt_list = [
'Are the following two sentences entailment or not_entailment? Answer me with "A. entailment" or "B. not_entailment", just one word. ',
"Does the relationship between the given sentences represent entailment or not_entailment? Respond with 'A. entailment' or 'B. not_entailment'.",
...,
]
wnli_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role="HUMAN",
prompt="""{adv_prompt}
Sentence 1: {sentence1}
Sentence 2: {sentence2}
Answer:"""),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(
type=AttackInferencer,
original_prompt_list=original_prompt_list,
adv_key='adv_prompt'))
```
### Add a eval config
我们应该在此处使用 `OpenICLAttackTask` 来进行攻击任务。还应该使用 `NaivePartitioner`,因为攻击实验将重复运行整个数据集近百次以搜索最佳攻击,为方便起见我们不希望拆分数据集。
```note
由于上述提到的重复搜索请选择小型数据集样本少于1000进行攻击否则时间成本将非常大。
```
`attack` 配置中还有其他几个选项:
- `attack`:攻击类型,可用选项包括`textfooler`, `textbugger`, `deepwordbug`, `bertattack`, `checklist`, `stresstest`
- `query_budget`:查询次数的上界,即运行数据集的总次数;
- `prompt_topk`要攻击的前k个提示的数量。在大多数情况下原始提示列表大于10运行整个集合是耗时的。
```python
# Please run whole dataset at a time, aka use `NaivePartitioner` only
# Please use `OpenICLAttackTask` if want to perform attack experiment
infer = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=SlurmRunner,
max_num_workers=8,
task=dict(type=OpenICLAttackTask),
retry=0),
)
attack = dict(
attack='textfooler',
query_budget=100,
prompt_topk=2,
)
```
### 运行试验
请当运行攻击实验的时候请使用 `--mode infer` 选项,并需要指定`PYTHONPATH`。
```shell
python run.py configs/eval_attack.py --mode infer
```
所有结果都将保存在名为“attack”的文件夹中。
内容包括原始提示的准确性和受到攻击的提示的准确性以及前k个提示下降的准确性例如
```
Prompt: Assess the connection between the following sentences and classify it as 'A. entailment' or 'B. not_entailment'., acc: 59.15%
Prompt: Does the relationship between the given sentences represent entailment or not_entailment? Respond with 'A. entailment' or 'B. not_entailment'., acc: 57.75%
Prompt: Analyze the two provided sentences and decide if their relationship is 'A. entailment' or 'B. not_entailment'., acc: 56.34%
Prompt: Identify whether the given pair of sentences demonstrates entailment or not_entailment. Answer with 'A. entailment' or 'B. not_entailment'., acc: 54.93%
...
Original prompt: Assess the connection between the following sentences and classify it as 'A. entailment' or 'B. not_entailment'.
Attacked prompt: b"Assess the attach between the following sentences and sorted it as 'A. entailment' or 'B. not_entailment'."
Original acc: 59.15%, attacked acc: 40.85%, dropped acc: 18.31%
```

View File

@ -60,6 +60,7 @@ OpenCompass 上手路线
advanced_guides/new_model.md
advanced_guides/evaluation_turbomind.md
advanced_guides/code_eval_service.md
advanced_guides/prompt_attack.md
.. _工具:
.. toctree::

View File

@ -70,6 +70,7 @@ from .tydiqa import * # noqa: F401, F403
from .wic import * # noqa: F401, F403
from .winograd import * # noqa: F401, F403
from .winogrande import * # noqa: F401, F403
from .wnli import wnliDataset # noqa: F401, F403
from .wsc import * # noqa: F401, F403
from .xcopa import * # noqa: F401, F403
from .xiezhi import XiezhiDataset, XiezhiRetriever # noqa: F401, F403

View File

@ -0,0 +1,26 @@
from datasets import load_dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class wnliDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
# dataset = dataset['validation']
gt_dict = {
1: 'A',
0: 'B',
-1: -1,
}
def preprocess(example):
example['label_option'] = gt_dict[example['label']]
return example
return dataset.map(preprocess)

View File

@ -1,3 +1,4 @@
from .icl_attack_inferencer import AttackInferencer # noqa
from .icl_base_inferencer import BaseInferencer # noqa
from .icl_clp_inferencer import CLPInferencer # noqa
from .icl_gen_inferencer import GenInferencer # noqa

View File

@ -0,0 +1,210 @@
"""Direct Generation Inferencer."""
import os
import os.path as osp
from typing import List, Optional
import mmengine
import torch
from tqdm import tqdm
from opencompass.models.base import BaseModel
from opencompass.registry import (ICL_EVALUATORS, ICL_INFERENCERS,
TEXT_POSTPROCESSORS)
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils.logging import get_logger
from .icl_base_inferencer import BaseInferencer, GenInferencerOutputHandler
logger = get_logger(__name__)
@ICL_INFERENCERS.register_module()
class AttackInferencer(BaseInferencer):
"""Generation Inferencer class to directly evaluate by generation.
Attributes:
model (:obj:`BaseModelWrapper`, optional): The module to inference.
max_out_len (:obj:`int`, optional): Maximum number of tokenized words
of the output.
adv_key (:obj:`str`): Prompt key in template to be attacked.
metric_key (:obj:`str`): Metric key to be returned and compared.
Defaults to `accuracy`.
max_seq_len (:obj:`int`, optional): Maximum number of tokenized words
allowed by the LM.
batch_size (:obj:`int`, optional): Batch size for the
:obj:`DataLoader`.
output_json_filepath (:obj:`str`, optional): File path for output
`JSON` file.
output_json_filename (:obj:`str`, optional): File name for output
`JSON` file.
gen_field_replace_token (:obj:`str`, optional): Used to replace the
generation field token when generating prompts.
save_every (:obj:`int`, optional): Save intermediate results every
`save_every` epochs.
generation_kwargs (:obj:`Dict`, optional): Parameters for the
:obj:`model.generate()` method.
"""
def __init__(
self,
model: BaseModel,
max_out_len: int,
adv_key: str,
metric_key: str = 'accuracy',
max_seq_len: Optional[int] = None,
batch_size: Optional[int] = 1,
gen_field_replace_token: Optional[str] = '',
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
dataset_cfg: Optional[List[int]] = None,
**kwargs) -> None:
super().__init__(
model=model,
max_seq_len=max_seq_len,
batch_size=batch_size,
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
**kwargs,
)
self.adv_key = adv_key
self.metric_key = metric_key
self.dataset_cfg = dataset_cfg
self.eval_cfg = dataset_cfg['eval_cfg']
self.output_column = dataset_cfg['reader_cfg']['output_column']
self.gen_field_replace_token = gen_field_replace_token
self.max_out_len = max_out_len
self.fix_id_list = fix_id_list
if self.model.is_api and save_every is None:
save_every = 1
self.save_every = save_every
def predict(self, adv_prompt) -> List:
# 1. Preparation for output logs
output_handler = GenInferencerOutputHandler()
# if output_json_filepath is None:
output_json_filepath = self.output_json_filepath
# if output_json_filename is None:
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in self.retriever.__class__.__name__:
ice_idx_list = self.retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = self.retriever.retrieve()
# 3. Generate prompts for testing input
prompt_list, label_list = self.get_generation_prompt_list_from_retriever_indices( # noqa
ice_idx_list, {self.adv_key: adv_prompt},
self.retriever,
self.gen_field_replace_token,
max_seq_len=self.max_seq_len,
ice_template=self.ice_template,
prompt_template=self.prompt_template)
# Create tmp json file for saving intermediate results and future
# resuming
index = 0
tmp_json_filepath = os.path.join(output_json_filepath,
'tmp_' + output_json_filename)
if osp.exists(tmp_json_filepath):
# TODO: move resume to output handler
tmp_result_dict = mmengine.load(tmp_json_filepath)
output_handler.results_dict = tmp_result_dict
index = len(tmp_result_dict)
# 4. Wrap prompts with Dataloader
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
# 5. Inference for prompts in each batch
logger.info('Starting inference process...')
for entry in tqdm(dataloader, disable=not self.is_main_process):
# 5-1. Inference with local model
with torch.no_grad():
parsed_entries = self.model.parse_template(entry, mode='gen')
results = self.model.generate_from_template(
entry, max_out_len=self.max_out_len)
generated = results
# 5-3. Save current output
for prompt, prediction in zip(parsed_entries, generated):
output_handler.save_results(prompt, prediction, index)
index = index + 1
# 5-4. Save intermediate results
if (self.save_every is not None and index % self.save_every == 0
and self.is_main_process):
output_handler.write_to_json(output_json_filepath,
'tmp_' + output_json_filename)
# 6. Output
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
output_handler.write_to_json(output_json_filepath,
output_json_filename)
if osp.exists(tmp_json_filepath):
os.remove(tmp_json_filepath)
pred_strs = [
sample['prediction']
for sample in output_handler.results_dict.values()
]
if 'pred_postprocessor' in self.eval_cfg:
kwargs = self.eval_cfg['pred_postprocessor'].copy()
proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type'))
pred_strs = [proc(s, **kwargs) for s in pred_strs]
icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator'])
result = icl_evaluator.score(predictions=pred_strs,
references=label_list)
score = result.get(self.metric_key)
# try to shrink score to range 0-1
return score / 100 if score > 1 else score
def get_generation_prompt_list_from_retriever_indices(
self,
ice_idx_list: List[List[int]],
extra_prompt: dict,
retriever: BaseRetriever,
gen_field_replace_token: str,
max_seq_len: Optional[int] = None,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None):
prompt_list = []
label_list = []
for idx, ice_idx in enumerate(ice_idx_list):
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
prompt = retriever.generate_prompt_for_adv_generate_task(
idx,
ice,
extra_prompt,
gen_field_replace_token=gen_field_replace_token,
ice_template=ice_template,
prompt_template=prompt_template)
label = retriever.test_ds[idx][self.output_column]
label_list.append(label)
if max_seq_len is not None:
prompt_token_num = self.model.get_token_len_from_template(
prompt, mode='gen')
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
ice_idx = ice_idx[:-1]
ice = retriever.generate_ice(ice_idx,
ice_template=ice_template)
prompt = retriever.generate_prompt_for_adv_generate_task(
idx,
ice,
extra_prompt,
gen_field_replace_token=gen_field_replace_token,
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = self.model.get_token_len_from_template(
prompt, mode='gen')
prompt_list.append(prompt)
return prompt_list, label_list

View File

@ -206,3 +206,66 @@ class BaseRetriever:
else:
raise NotImplementedError(
'Leaving prompt as empty is not supported')
def generate_prompt_for_adv_generate_task(
self,
idx,
ice,
extra_prompt=dict(),
gen_field_replace_token='',
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None):
"""Generate the prompt for one test example in generative evaluation
with `prompt_template`. If `prompt_template` is not provided, the
`ice_template` will be used to generate the prompt. The token
represented by `gen_field_replace_token` will not be replaced by the
generated text, or it will leaks the answer.
Args:
idx (`int`): The index of the test example.
ice (`str`): The in-context example for the test example.
gen_field_replace_token (`str`): The token of the answer in the
prompt. Defaults to ''.
ice_template (`Optional[PromptTemplate]`): The template for
in-context example. Defaults to None.
prompt_template (`Optional[PromptTemplate]`): The template for
prompt. Defaults to None.
"""
if prompt_template is not None and ice_template is not None:
if prompt_template.ice_token is not None:
return prompt_template.generate_item(
{
**self.test_ds[idx],
**extra_prompt
},
output_field=self.dataset_reader.output_column,
output_field_replace_token=gen_field_replace_token,
ice_field_replace_token=ice)
else:
raise NotImplementedError(
'ice_token of prompt_template is not provided')
elif ice_template is not None and prompt_template is None:
if ice_template.ice_token is not None:
return ice_template.generate_item(
{
**self.test_ds[idx],
**extra_prompt
},
output_field=self.dataset_reader.output_column,
output_field_replace_token=gen_field_replace_token,
ice_field_replace_token=ice)
else:
raise NotImplementedError(
'ice_token of ice_template is not provided')
elif ice_template is None and prompt_template is not None:
return prompt_template.generate_item(
{
**self.test_ds[idx],
**extra_prompt
},
output_field=self.dataset_reader.output_column,
output_field_replace_token=gen_field_replace_token,
ice_field_replace_token=ice)
else:
raise NotImplementedError(
'Leaving prompt as empty is not supported')

View File

@ -1,3 +1,4 @@
from .mm_infer import * # noqa: F401, F403
from .openicl_attack import * # noqa: F401, F403
from .openicl_eval import * # noqa: F401, F403
from .openicl_infer import * # noqa: F401, F403

View File

@ -0,0 +1,204 @@
import argparse
import os.path as osp
import random
import time
from typing import Any
from mmengine.config import Config, ConfigDict
from mmengine.utils import mkdir_or_exist
from opencompass.registry import (ICL_INFERENCERS, ICL_PROMPT_TEMPLATES,
ICL_RETRIEVERS, TASKS)
from opencompass.tasks.base import BaseTask
from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg,
get_infer_output_path, get_logger,
task_abbr_from_cfg)
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
class OpenICLAttackTask(BaseTask):
"""OpenICL Inference Task.
This task is used to run the inference process.
"""
name_prefix = 'OpenICLAttack'
log_subdir = 'logs/attack'
output_subdir = 'attack'
def __init__(self, cfg: ConfigDict):
super().__init__(cfg)
run_cfg = self.model_cfgs[0].get('run_cfg', {})
self.num_gpus = run_cfg.get('num_gpus', 0)
self.num_procs = run_cfg.get('num_procs', 1)
self.logger = get_logger()
def get_command(self, cfg_path, template):
"""Get the command template for the task.
Args:
cfg_path (str): The path to the config file of the task.
template (str): The template which have '{task_cmd}' to format
the command.
"""
script_path = __file__
if self.num_gpus > 0:
port = random.randint(12000, 32000)
command = (f'torchrun --master_port={port} '
f'--nproc_per_node {self.num_procs} '
f'{script_path} {cfg_path}')
else:
command = f'python {script_path} {cfg_path}'
return template.format(task_cmd=command)
def prompt_selection(self, inferencer, prompts):
prompt_dict = {}
for prompt in prompts:
acc = inferencer.predict(prompt)
prompt_dict[prompt] = acc
self.logger.info('{:.2f}, {}\n'.format(acc * 100, prompt))
sorted_prompts = sorted(prompt_dict.items(),
key=lambda x: x[1],
reverse=True)
return sorted_prompts
def run(self):
self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}')
for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs):
self.max_out_len = model_cfg.get('max_out_len', None)
self.batch_size = model_cfg.get('batch_size', None)
self.model = build_model_from_cfg(model_cfg)
for dataset_cfg in dataset_cfgs:
self.model_cfg = model_cfg
self.dataset_cfg = dataset_cfg
self.infer_cfg = self.dataset_cfg['infer_cfg']
self.dataset = build_dataset_from_cfg(self.dataset_cfg)
self.sub_cfg = {
'models': [self.model_cfg],
'datasets': [[self.dataset_cfg]],
}
out_path = get_infer_output_path(
self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'attack'))
if osp.exists(out_path):
continue
self._inference()
def _inference(self):
self.logger.info(
f'Start inferencing {task_abbr_from_cfg(self.sub_cfg)}')
assert hasattr(self.infer_cfg, 'ice_template') or hasattr(self.infer_cfg, 'prompt_template'), \
'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501
ice_template = None
if hasattr(self.infer_cfg, 'ice_template'):
ice_template = ICL_PROMPT_TEMPLATES.build(
self.infer_cfg['ice_template'])
prompt_template = None
if hasattr(self.infer_cfg, 'prompt_template'):
prompt_template = ICL_PROMPT_TEMPLATES.build(
self.infer_cfg['prompt_template'])
retriever_cfg = self.infer_cfg['retriever'].copy()
retriever_cfg['dataset'] = self.dataset
retriever = ICL_RETRIEVERS.build(retriever_cfg)
# set inferencer's default value according to model's config'
inferencer_cfg = self.infer_cfg['inferencer']
inferencer_cfg['model'] = self.model
self._set_default_value(inferencer_cfg, 'max_out_len',
self.max_out_len)
self._set_default_value(inferencer_cfg, 'batch_size', self.batch_size)
inferencer_cfg['max_seq_len'] = self.model_cfg['max_seq_len']
inferencer_cfg['dataset_cfg'] = self.dataset_cfg
inferencer = ICL_INFERENCERS.build(inferencer_cfg)
out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'attack'))
out_dir, out_file = osp.split(out_path)
mkdir_or_exist(out_dir)
from config import LABEL_SET
from prompt_attack.attack import create_attack
from prompt_attack.goal_function import PromptGoalFunction
inferencer.retriever = retriever
inferencer.prompt_template = prompt_template
inferencer.ice_template = ice_template
inferencer.output_json_filepath = out_dir
inferencer.output_json_filename = out_file
goal_function = PromptGoalFunction(
inference=inferencer,
query_budget=self.cfg['attack'].query_budget,
logger=self.logger,
model_wrapper=None,
verbose='True')
if self.cfg['attack']['dataset'] not in LABEL_SET:
# set default
self.cfg['attack']['dataset'] = 'mmlu'
attack = create_attack(self.cfg['attack'], goal_function)
prompts = self.infer_cfg['inferencer']['original_prompt_list']
sorted_prompts = self.prompt_selection(inferencer, prompts)
if True:
# if args.prompt_selection:
for prompt, acc in sorted_prompts:
self.logger.info('Prompt: {}, acc: {:.2f}%\n'.format(
prompt, acc * 100))
with open(out_dir + 'attacklog.txt', 'a+') as f:
f.write('Prompt: {}, acc: {:.2f}%\n'.format(
prompt, acc * 100))
for init_prompt, init_acc in sorted_prompts[:self.cfg['attack'].
prompt_topk]:
if init_acc > 0:
init_acc, attacked_prompt, attacked_acc, dropped_acc = attack.attack( # noqa
init_prompt)
self.logger.info('Original prompt: {}'.format(init_prompt))
self.logger.info('Attacked prompt: {}'.format(
attacked_prompt.encode('utf-8')))
self.logger.info(
'Original acc: {:.2f}%, attacked acc: {:.2f}%, dropped acc: {:.2f}%' # noqa
.format(init_acc * 100, attacked_acc * 100,
dropped_acc * 100))
with open(out_dir + 'attacklog.txt', 'a+') as f:
f.write('Original prompt: {}\n'.format(init_prompt))
f.write('Attacked prompt: {}\n'.format(
attacked_prompt.encode('utf-8')))
f.write(
'Original acc: {:.2f}%, attacked acc: {:.2f}%, dropped acc: {:.2f}%\n\n' # noqa
.format(init_acc * 100, attacked_acc * 100,
dropped_acc * 100))
else:
with open(out_dir + 'attacklog.txt', 'a+') as f:
f.write('Init acc is 0, skip this prompt\n')
f.write('Original prompt: {}\n'.format(init_prompt))
f.write('Original acc: {:.2f}% \n\n'.format(init_acc *
100))
def _set_default_value(self, cfg: ConfigDict, key: str, value: Any):
if key not in cfg:
assert value, (f'{key} must be specified!')
cfg[key] = value
def parse_args():
parser = argparse.ArgumentParser(description='Model Inferencer')
parser.add_argument('config', help='Config file path')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
cfg = Config.fromfile(args.config)
start_time = time.time()
inferencer = OpenICLAttackTask(cfg)
inferencer.run()
end_time = time.time()
get_logger().info(f'time elapsed: {end_time - start_time:.2f}s')

5
run.py
View File

@ -266,6 +266,11 @@ def main():
if args.dry_run:
return
runner = RUNNERS.build(cfg.infer.runner)
# Add extra attack config if exists
if hasattr(cfg, 'attack'):
for task in tasks:
cfg.attack.dataset = task.datasets[0][0].abbr
task.attack = cfg.attack
runner(tasks)
# evaluate