mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feat] implementation for support promptbench (#239)
* [Feat] support adv_glue dataset for adversarial robustness * reorg files * minor fix * minor fix * support prompt bench demo * minor fix * minor fix * minor fix * minor fix * minor fix * minor fix * minor fix * minor fix
This commit is contained in:
parent
de8a154795
commit
a11cb45c83
@ -0,0 +1,57 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import BM25Retriever
|
||||
from opencompass.openicl.icl_inferencer import AttackInferencer
|
||||
from opencompass.openicl.icl_evaluator import BleuEvaluator
|
||||
from opencompass.datasets import IWSLT2017Dataset
|
||||
from opencompass.utils.text_postprocessors import general_cn_postprocess
|
||||
|
||||
iwslt2017_reader_cfg = dict(
|
||||
input_columns='en', output_column='de', train_split='validation')
|
||||
|
||||
original_prompt_list = [
|
||||
"Translate the provided sentence from English to German while maintaining the original meaning and context:",
|
||||
"Convert the following sentence from its original English language to the target language German:",
|
||||
"Given the sentence below, perform a machine translation from English to German:",
|
||||
"Translate the subsequent sentence from its source language English into the desired language German:",
|
||||
"Accurately translate the sentence from English to German, ensuring the meaning remains intact:",
|
||||
"Please perform a translation of the given sentence, converting it from English to German:",
|
||||
"Translate the following text from the source language English to the target language German:",
|
||||
"Using machine translation, convert the given sentence from English into the German language:",
|
||||
"Translate the subsequent text passage from its original English language to the German language:",
|
||||
"Perform a machine translation for the provided sentence, changing it from English to German:",
|
||||
]
|
||||
|
||||
iwslt2017_infer_cfg = dict(
|
||||
ice_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
begin='</E>',
|
||||
round=[
|
||||
dict(role='HUMAN', prompt='{adv_prompt}\n{en}'),
|
||||
dict(role='BOT', prompt='{de}'),
|
||||
]
|
||||
),
|
||||
ice_token='</E>'),
|
||||
retriever=dict(type=BM25Retriever, ice_num=1),
|
||||
inferencer=dict(
|
||||
type=AttackInferencer,
|
||||
original_prompt_list=original_prompt_list,
|
||||
adv_key='adv_prompt',
|
||||
metric_key='score'))
|
||||
|
||||
iwslt2017_eval_cfg = dict(
|
||||
evaluator=dict(type=BleuEvaluator),
|
||||
pred_role='BOT',
|
||||
pred_postprocessor=dict(type=general_cn_postprocess),
|
||||
dataset_postprocessor=dict(type=general_cn_postprocess))
|
||||
|
||||
iwslt2017_datasets = [
|
||||
dict(
|
||||
abbr='iwslt',
|
||||
type=IWSLT2017Dataset,
|
||||
path='iwslt2017',
|
||||
name='iwslt2017-en-de',
|
||||
reader_cfg=iwslt2017_reader_cfg,
|
||||
infer_cfg=iwslt2017_infer_cfg,
|
||||
eval_cfg=iwslt2017_eval_cfg)
|
||||
]
|
44
configs/datasets/promptbench/promptbench_math_gen_abf776.py
Normal file
44
configs/datasets/promptbench/promptbench_math_gen_abf776.py
Normal file
@ -0,0 +1,44 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import AttackInferencer
|
||||
from opencompass.datasets import MATHDataset, MATHEvaluator, math_postprocess
|
||||
|
||||
math_reader_cfg = dict(input_columns=['problem'], output_column='solution')
|
||||
|
||||
original_prompt_list = [
|
||||
"Solve the following math question about",
|
||||
"Determine the solution to this mathematical problem related to",
|
||||
"Calculate the answer to the following math query about",
|
||||
"Find the solution for this mathematical challenge with",
|
||||
"Compute the result of this math task concerning",
|
||||
"Resolve the following mathematical question associated with",
|
||||
"Work out the answer to this math problem featuring",
|
||||
"Figure out the solution for the following mathematical task with",
|
||||
"Obtain the result for this math question regarding",
|
||||
"Evaluate the following mathematical problem that includes",
|
||||
]
|
||||
|
||||
math_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt="{adv_prompt} {problem}:"),
|
||||
]),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=AttackInferencer, original_prompt_list=original_prompt_list,max_out_len=512, adv_key='adv_prompt'))
|
||||
|
||||
math_eval_cfg = dict(
|
||||
evaluator=dict(type=MATHEvaluator), pred_postprocessor=dict(type=math_postprocess))
|
||||
|
||||
math_datasets = [
|
||||
dict(
|
||||
type=MATHDataset,
|
||||
abbr='math',
|
||||
path='./data/math/math.json',
|
||||
reader_cfg=math_reader_cfg,
|
||||
infer_cfg=math_infer_cfg,
|
||||
eval_cfg=math_eval_cfg)
|
||||
]
|
@ -0,0 +1,48 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import AttackInferencer
|
||||
from opencompass.datasets import SQuAD20Dataset, SQuAD20Evaluator
|
||||
|
||||
squad20_reader_cfg = dict(
|
||||
input_columns=['context', 'question'],
|
||||
output_column='answers')
|
||||
|
||||
original_prompt_list = [
|
||||
"Based on the given context, provide the best possible answer. If there's no answer available in the context, respond with 'unanswerable'.",
|
||||
"Identify the most relevant answer from the context. If it's not possible to find an answer, respond with 'unanswerable'.",
|
||||
"Find the correct answer in the context provided. If an answer cannot be found, please respond with 'unanswerable'.",
|
||||
"Please extract the most appropriate answer from the context. If an answer is not present, indicate 'unanswerable'.",
|
||||
"Using the context, determine the most suitable answer. If the context doesn't contain the answer, respond with 'unanswerable'.",
|
||||
"Locate the most accurate answer within the context. If the context doesn't provide an answer, respond with 'unanswerable'.",
|
||||
"Please derive the most fitting answer from the context. If there isn't an answer in the context, respond with 'unanswerable'.",
|
||||
"Discover the best answer based on the context. If the context doesn't include an answer, respond with 'unanswerable'.",
|
||||
"From the context, provide the most precise answer. If the answer is not in the context, respond with 'unanswerable'.",
|
||||
"Search the context for the most relevant answer. If the answer cannot be found, respond with 'unanswerable'.",
|
||||
]
|
||||
|
||||
squad20_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(role='HUMAN', prompt='{adv_prompt} {context}'),
|
||||
dict(role='BOT', prompt='Answer:'),
|
||||
], )),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=AttackInferencer, max_out_len=50,
|
||||
original_prompt_list=original_prompt_list,
|
||||
adv_key='adv_prompt',
|
||||
metric_key='score'))
|
||||
|
||||
squad20_eval_cfg = dict(
|
||||
evaluator=dict(type=SQuAD20Evaluator), pred_role='BOT')
|
||||
|
||||
squad20_datasets = [
|
||||
dict(
|
||||
type=SQuAD20Dataset,
|
||||
abbr='squad_v2',
|
||||
path='./data/SQuAD2.0/dev-v2.0.json',
|
||||
reader_cfg=squad20_reader_cfg,
|
||||
infer_cfg=squad20_infer_cfg,
|
||||
eval_cfg=squad20_eval_cfg)
|
||||
]
|
61
configs/datasets/promptbench/promptbench_wnli_gen_50662f.py
Normal file
61
configs/datasets/promptbench/promptbench_wnli_gen_50662f.py
Normal file
@ -0,0 +1,61 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import AttackInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import wnliDataset
|
||||
from opencompass.utils.text_postprocessors import first_option_postprocess
|
||||
|
||||
wnli_reader_cfg = dict(
|
||||
input_columns=['sentence1', 'sentence2'],
|
||||
output_column='label_option',
|
||||
train_split='validation',
|
||||
test_split='validation')
|
||||
|
||||
original_prompt_list = [
|
||||
'Are the following two sentences entailment or not_entailment? Answer me with "A. entailment" or "B. not_entailment", just one word. ',
|
||||
"Does the relationship between the given sentences represent entailment or not_entailment? Respond with 'A. entailment' or 'B. not_entailment'.",
|
||||
"Examine the pair of sentences and determine if they exhibit entailment or not_entailment. Answer with either 'A. entailment' or 'B. not_entailment'.",
|
||||
"Assess the connection between the following sentences and classify it as 'A. entailment' or 'B. not_entailment'.",
|
||||
"Analyze the two provided sentences and decide if their relationship is 'A. entailment' or 'B. not_entailment'.",
|
||||
"Identify whether the given pair of sentences demonstrates entailment or not_entailment. Answer with 'A. entailment' or 'B. not_entailment'.",
|
||||
"Review the two statements and categorize their relationship as either 'A. entailment' or 'B. not_entailment'.",
|
||||
"Please classify the relationship between the provided sentences as 'A. entailment' or 'B. not_entailment'.",
|
||||
"Indicate if the connection between the following sentences is 'A. entailment' or 'B. not_entailment'.",
|
||||
"Determine if the given pair of sentences displays entailment or not_entailment. Respond with 'A. entailment' or 'B. not_entailment'.",
|
||||
"Considering the two sentences, identify if their relationship is 'A. entailment' or 'B. not_entailment'.",
|
||||
]
|
||||
|
||||
wnli_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt="""{adv_prompt}
|
||||
Sentence 1: {sentence1}
|
||||
Sentence 2: {sentence2}
|
||||
Answer:"""),
|
||||
]),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(
|
||||
type=AttackInferencer,
|
||||
original_prompt_list=original_prompt_list,
|
||||
adv_key='adv_prompt'))
|
||||
|
||||
wnli_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
pred_role="BOT",
|
||||
pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
|
||||
)
|
||||
|
||||
wnli_datasets = [
|
||||
dict(
|
||||
abbr='wnli',
|
||||
type=wnliDataset,
|
||||
path='glue',
|
||||
name='wnli',
|
||||
reader_cfg=wnli_reader_cfg,
|
||||
infer_cfg=wnli_infer_cfg,
|
||||
eval_cfg=wnli_eval_cfg)
|
||||
]
|
27
configs/eval_attack.py
Normal file
27
configs/eval_attack.py
Normal file
@ -0,0 +1,27 @@
|
||||
from mmengine.config import read_base
|
||||
from opencompass.partitioners import NaivePartitioner
|
||||
from opencompass.runners import LocalRunner
|
||||
from opencompass.tasks import OpenICLAttackTask
|
||||
|
||||
with read_base():
|
||||
# choose a list of datasets
|
||||
from .datasets.promptbench.promptbench_wnli_gen_50662f import wnli_datasets
|
||||
from .models.hf_vicuna_7b import models
|
||||
|
||||
datasets = wnli_datasets
|
||||
|
||||
# Please run whole dataset at a time, aka use `NaivePartitioner` only
|
||||
# Please use `OpenICLAttackTask` if want to perform attack experiment
|
||||
infer = dict(
|
||||
partitioner=dict(type=NaivePartitioner),
|
||||
runner=dict(
|
||||
type=LocalRunner,
|
||||
max_num_workers=8,
|
||||
task=dict(type=OpenICLAttackTask)),
|
||||
)
|
||||
|
||||
attack = dict(
|
||||
attack='textfooler',
|
||||
query_budget=100,
|
||||
prompt_topk=1,
|
||||
)
|
108
docs/en/advanced_guides/prompt_attack.md
Normal file
108
docs/en/advanced_guides/prompt_attack.md
Normal file
@ -0,0 +1,108 @@
|
||||
# Prompt Attack
|
||||
|
||||
We support prompt attack following the idea of [PromptBench](https://github.com/microsoft/promptbench). The main purpose here is to evaluate the robustness of prompt instruction, which means when attack/modify the prompt to instruct the task, how well can this task perform as the original task.
|
||||
|
||||
## Set up environment
|
||||
|
||||
Some components are necessary to prompt attack experiment, therefore we need to set up environments.
|
||||
|
||||
```shell
|
||||
git clone https://github.com/microsoft/promptbench.git
|
||||
pip install textattack==0.3.8
|
||||
export PYTHONPATH=$PYTHONPATH:promptbench/
|
||||
```
|
||||
|
||||
## How to attack
|
||||
|
||||
### Add a dataset config
|
||||
|
||||
We will use GLUE-wnli dataset as example, most configuration settings can refer to [config.md](../user_guides/config.md) for help.
|
||||
|
||||
First we need support the basic dataset config, you can find the existing config files in `configs` or support your own config according to [new-dataset](./new_dataset.md)
|
||||
|
||||
Take the following `infer_cfg` as example, we need to define the prompt template. `adv_prompt` is the basic prompt placeholder to be attacked in the experiment. `sentence1` and `sentence2` are the input columns of this dataset. The attack will only modify the `adv_prompt` here.
|
||||
|
||||
Then, we should use `AttackInferencer` with `original_prompt_list` and `adv_key` to tell the inferencer where to attack and what text to be attacked.
|
||||
|
||||
More details can refer to `configs/datasets/promptbench/promptbench_wnli_gen_50662f.py` config file.
|
||||
|
||||
```python
|
||||
original_prompt_list = [
|
||||
'Are the following two sentences entailment or not_entailment? Answer me with "A. entailment" or "B. not_entailment", just one word. ',
|
||||
"Does the relationship between the given sentences represent entailment or not_entailment? Respond with 'A. entailment' or 'B. not_entailment'.",
|
||||
...,
|
||||
]
|
||||
|
||||
wnli_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt="""{adv_prompt}
|
||||
Sentence 1: {sentence1}
|
||||
Sentence 2: {sentence2}
|
||||
Answer:"""),
|
||||
]),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(
|
||||
type=AttackInferencer,
|
||||
original_prompt_list=original_prompt_list,
|
||||
adv_key='adv_prompt'))
|
||||
```
|
||||
|
||||
### Add a eval config
|
||||
|
||||
We should use `OpenICLAttackTask` here for attack task. Also `NaivePartitioner` should be used because the attack experiment will run the whole dataset repeatedly for nearly hurdurds times to search the best attack, we do not want to split the dataset for convenience.
|
||||
|
||||
```note
|
||||
Please choose a small dataset(example < 1000) for attack, due to the aforementioned repeated search, otherwise the time cost is enumerous.
|
||||
```
|
||||
|
||||
There are several other options in `attack` config:
|
||||
|
||||
- `attack`: attack type, available options includes `textfooler`, `textbugger`, `deepwordbug`, `bertattack`, `checklist`, `stresstest`;
|
||||
- `query_budget`: upper boundary of queries, which means the total numbers of running the dataset;
|
||||
- `prompt_topk`: number of topk prompt to be attacked. In most case, the original prompt list is great than 10, running the whole set is time consuming.
|
||||
|
||||
```python
|
||||
# Please run whole dataset at a time, aka use `NaivePartitioner` only
|
||||
# Please use `OpenICLAttackTask` if want to perform attack experiment
|
||||
infer = dict(
|
||||
partitioner=dict(type=NaivePartitioner),
|
||||
runner=dict(
|
||||
type=SlurmRunner,
|
||||
max_num_workers=8,
|
||||
task=dict(type=OpenICLAttackTask),
|
||||
retry=0),
|
||||
)
|
||||
|
||||
attack = dict(
|
||||
attack='textfooler',
|
||||
query_budget=100,
|
||||
prompt_topk=2,
|
||||
)
|
||||
```
|
||||
|
||||
### Run the experiment
|
||||
|
||||
Please use `--mode infer` when run the attack experiment, and set `PYTHONPATH` env.
|
||||
|
||||
```shell
|
||||
python run.py configs/eval_attack.py --mode infer
|
||||
```
|
||||
|
||||
All the results will be saved in `attack` folder.
|
||||
The content includes the original prompt accuracy and the attacked prompt with dropped accuracy of `topk` prompt, for instance:
|
||||
|
||||
```
|
||||
Prompt: Assess the connection between the following sentences and classify it as 'A. entailment' or 'B. not_entailment'., acc: 59.15%
|
||||
Prompt: Does the relationship between the given sentences represent entailment or not_entailment? Respond with 'A. entailment' or 'B. not_entailment'., acc: 57.75%
|
||||
Prompt: Analyze the two provided sentences and decide if their relationship is 'A. entailment' or 'B. not_entailment'., acc: 56.34%
|
||||
Prompt: Identify whether the given pair of sentences demonstrates entailment or not_entailment. Answer with 'A. entailment' or 'B. not_entailment'., acc: 54.93%
|
||||
...
|
||||
Original prompt: Assess the connection between the following sentences and classify it as 'A. entailment' or 'B. not_entailment'.
|
||||
Attacked prompt: b"Assess the attach between the following sentences and sorted it as 'A. entailment' or 'B. not_entailment'."
|
||||
Original acc: 59.15%, attacked acc: 40.85%, dropped acc: 18.31%
|
||||
```
|
@ -60,6 +60,7 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass.
|
||||
advanced_guides/new_model.md
|
||||
advanced_guides/evaluation_turbomind.md
|
||||
advanced_guides/code_eval_service.md
|
||||
advanced_guides/prompt_attack.md
|
||||
|
||||
.. _Tools:
|
||||
.. toctree::
|
||||
|
108
docs/zh_cn/advanced_guides/prompt_attack.md
Normal file
108
docs/zh_cn/advanced_guides/prompt_attack.md
Normal file
@ -0,0 +1,108 @@
|
||||
# 提示词攻击
|
||||
|
||||
OpenCompass 支持[PromptBench](https://github.com/microsoft/promptbench)的提示词攻击。其主要想法是评估提示指令的鲁棒性,也就是说,当攻击或修改提示以指导任务时,希望该任务能尽可能表现的像像原始任务一样好。
|
||||
|
||||
## 环境安装
|
||||
|
||||
提示词攻击需要依赖 `PromptBench` 中的组件,所以需要先配置好环境。
|
||||
|
||||
```shell
|
||||
git clone https://github.com/microsoft/promptbench.git
|
||||
pip install textattack==0.3.8
|
||||
export PYTHONPATH=$PYTHONPATH:promptbench/
|
||||
```
|
||||
|
||||
## 如何攻击
|
||||
|
||||
### 增加数据集配置文件
|
||||
|
||||
我们将使用GLUE-wnli数据集作为示例,大部分配置设置可以参考[config.md](../user_guides/config.md)获取帮助。
|
||||
|
||||
首先,我们需要支持基本的数据集配置,你可以在`configs`中找到现有的配置文件,或者根据[new-dataset](./new_dataset.md)支持你自己的配置。
|
||||
|
||||
以下面的`infer_cfg`为例,我们需要定义提示模板。`adv_prompt`是实验中要被攻击的基本提示占位符。`sentence1`和`sentence2`是此数据集的输入。攻击只会修改`adv_prompt`字段。
|
||||
|
||||
然后,我们应该使用`AttackInferencer`与`original_prompt_list`和`adv_key`告诉推理器在哪里攻击和攻击什么文本。
|
||||
|
||||
更多详细信息可以参考`configs/datasets/promptbench/promptbench_wnli_gen_50662f.py`配置文件。
|
||||
|
||||
```python
|
||||
original_prompt_list = [
|
||||
'Are the following two sentences entailment or not_entailment? Answer me with "A. entailment" or "B. not_entailment", just one word. ',
|
||||
"Does the relationship between the given sentences represent entailment or not_entailment? Respond with 'A. entailment' or 'B. not_entailment'.",
|
||||
...,
|
||||
]
|
||||
|
||||
wnli_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt="""{adv_prompt}
|
||||
Sentence 1: {sentence1}
|
||||
Sentence 2: {sentence2}
|
||||
Answer:"""),
|
||||
]),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(
|
||||
type=AttackInferencer,
|
||||
original_prompt_list=original_prompt_list,
|
||||
adv_key='adv_prompt'))
|
||||
```
|
||||
|
||||
### Add a eval config
|
||||
|
||||
我们应该在此处使用 `OpenICLAttackTask` 来进行攻击任务。还应该使用 `NaivePartitioner`,因为攻击实验将重复运行整个数据集近百次以搜索最佳攻击,为方便起见我们不希望拆分数据集。
|
||||
|
||||
```note
|
||||
由于上述提到的重复搜索,请选择小型数据集(样本少于1000)进行攻击,否则时间成本将非常大。
|
||||
```
|
||||
|
||||
在 `attack` 配置中还有其他几个选项:
|
||||
|
||||
- `attack`:攻击类型,可用选项包括`textfooler`, `textbugger`, `deepwordbug`, `bertattack`, `checklist`, `stresstest`;
|
||||
- `query_budget`:查询次数的上界,即运行数据集的总次数;
|
||||
- `prompt_topk`:要攻击的前k个提示的数量。在大多数情况下,原始提示列表大于10,运行整个集合是耗时的。
|
||||
|
||||
```python
|
||||
# Please run whole dataset at a time, aka use `NaivePartitioner` only
|
||||
# Please use `OpenICLAttackTask` if want to perform attack experiment
|
||||
infer = dict(
|
||||
partitioner=dict(type=NaivePartitioner),
|
||||
runner=dict(
|
||||
type=SlurmRunner,
|
||||
max_num_workers=8,
|
||||
task=dict(type=OpenICLAttackTask),
|
||||
retry=0),
|
||||
)
|
||||
|
||||
attack = dict(
|
||||
attack='textfooler',
|
||||
query_budget=100,
|
||||
prompt_topk=2,
|
||||
)
|
||||
```
|
||||
|
||||
### 运行试验
|
||||
|
||||
请当运行攻击实验的时候请使用 `--mode infer` 选项,并需要指定`PYTHONPATH`。
|
||||
|
||||
```shell
|
||||
python run.py configs/eval_attack.py --mode infer
|
||||
```
|
||||
|
||||
所有结果都将保存在名为“attack”的文件夹中。
|
||||
内容包括原始提示的准确性和受到攻击的提示的准确性,以及前k个提示下降的准确性,例如:
|
||||
|
||||
```
|
||||
Prompt: Assess the connection between the following sentences and classify it as 'A. entailment' or 'B. not_entailment'., acc: 59.15%
|
||||
Prompt: Does the relationship between the given sentences represent entailment or not_entailment? Respond with 'A. entailment' or 'B. not_entailment'., acc: 57.75%
|
||||
Prompt: Analyze the two provided sentences and decide if their relationship is 'A. entailment' or 'B. not_entailment'., acc: 56.34%
|
||||
Prompt: Identify whether the given pair of sentences demonstrates entailment or not_entailment. Answer with 'A. entailment' or 'B. not_entailment'., acc: 54.93%
|
||||
...
|
||||
Original prompt: Assess the connection between the following sentences and classify it as 'A. entailment' or 'B. not_entailment'.
|
||||
Attacked prompt: b"Assess the attach between the following sentences and sorted it as 'A. entailment' or 'B. not_entailment'."
|
||||
Original acc: 59.15%, attacked acc: 40.85%, dropped acc: 18.31%
|
||||
```
|
@ -60,6 +60,7 @@ OpenCompass 上手路线
|
||||
advanced_guides/new_model.md
|
||||
advanced_guides/evaluation_turbomind.md
|
||||
advanced_guides/code_eval_service.md
|
||||
advanced_guides/prompt_attack.md
|
||||
|
||||
.. _工具:
|
||||
.. toctree::
|
||||
|
@ -70,6 +70,7 @@ from .tydiqa import * # noqa: F401, F403
|
||||
from .wic import * # noqa: F401, F403
|
||||
from .winograd import * # noqa: F401, F403
|
||||
from .winogrande import * # noqa: F401, F403
|
||||
from .wnli import wnliDataset # noqa: F401, F403
|
||||
from .wsc import * # noqa: F401, F403
|
||||
from .xcopa import * # noqa: F401, F403
|
||||
from .xiezhi import XiezhiDataset, XiezhiRetriever # noqa: F401, F403
|
||||
|
26
opencompass/datasets/wnli.py
Normal file
26
opencompass/datasets/wnli.py
Normal file
@ -0,0 +1,26 @@
|
||||
from datasets import load_dataset
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class wnliDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(**kwargs):
|
||||
|
||||
dataset = load_dataset(**kwargs)
|
||||
# dataset = dataset['validation']
|
||||
gt_dict = {
|
||||
1: 'A',
|
||||
0: 'B',
|
||||
-1: -1,
|
||||
}
|
||||
|
||||
def preprocess(example):
|
||||
example['label_option'] = gt_dict[example['label']]
|
||||
return example
|
||||
|
||||
return dataset.map(preprocess)
|
@ -1,3 +1,4 @@
|
||||
from .icl_attack_inferencer import AttackInferencer # noqa
|
||||
from .icl_base_inferencer import BaseInferencer # noqa
|
||||
from .icl_clp_inferencer import CLPInferencer # noqa
|
||||
from .icl_gen_inferencer import GenInferencer # noqa
|
||||
|
210
opencompass/openicl/icl_inferencer/icl_attack_inferencer.py
Normal file
210
opencompass/openicl/icl_inferencer/icl_attack_inferencer.py
Normal file
@ -0,0 +1,210 @@
|
||||
"""Direct Generation Inferencer."""
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import List, Optional
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from opencompass.models.base import BaseModel
|
||||
from opencompass.registry import (ICL_EVALUATORS, ICL_INFERENCERS,
|
||||
TEXT_POSTPROCESSORS)
|
||||
|
||||
from ..icl_prompt_template import PromptTemplate
|
||||
from ..icl_retriever import BaseRetriever
|
||||
from ..utils.logging import get_logger
|
||||
from .icl_base_inferencer import BaseInferencer, GenInferencerOutputHandler
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ICL_INFERENCERS.register_module()
|
||||
class AttackInferencer(BaseInferencer):
|
||||
"""Generation Inferencer class to directly evaluate by generation.
|
||||
|
||||
Attributes:
|
||||
model (:obj:`BaseModelWrapper`, optional): The module to inference.
|
||||
max_out_len (:obj:`int`, optional): Maximum number of tokenized words
|
||||
of the output.
|
||||
adv_key (:obj:`str`): Prompt key in template to be attacked.
|
||||
metric_key (:obj:`str`): Metric key to be returned and compared.
|
||||
Defaults to `accuracy`.
|
||||
max_seq_len (:obj:`int`, optional): Maximum number of tokenized words
|
||||
allowed by the LM.
|
||||
batch_size (:obj:`int`, optional): Batch size for the
|
||||
:obj:`DataLoader`.
|
||||
output_json_filepath (:obj:`str`, optional): File path for output
|
||||
`JSON` file.
|
||||
output_json_filename (:obj:`str`, optional): File name for output
|
||||
`JSON` file.
|
||||
gen_field_replace_token (:obj:`str`, optional): Used to replace the
|
||||
generation field token when generating prompts.
|
||||
save_every (:obj:`int`, optional): Save intermediate results every
|
||||
`save_every` epochs.
|
||||
generation_kwargs (:obj:`Dict`, optional): Parameters for the
|
||||
:obj:`model.generate()` method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: BaseModel,
|
||||
max_out_len: int,
|
||||
adv_key: str,
|
||||
metric_key: str = 'accuracy',
|
||||
max_seq_len: Optional[int] = None,
|
||||
batch_size: Optional[int] = 1,
|
||||
gen_field_replace_token: Optional[str] = '',
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
save_every: Optional[int] = None,
|
||||
fix_id_list: Optional[List[int]] = None,
|
||||
dataset_cfg: Optional[List[int]] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
model=model,
|
||||
max_seq_len=max_seq_len,
|
||||
batch_size=batch_size,
|
||||
output_json_filename=output_json_filename,
|
||||
output_json_filepath=output_json_filepath,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.adv_key = adv_key
|
||||
self.metric_key = metric_key
|
||||
self.dataset_cfg = dataset_cfg
|
||||
self.eval_cfg = dataset_cfg['eval_cfg']
|
||||
self.output_column = dataset_cfg['reader_cfg']['output_column']
|
||||
self.gen_field_replace_token = gen_field_replace_token
|
||||
self.max_out_len = max_out_len
|
||||
self.fix_id_list = fix_id_list
|
||||
|
||||
if self.model.is_api and save_every is None:
|
||||
save_every = 1
|
||||
self.save_every = save_every
|
||||
|
||||
def predict(self, adv_prompt) -> List:
|
||||
# 1. Preparation for output logs
|
||||
output_handler = GenInferencerOutputHandler()
|
||||
|
||||
# if output_json_filepath is None:
|
||||
output_json_filepath = self.output_json_filepath
|
||||
# if output_json_filename is None:
|
||||
output_json_filename = self.output_json_filename
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
if 'Fix' in self.retriever.__class__.__name__:
|
||||
ice_idx_list = self.retriever.retrieve(self.fix_id_list)
|
||||
else:
|
||||
ice_idx_list = self.retriever.retrieve()
|
||||
|
||||
# 3. Generate prompts for testing input
|
||||
prompt_list, label_list = self.get_generation_prompt_list_from_retriever_indices( # noqa
|
||||
ice_idx_list, {self.adv_key: adv_prompt},
|
||||
self.retriever,
|
||||
self.gen_field_replace_token,
|
||||
max_seq_len=self.max_seq_len,
|
||||
ice_template=self.ice_template,
|
||||
prompt_template=self.prompt_template)
|
||||
|
||||
# Create tmp json file for saving intermediate results and future
|
||||
# resuming
|
||||
index = 0
|
||||
tmp_json_filepath = os.path.join(output_json_filepath,
|
||||
'tmp_' + output_json_filename)
|
||||
if osp.exists(tmp_json_filepath):
|
||||
# TODO: move resume to output handler
|
||||
tmp_result_dict = mmengine.load(tmp_json_filepath)
|
||||
output_handler.results_dict = tmp_result_dict
|
||||
index = len(tmp_result_dict)
|
||||
|
||||
# 4. Wrap prompts with Dataloader
|
||||
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
|
||||
|
||||
# 5. Inference for prompts in each batch
|
||||
logger.info('Starting inference process...')
|
||||
for entry in tqdm(dataloader, disable=not self.is_main_process):
|
||||
# 5-1. Inference with local model
|
||||
with torch.no_grad():
|
||||
parsed_entries = self.model.parse_template(entry, mode='gen')
|
||||
results = self.model.generate_from_template(
|
||||
entry, max_out_len=self.max_out_len)
|
||||
generated = results
|
||||
|
||||
# 5-3. Save current output
|
||||
for prompt, prediction in zip(parsed_entries, generated):
|
||||
output_handler.save_results(prompt, prediction, index)
|
||||
index = index + 1
|
||||
|
||||
# 5-4. Save intermediate results
|
||||
if (self.save_every is not None and index % self.save_every == 0
|
||||
and self.is_main_process):
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
'tmp_' + output_json_filename)
|
||||
|
||||
# 6. Output
|
||||
if self.is_main_process:
|
||||
os.makedirs(output_json_filepath, exist_ok=True)
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
output_json_filename)
|
||||
if osp.exists(tmp_json_filepath):
|
||||
os.remove(tmp_json_filepath)
|
||||
|
||||
pred_strs = [
|
||||
sample['prediction']
|
||||
for sample in output_handler.results_dict.values()
|
||||
]
|
||||
|
||||
if 'pred_postprocessor' in self.eval_cfg:
|
||||
kwargs = self.eval_cfg['pred_postprocessor'].copy()
|
||||
proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type'))
|
||||
pred_strs = [proc(s, **kwargs) for s in pred_strs]
|
||||
|
||||
icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator'])
|
||||
result = icl_evaluator.score(predictions=pred_strs,
|
||||
references=label_list)
|
||||
score = result.get(self.metric_key)
|
||||
# try to shrink score to range 0-1
|
||||
return score / 100 if score > 1 else score
|
||||
|
||||
def get_generation_prompt_list_from_retriever_indices(
|
||||
self,
|
||||
ice_idx_list: List[List[int]],
|
||||
extra_prompt: dict,
|
||||
retriever: BaseRetriever,
|
||||
gen_field_replace_token: str,
|
||||
max_seq_len: Optional[int] = None,
|
||||
ice_template: Optional[PromptTemplate] = None,
|
||||
prompt_template: Optional[PromptTemplate] = None):
|
||||
prompt_list = []
|
||||
label_list = []
|
||||
for idx, ice_idx in enumerate(ice_idx_list):
|
||||
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
|
||||
prompt = retriever.generate_prompt_for_adv_generate_task(
|
||||
idx,
|
||||
ice,
|
||||
extra_prompt,
|
||||
gen_field_replace_token=gen_field_replace_token,
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
label = retriever.test_ds[idx][self.output_column]
|
||||
label_list.append(label)
|
||||
if max_seq_len is not None:
|
||||
prompt_token_num = self.model.get_token_len_from_template(
|
||||
prompt, mode='gen')
|
||||
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
|
||||
ice_idx = ice_idx[:-1]
|
||||
ice = retriever.generate_ice(ice_idx,
|
||||
ice_template=ice_template)
|
||||
prompt = retriever.generate_prompt_for_adv_generate_task(
|
||||
idx,
|
||||
ice,
|
||||
extra_prompt,
|
||||
gen_field_replace_token=gen_field_replace_token,
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
prompt_token_num = self.model.get_token_len_from_template(
|
||||
prompt, mode='gen')
|
||||
prompt_list.append(prompt)
|
||||
return prompt_list, label_list
|
@ -206,3 +206,66 @@ class BaseRetriever:
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Leaving prompt as empty is not supported')
|
||||
|
||||
def generate_prompt_for_adv_generate_task(
|
||||
self,
|
||||
idx,
|
||||
ice,
|
||||
extra_prompt=dict(),
|
||||
gen_field_replace_token='',
|
||||
ice_template: Optional[PromptTemplate] = None,
|
||||
prompt_template: Optional[PromptTemplate] = None):
|
||||
"""Generate the prompt for one test example in generative evaluation
|
||||
with `prompt_template`. If `prompt_template` is not provided, the
|
||||
`ice_template` will be used to generate the prompt. The token
|
||||
represented by `gen_field_replace_token` will not be replaced by the
|
||||
generated text, or it will leaks the answer.
|
||||
|
||||
Args:
|
||||
idx (`int`): The index of the test example.
|
||||
ice (`str`): The in-context example for the test example.
|
||||
gen_field_replace_token (`str`): The token of the answer in the
|
||||
prompt. Defaults to ''.
|
||||
ice_template (`Optional[PromptTemplate]`): The template for
|
||||
in-context example. Defaults to None.
|
||||
prompt_template (`Optional[PromptTemplate]`): The template for
|
||||
prompt. Defaults to None.
|
||||
"""
|
||||
if prompt_template is not None and ice_template is not None:
|
||||
if prompt_template.ice_token is not None:
|
||||
return prompt_template.generate_item(
|
||||
{
|
||||
**self.test_ds[idx],
|
||||
**extra_prompt
|
||||
},
|
||||
output_field=self.dataset_reader.output_column,
|
||||
output_field_replace_token=gen_field_replace_token,
|
||||
ice_field_replace_token=ice)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'ice_token of prompt_template is not provided')
|
||||
elif ice_template is not None and prompt_template is None:
|
||||
if ice_template.ice_token is not None:
|
||||
return ice_template.generate_item(
|
||||
{
|
||||
**self.test_ds[idx],
|
||||
**extra_prompt
|
||||
},
|
||||
output_field=self.dataset_reader.output_column,
|
||||
output_field_replace_token=gen_field_replace_token,
|
||||
ice_field_replace_token=ice)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'ice_token of ice_template is not provided')
|
||||
elif ice_template is None and prompt_template is not None:
|
||||
return prompt_template.generate_item(
|
||||
{
|
||||
**self.test_ds[idx],
|
||||
**extra_prompt
|
||||
},
|
||||
output_field=self.dataset_reader.output_column,
|
||||
output_field_replace_token=gen_field_replace_token,
|
||||
ice_field_replace_token=ice)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Leaving prompt as empty is not supported')
|
||||
|
@ -1,3 +1,4 @@
|
||||
from .mm_infer import * # noqa: F401, F403
|
||||
from .openicl_attack import * # noqa: F401, F403
|
||||
from .openicl_eval import * # noqa: F401, F403
|
||||
from .openicl_infer import * # noqa: F401, F403
|
||||
|
204
opencompass/tasks/openicl_attack.py
Normal file
204
opencompass/tasks/openicl_attack.py
Normal file
@ -0,0 +1,204 @@
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import random
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
from opencompass.registry import (ICL_INFERENCERS, ICL_PROMPT_TEMPLATES,
|
||||
ICL_RETRIEVERS, TASKS)
|
||||
from opencompass.tasks.base import BaseTask
|
||||
from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg,
|
||||
get_infer_output_path, get_logger,
|
||||
task_abbr_from_cfg)
|
||||
|
||||
|
||||
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
|
||||
class OpenICLAttackTask(BaseTask):
|
||||
"""OpenICL Inference Task.
|
||||
|
||||
This task is used to run the inference process.
|
||||
"""
|
||||
|
||||
name_prefix = 'OpenICLAttack'
|
||||
log_subdir = 'logs/attack'
|
||||
output_subdir = 'attack'
|
||||
|
||||
def __init__(self, cfg: ConfigDict):
|
||||
super().__init__(cfg)
|
||||
run_cfg = self.model_cfgs[0].get('run_cfg', {})
|
||||
self.num_gpus = run_cfg.get('num_gpus', 0)
|
||||
self.num_procs = run_cfg.get('num_procs', 1)
|
||||
self.logger = get_logger()
|
||||
|
||||
def get_command(self, cfg_path, template):
|
||||
"""Get the command template for the task.
|
||||
|
||||
Args:
|
||||
cfg_path (str): The path to the config file of the task.
|
||||
template (str): The template which have '{task_cmd}' to format
|
||||
the command.
|
||||
"""
|
||||
script_path = __file__
|
||||
if self.num_gpus > 0:
|
||||
port = random.randint(12000, 32000)
|
||||
command = (f'torchrun --master_port={port} '
|
||||
f'--nproc_per_node {self.num_procs} '
|
||||
f'{script_path} {cfg_path}')
|
||||
else:
|
||||
command = f'python {script_path} {cfg_path}'
|
||||
|
||||
return template.format(task_cmd=command)
|
||||
|
||||
def prompt_selection(self, inferencer, prompts):
|
||||
prompt_dict = {}
|
||||
|
||||
for prompt in prompts:
|
||||
acc = inferencer.predict(prompt)
|
||||
prompt_dict[prompt] = acc
|
||||
self.logger.info('{:.2f}, {}\n'.format(acc * 100, prompt))
|
||||
|
||||
sorted_prompts = sorted(prompt_dict.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True)
|
||||
return sorted_prompts
|
||||
|
||||
def run(self):
|
||||
self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}')
|
||||
for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs):
|
||||
self.max_out_len = model_cfg.get('max_out_len', None)
|
||||
self.batch_size = model_cfg.get('batch_size', None)
|
||||
self.model = build_model_from_cfg(model_cfg)
|
||||
|
||||
for dataset_cfg in dataset_cfgs:
|
||||
self.model_cfg = model_cfg
|
||||
self.dataset_cfg = dataset_cfg
|
||||
self.infer_cfg = self.dataset_cfg['infer_cfg']
|
||||
self.dataset = build_dataset_from_cfg(self.dataset_cfg)
|
||||
self.sub_cfg = {
|
||||
'models': [self.model_cfg],
|
||||
'datasets': [[self.dataset_cfg]],
|
||||
}
|
||||
out_path = get_infer_output_path(
|
||||
self.model_cfg, self.dataset_cfg,
|
||||
osp.join(self.work_dir, 'attack'))
|
||||
if osp.exists(out_path):
|
||||
continue
|
||||
self._inference()
|
||||
|
||||
def _inference(self):
|
||||
self.logger.info(
|
||||
f'Start inferencing {task_abbr_from_cfg(self.sub_cfg)}')
|
||||
|
||||
assert hasattr(self.infer_cfg, 'ice_template') or hasattr(self.infer_cfg, 'prompt_template'), \
|
||||
'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501
|
||||
ice_template = None
|
||||
if hasattr(self.infer_cfg, 'ice_template'):
|
||||
ice_template = ICL_PROMPT_TEMPLATES.build(
|
||||
self.infer_cfg['ice_template'])
|
||||
|
||||
prompt_template = None
|
||||
if hasattr(self.infer_cfg, 'prompt_template'):
|
||||
prompt_template = ICL_PROMPT_TEMPLATES.build(
|
||||
self.infer_cfg['prompt_template'])
|
||||
|
||||
retriever_cfg = self.infer_cfg['retriever'].copy()
|
||||
retriever_cfg['dataset'] = self.dataset
|
||||
retriever = ICL_RETRIEVERS.build(retriever_cfg)
|
||||
|
||||
# set inferencer's default value according to model's config'
|
||||
inferencer_cfg = self.infer_cfg['inferencer']
|
||||
inferencer_cfg['model'] = self.model
|
||||
self._set_default_value(inferencer_cfg, 'max_out_len',
|
||||
self.max_out_len)
|
||||
self._set_default_value(inferencer_cfg, 'batch_size', self.batch_size)
|
||||
inferencer_cfg['max_seq_len'] = self.model_cfg['max_seq_len']
|
||||
inferencer_cfg['dataset_cfg'] = self.dataset_cfg
|
||||
inferencer = ICL_INFERENCERS.build(inferencer_cfg)
|
||||
|
||||
out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg,
|
||||
osp.join(self.work_dir, 'attack'))
|
||||
out_dir, out_file = osp.split(out_path)
|
||||
mkdir_or_exist(out_dir)
|
||||
|
||||
from config import LABEL_SET
|
||||
from prompt_attack.attack import create_attack
|
||||
from prompt_attack.goal_function import PromptGoalFunction
|
||||
|
||||
inferencer.retriever = retriever
|
||||
inferencer.prompt_template = prompt_template
|
||||
inferencer.ice_template = ice_template
|
||||
inferencer.output_json_filepath = out_dir
|
||||
inferencer.output_json_filename = out_file
|
||||
goal_function = PromptGoalFunction(
|
||||
inference=inferencer,
|
||||
query_budget=self.cfg['attack'].query_budget,
|
||||
logger=self.logger,
|
||||
model_wrapper=None,
|
||||
verbose='True')
|
||||
if self.cfg['attack']['dataset'] not in LABEL_SET:
|
||||
# set default
|
||||
self.cfg['attack']['dataset'] = 'mmlu'
|
||||
attack = create_attack(self.cfg['attack'], goal_function)
|
||||
|
||||
prompts = self.infer_cfg['inferencer']['original_prompt_list']
|
||||
sorted_prompts = self.prompt_selection(inferencer, prompts)
|
||||
if True:
|
||||
# if args.prompt_selection:
|
||||
for prompt, acc in sorted_prompts:
|
||||
self.logger.info('Prompt: {}, acc: {:.2f}%\n'.format(
|
||||
prompt, acc * 100))
|
||||
with open(out_dir + 'attacklog.txt', 'a+') as f:
|
||||
f.write('Prompt: {}, acc: {:.2f}%\n'.format(
|
||||
prompt, acc * 100))
|
||||
|
||||
for init_prompt, init_acc in sorted_prompts[:self.cfg['attack'].
|
||||
prompt_topk]:
|
||||
if init_acc > 0:
|
||||
init_acc, attacked_prompt, attacked_acc, dropped_acc = attack.attack( # noqa
|
||||
init_prompt)
|
||||
self.logger.info('Original prompt: {}'.format(init_prompt))
|
||||
self.logger.info('Attacked prompt: {}'.format(
|
||||
attacked_prompt.encode('utf-8')))
|
||||
self.logger.info(
|
||||
'Original acc: {:.2f}%, attacked acc: {:.2f}%, dropped acc: {:.2f}%' # noqa
|
||||
.format(init_acc * 100, attacked_acc * 100,
|
||||
dropped_acc * 100))
|
||||
with open(out_dir + 'attacklog.txt', 'a+') as f:
|
||||
f.write('Original prompt: {}\n'.format(init_prompt))
|
||||
f.write('Attacked prompt: {}\n'.format(
|
||||
attacked_prompt.encode('utf-8')))
|
||||
f.write(
|
||||
'Original acc: {:.2f}%, attacked acc: {:.2f}%, dropped acc: {:.2f}%\n\n' # noqa
|
||||
.format(init_acc * 100, attacked_acc * 100,
|
||||
dropped_acc * 100))
|
||||
else:
|
||||
with open(out_dir + 'attacklog.txt', 'a+') as f:
|
||||
f.write('Init acc is 0, skip this prompt\n')
|
||||
f.write('Original prompt: {}\n'.format(init_prompt))
|
||||
f.write('Original acc: {:.2f}% \n\n'.format(init_acc *
|
||||
100))
|
||||
|
||||
def _set_default_value(self, cfg: ConfigDict, key: str, value: Any):
|
||||
if key not in cfg:
|
||||
assert value, (f'{key} must be specified!')
|
||||
cfg[key] = value
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Model Inferencer')
|
||||
parser.add_argument('config', help='Config file path')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
cfg = Config.fromfile(args.config)
|
||||
start_time = time.time()
|
||||
inferencer = OpenICLAttackTask(cfg)
|
||||
inferencer.run()
|
||||
end_time = time.time()
|
||||
get_logger().info(f'time elapsed: {end_time - start_time:.2f}s')
|
Loading…
Reference in New Issue
Block a user