[Feature] Add Data Contamination Analysis (#639)

* add contamination analysis to ceval

* fix bugs

* add contamination docs

* to pass CI check

* update

---------

Co-authored-by: zhangyifan1 <zhangyifan1@pjlab.org.cn>
Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
liyucheng09 2023-12-08 10:00:11 +08:00 committed by GitHub
parent 3a354bd1da
commit 05bbce8b08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 513 additions and 14 deletions

View File

@ -0,0 +1,107 @@
from typing import List
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import FixKRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccContaminationEvaluator
from opencompass.datasets import CEvalDatasetClean as CEvalDataset
ceval_subject_mapping = {
'computer_network': ['Computer Network', '计算机网络', 'STEM'],
'operating_system': ['Operating System', '操作系统', 'STEM'],
'computer_architecture': ['Computer Architecture', '计算机组成', 'STEM'],
'college_programming': ['College Programming', '大学编程', 'STEM'],
'college_physics': ['College Physics', '大学物理', 'STEM'],
'college_chemistry': ['College Chemistry', '大学化学', 'STEM'],
'advanced_mathematics': ['Advanced Mathematics', '高等数学', 'STEM'],
'probability_and_statistics': ['Probability and Statistics', '概率统计', 'STEM'],
'discrete_mathematics': ['Discrete Mathematics', '离散数学', 'STEM'],
'electrical_engineer': ['Electrical Engineer', '注册电气工程师', 'STEM'],
'metrology_engineer': ['Metrology Engineer', '注册计量师', 'STEM'],
'high_school_mathematics': ['High School Mathematics', '高中数学', 'STEM'],
'high_school_physics': ['High School Physics', '高中物理', 'STEM'],
'high_school_chemistry': ['High School Chemistry', '高中化学', 'STEM'],
'high_school_biology': ['High School Biology', '高中生物', 'STEM'],
'middle_school_mathematics': ['Middle School Mathematics', '初中数学', 'STEM'],
'middle_school_biology': ['Middle School Biology', '初中生物', 'STEM'],
'middle_school_physics': ['Middle School Physics', '初中物理', 'STEM'],
'middle_school_chemistry': ['Middle School Chemistry', '初中化学', 'STEM'],
'veterinary_medicine': ['Veterinary Medicine', '兽医学', 'STEM'],
'college_economics': ['College Economics', '大学经济学', 'Social Science'],
'business_administration': ['Business Administration', '工商管理', 'Social Science'],
'marxism': ['Marxism', '马克思主义基本原理', 'Social Science'],
'mao_zedong_thought': ['Mao Zedong Thought', '毛泽东思想和中国特色社会主义理论体系概论', 'Social Science'],
'education_science': ['Education Science', '教育学', 'Social Science'],
'teacher_qualification': ['Teacher Qualification', '教师资格', 'Social Science'],
'high_school_politics': ['High School Politics', '高中政治', 'Social Science'],
'high_school_geography': ['High School Geography', '高中地理', 'Social Science'],
'middle_school_politics': ['Middle School Politics', '初中政治', 'Social Science'],
'middle_school_geography': ['Middle School Geography', '初中地理', 'Social Science'],
'modern_chinese_history': ['Modern Chinese History', '近代史纲要', 'Humanities'],
'ideological_and_moral_cultivation': ['Ideological and Moral Cultivation', '思想道德修养与法律基础', 'Humanities'],
'logic': ['Logic', '逻辑学', 'Humanities'],
'law': ['Law', '法学', 'Humanities'],
'chinese_language_and_literature': ['Chinese Language and Literature', '中国语言文学', 'Humanities'],
'art_studies': ['Art Studies', '艺术学', 'Humanities'],
'professional_tour_guide': ['Professional Tour Guide', '导游资格', 'Humanities'],
'legal_professional': ['Legal Professional', '法律职业资格', 'Humanities'],
'high_school_chinese': ['High School Chinese', '高中语文', 'Humanities'],
'high_school_history': ['High School History', '高中历史', 'Humanities'],
'middle_school_history': ['Middle School History', '初中历史', 'Humanities'],
'civil_servant': ['Civil Servant', '公务员', 'Other'],
'sports_science': ['Sports Science', '体育学', 'Other'],
'plant_protection': ['Plant Protection', '植物保护', 'Other'],
'basic_medicine': ['Basic Medicine', '基础医学', 'Other'],
'clinical_medicine': ['Clinical Medicine', '临床医学', 'Other'],
'urban_and_rural_planner': ['Urban and Rural Planner', '注册城乡规划师', 'Other'],
'accountant': ['Accountant', '注册会计师', 'Other'],
'fire_engineer': ['Fire Engineer', '注册消防工程师', 'Other'],
'environmental_impact_assessment_engineer': ['Environmental Impact Assessment Engineer', '环境影响评价工程师', 'Other'],
'tax_accountant': ['Tax Accountant', '税务师', 'Other'],
'physician': ['Physician', '医师资格', 'Other'],
}
ceval_all_sets = list(ceval_subject_mapping.keys())
ceval_datasets = []
for _split in ["val"]:
for _name in ceval_all_sets:
_ch_name = ceval_subject_mapping[_name][1]
ceval_infer_cfg = dict(
ice_template=dict(
type=PromptTemplate,
template={
answer: dict(
begin="</E>",
round=[
dict(
role="HUMAN",
prompt=
f"以下是中国关于{_ch_name}考试的单项选择题,请选出其中的正确答案。\n{{question}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\n答案: "
),
dict(role="BOT", prompt=answer),
])
for answer in ["A", "B", "C", "D"]
},
ice_token="</E>",
),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer),
)
ceval_eval_cfg = dict(evaluator=dict(type=AccContaminationEvaluator), analyze_contamination=True)
ceval_datasets.append(
dict(
type=CEvalDataset,
path="./data/ceval/formal_ceval",
name=_name,
abbr="ceval-" + _name if _split == "val" else "ceval-test-" + _name,
reader_cfg=dict(
input_columns=["question", "A", "B", "C", "D"],
output_column="answer",
train_split="dev",
test_split=_split),
infer_cfg=ceval_infer_cfg,
eval_cfg=ceval_eval_cfg,
))
del _split, _name, _ch_name

View File

@ -0,0 +1,12 @@
from mmengine.config import read_base
with read_base():
from .datasets.ceval.ceval_clean_ppl import ceval_datasets
from .models.yi.hf_yi_6b import models as hf_yi_6b_model
from .models.qwen.hf_qwen_7b import models as hf_qwen_7b_model
from .models.hf_llama.hf_llama2_7b import models as hf_llama2_7b_model
from .summarizers.contamination import ceval_summarizer as summarizer
datasets = [*ceval_datasets]
models = [*hf_yi_6b_model, *hf_qwen_7b_model, *hf_llama2_7b_model]

View File

@ -0,0 +1,157 @@
from mmengine.config import read_base
from opencompass.summarizers import CircularSummarizer
with read_base():
from .groups.ceval import ceval_summary_groups
ceval_category_weights = {
'computer_network': {'accuracy - clean': 11, 'accuracy - input contaminated': 2, 'accuracy - input-and-label contaminated': 6, 'accuracy - not labeled': 0},
'operating_system': {'accuracy - clean': 14, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 5, 'accuracy - not labeled': 0},
'computer_architecture': {'accuracy - clean': 7, 'accuracy - input contaminated': 2, 'accuracy - input-and-label contaminated': 12, 'accuracy - not labeled': 0},
'college_programming': {'accuracy - clean': 22, 'accuracy - input contaminated': 1, 'accuracy - input-and-label contaminated': 14, 'accuracy - not labeled': 0},
'college_physics': {'accuracy - clean': 6, 'accuracy - input contaminated': 4, 'accuracy - input-and-label contaminated': 9, 'accuracy - not labeled': 0},
'college_chemistry': {'accuracy - clean': 21, 'accuracy - input contaminated': 1, 'accuracy - input-and-label contaminated': 2, 'accuracy - not labeled': 0},
'advanced_mathematics': {'accuracy - clean': 19, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 0, 'accuracy - not labeled': 0},
'probability_and_statistics': {'accuracy - clean': 18, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 0, 'accuracy - not labeled': 0},
'discrete_mathematics': {'accuracy - clean': 14, 'accuracy - input contaminated': 1, 'accuracy - input-and-label contaminated': 1, 'accuracy - not labeled': 0},
'electrical_engineer': {'accuracy - clean': 18, 'accuracy - input contaminated': 4, 'accuracy - input-and-label contaminated': 15, 'accuracy - not labeled': 0},
'metrology_engineer': {'accuracy - clean': 8, 'accuracy - input contaminated': 2, 'accuracy - input-and-label contaminated': 14, 'accuracy - not labeled': 0},
'high_school_mathematics': {'accuracy - clean': 18, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 0, 'accuracy - not labeled': 0},
'high_school_physics': {'accuracy - clean': 12, 'accuracy - input contaminated': 2, 'accuracy - input-and-label contaminated': 5, 'accuracy - not labeled': 0},
'high_school_chemistry': {'accuracy - clean': 16, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 3, 'accuracy - not labeled': 0},
'high_school_biology': {'accuracy - clean': 9, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 10, 'accuracy - not labeled': 0},
'middle_school_mathematics': {'accuracy - clean': 15, 'accuracy - input contaminated': 1, 'accuracy - input-and-label contaminated': 3, 'accuracy - not labeled': 0},
'middle_school_biology': {'accuracy - clean': 10, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 11, 'accuracy - not labeled': 0},
'middle_school_physics': {'accuracy - clean': 7, 'accuracy - input contaminated': 1, 'accuracy - input-and-label contaminated': 11, 'accuracy - not labeled': 0},
'middle_school_chemistry': {'accuracy - clean': 12, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 8, 'accuracy - not labeled': 0},
'veterinary_medicine': {'accuracy - clean': 13, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 10, 'accuracy - not labeled': 0},
'college_economics': {'accuracy - clean': 19, 'accuracy - input contaminated': 4, 'accuracy - input-and-label contaminated': 32, 'accuracy - not labeled': 0},
'business_administration': {'accuracy - clean': 13, 'accuracy - input contaminated': 2, 'accuracy - input-and-label contaminated': 18, 'accuracy - not labeled': 0},
'marxism': {'accuracy - clean': 10, 'accuracy - input contaminated': 1, 'accuracy - input-and-label contaminated': 8, 'accuracy - not labeled': 0},
'mao_zedong_thought': {'accuracy - clean': 6, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 18, 'accuracy - not labeled': 0},
'education_science': {'accuracy - clean': 11, 'accuracy - input contaminated': 1, 'accuracy - input-and-label contaminated': 17, 'accuracy - not labeled': 0},
'teacher_qualification': {'accuracy - clean': 18, 'accuracy - input contaminated': 2, 'accuracy - input-and-label contaminated': 23, 'accuracy - not labeled': 1},
'high_school_politics': {'accuracy - clean': 14, 'accuracy - input contaminated': 2, 'accuracy - input-and-label contaminated': 3, 'accuracy - not labeled': 0},
'high_school_geography': {'accuracy - clean': 11, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 8, 'accuracy - not labeled': 0},
'middle_school_politics': {'accuracy - clean': 20, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 1, 'accuracy - not labeled': 0},
'middle_school_geography': {'accuracy - clean': 3, 'accuracy - input contaminated': 1, 'accuracy - input-and-label contaminated': 8, 'accuracy - not labeled': 0},
'modern_chinese_history': {'accuracy - clean': 8, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 15, 'accuracy - not labeled': 0},
'ideological_and_moral_cultivation': {'accuracy - clean': 5, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 14, 'accuracy - not labeled': 0},
'logic': {'accuracy - clean': 15, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 7, 'accuracy - not labeled': 0},
'law': {'accuracy - clean': 15, 'accuracy - input contaminated': 3, 'accuracy - input-and-label contaminated': 6, 'accuracy - not labeled': 0},
'chinese_language_and_literature': {'accuracy - clean': 13, 'accuracy - input contaminated': 1, 'accuracy - input-and-label contaminated': 9, 'accuracy - not labeled': 0},
'art_studies': {'accuracy - clean': 14, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 19, 'accuracy - not labeled': 0},
'professional_tour_guide': {'accuracy - clean': 10, 'accuracy - input contaminated': 2, 'accuracy - input-and-label contaminated': 17, 'accuracy - not labeled': 0},
'legal_professional': {'accuracy - clean': 14, 'accuracy - input contaminated': 2, 'accuracy - input-and-label contaminated': 7, 'accuracy - not labeled': 0},
'high_school_chinese': {'accuracy - clean': 12, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 4, 'accuracy - not labeled': 3},
'high_school_history': {'accuracy - clean': 12, 'accuracy - input contaminated': 3, 'accuracy - input-and-label contaminated': 5, 'accuracy - not labeled': 0},
'middle_school_history': {'accuracy - clean': 11, 'accuracy - input contaminated': 1, 'accuracy - input-and-label contaminated': 9, 'accuracy - not labeled': 1},
'civil_servant': {'accuracy - clean': 19, 'accuracy - input contaminated': 5, 'accuracy - input-and-label contaminated': 17, 'accuracy - not labeled': 6},
'sports_science': {'accuracy - clean': 8, 'accuracy - input contaminated': 2, 'accuracy - input-and-label contaminated': 9, 'accuracy - not labeled': 0},
'plant_protection': {'accuracy - clean': 12, 'accuracy - input contaminated': 1, 'accuracy - input-and-label contaminated': 9, 'accuracy - not labeled': 0},
'basic_medicine': {'accuracy - clean': 9, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 10, 'accuracy - not labeled': 0},
'clinical_medicine': {'accuracy - clean': 14, 'accuracy - input contaminated': 1, 'accuracy - input-and-label contaminated': 7, 'accuracy - not labeled': 0},
'urban_and_rural_planner': {'accuracy - clean': 28, 'accuracy - input contaminated': 3, 'accuracy - input-and-label contaminated': 15, 'accuracy - not labeled': 0},
'accountant': {'accuracy - clean': 17, 'accuracy - input contaminated': 7, 'accuracy - input-and-label contaminated': 25, 'accuracy - not labeled': 0},
'fire_engineer': {'accuracy - clean': 12, 'accuracy - input contaminated': 1, 'accuracy - input-and-label contaminated': 18, 'accuracy - not labeled': 0},
'environmental_impact_assessment_engineer': {'accuracy - clean': 21, 'accuracy - input contaminated': 2, 'accuracy - input-and-label contaminated': 8, 'accuracy - not labeled': 0},
'tax_accountant': {'accuracy - clean': 31, 'accuracy - input contaminated': 0, 'accuracy - input-and-label contaminated': 18, 'accuracy - not labeled': 0},
'physician': {'accuracy - clean': 24, 'accuracy - input contaminated': 1, 'accuracy - input-and-label contaminated': 24, 'accuracy - not labeled': 0},
}
ceval_stem = ['computer_network', 'operating_system', 'computer_architecture', 'college_programming', 'college_physics', 'college_chemistry', 'advanced_mathematics', 'probability_and_statistics', 'discrete_mathematics', 'electrical_engineer', 'metrology_engineer', 'high_school_mathematics', 'high_school_physics', 'high_school_chemistry', 'high_school_biology', 'middle_school_mathematics', 'middle_school_biology', 'middle_school_physics', 'middle_school_chemistry', 'veterinary_medicine']
ceval_social_science = ['college_economics', 'business_administration', 'marxism', 'mao_zedong_thought', 'education_science', 'teacher_qualification', 'high_school_politics', 'high_school_geography', 'middle_school_politics', 'middle_school_geography']
ceval_humanities = ['modern_chinese_history', 'ideological_and_moral_cultivation', 'logic', 'law', 'chinese_language_and_literature', 'art_studies', 'professional_tour_guide', 'legal_professional', 'high_school_chinese', 'high_school_history', 'middle_school_history']
ceval_other = ['civil_servant', 'sports_science', 'plant_protection', 'basic_medicine', 'clinical_medicine', 'urban_and_rural_planner', 'accountant', 'fire_engineer', 'environmental_impact_assessment_engineer', 'tax_accountant', 'physician']
ceval_hard = ['advanced_mathematics', 'discrete_mathematics', 'probability_and_statistics', 'college_chemistry', 'college_physics', 'high_school_mathematics', 'high_school_chemistry', 'high_school_physics']
ceval_all = ceval_stem + ceval_social_science + ceval_humanities + ceval_other
name_and_subsets = [
('ceval', ceval_all),
('ceval-stem', ceval_stem),
('ceval-social-science', ceval_social_science),
('ceval-humanities', ceval_humanities),
('ceval-other', ceval_other),
('ceval-hard', ceval_hard)
]
ceval_summary_groups = []
for metric_name in ['accuracy - clean', 'accuracy - input contaminated', 'accuracy - input-and-label contaminated']:
for dataset_abbr, subsets in name_and_subsets:
weights = {f'ceval-{i}': ceval_category_weights[i][metric_name] for i in subsets}
subsets = [[f'ceval-{i}', metric_name] for i in subsets]
ceval_summary_groups.append(
{
'name': dataset_abbr,
'subsets': subsets,
'metric': metric_name,
'weights': weights,
}
)
ceval_summarizer = dict(
type=CircularSummarizer,
metric_types=['accuracy - clean', 'accuracy - input contaminated', 'accuracy - input-and-label contaminated'],
dataset_abbrs = [
'ceval-computer_network',
'ceval-operating_system',
'ceval-computer_architecture',
'ceval-college_programming',
'ceval-college_physics',
'ceval-college_chemistry',
'ceval-advanced_mathematics',
'ceval-probability_and_statistics',
'ceval-discrete_mathematics',
'ceval-electrical_engineer',
'ceval-metrology_engineer',
'ceval-high_school_mathematics',
'ceval-high_school_physics',
'ceval-high_school_chemistry',
'ceval-high_school_biology',
'ceval-middle_school_mathematics',
'ceval-middle_school_biology',
'ceval-middle_school_physics',
'ceval-middle_school_chemistry',
'ceval-veterinary_medicine',
'ceval-college_economics',
'ceval-business_administration',
'ceval-marxism',
'ceval-mao_zedong_thought',
'ceval-education_science',
'ceval-teacher_qualification',
'ceval-high_school_politics',
'ceval-high_school_geography',
'ceval-middle_school_politics',
'ceval-middle_school_geography',
'ceval-modern_chinese_history',
'ceval-ideological_and_moral_cultivation',
'ceval-logic',
'ceval-law',
'ceval-chinese_language_and_literature',
'ceval-art_studies',
'ceval-professional_tour_guide',
'ceval-legal_professional',
'ceval-high_school_chinese',
'ceval-high_school_history',
'ceval-middle_school_history',
'ceval-civil_servant',
'ceval-sports_science',
'ceval-plant_protection',
'ceval-basic_medicine',
'ceval-clinical_medicine',
'ceval-urban_and_rural_planner',
'ceval-accountant',
'ceval-fire_engineer',
'ceval-environmental_impact_assessment_engineer',
'ceval-tax_accountant',
'ceval-physician',
'ceval-humanities',
'ceval-stem',
'ceval-social-science',
'ceval-other',
'ceval-hard',
'ceval',
],
summary_groups=ceval_summary_groups,
)

View File

@ -0,0 +1,56 @@
# Contamination Evaluation Guidance
**Data contamination**, i.e.,
the presence of test data from these downstream tasks in the pre-training data of LLMs, may inflate LLM performance observed on many downstream tasks (e.g., summarization, natural language inference, text classification).
To evaluate LLM with contaminated data, we employed [Contamination Detector](https://github.com/liyucheng09/Contamination_Detector) to generate contamination labels.
## Introduction to [Detection Tools](https://github.com/liyucheng09/Contamination_Detector)
Contamination Detector aids in identifying and analyzing such potential contamination without requiring access to the LLMs' training data based on Internet presence verification, enabling even small teams and individuals to conduct robust evaluation.
### Method
- Using the Bing Search API to check if verbatim test examples appear online, which likely indicates inclusion in Common Crawl.
- Specifically verifying if pages containing verbatim test examples were indexed in the 2017-2020 Common Crawl, by only searching the URLs rather than
full contents.
#### Construct queries
for example:
**Question**: The flaw in Andersons ACT
theory was that some considered it \_\_\_\_.
**Choices**:
A: Only applicable to a motor system,
B: Untestable and thus, of uncertain sci-
entific value,
C: Lacking in definition for its ele-
ments
D: Overly complex in explaining the
operation of cognition,
**Answer**: B
**Query**: The flaw in Andersons ACT theory was that some considered it untestable and thus, of uncertain scientific value.
#### Improve Matching
To avoid potential false positives, the method is configured with two key settings:
- an order penalty (gamma of 0.8) for METEOR ensures matches respect sequence;
- matching is constrained to a window up
to 2x the query length, preventing partial or out-of-
context matches.
#### Contamination Type
- *input contamination* where only question is presented in the
matched pages but not answer;
- *input-and-label contamination* where both question and answer occur in the matched pages.
## Data Preparation
To be complete
## Evaluation Configuration
To be complete

View File

@ -68,6 +68,7 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass.
advanced_guides/longeval.md
advanced_guides/subjective_evaluation.md
advanced_guides/circular_eval.md
advanced_guides/contamination_eval.md
.. _Tools:
.. toctree::

View File

@ -0,0 +1,50 @@
# 污染评估指南
**数据污染**即下游任务的测试数据存在于大型语言模型LLMs的预训练数据中可能会夸大在许多下游任务例如摘要、自然语言推理、文本分类上观察到的LLM性能。
为了评估LLM在污染数据下的性能我们使用了[Contamination Detector](https://github.com/liyucheng09/Contamination_Detector)来生成污染标签。
## [检测工具](https://github.com/liyucheng09/Contamination_Detector)简介
污染检测器有助于在不需要访问LLM的训练数据的情况下基于互联网存在验证识别和分析此类潜在污染使得即使是小团队和个人也能进行强大的评估。
### 方法
- 使用必应搜索API检查逐字测试样例是否在线出现这可能表明其包含在Common Crawl中。
- 具体来说是通过仅搜索URL而不是完整内容来验证包含逐字测试样例的页面是否在2017-2020年的Common Crawl中被索引。
#### 构造查询
例如:
**问题**The flaw in Andersons ACT theory was that some considered it \_\_\_\_.
**选项**
A: Only applicable to a motor system,
B: Untestable and thus, of uncertain sci-
entific value,
C: Lacking in definition for its ele-
ments
D: Overly complex in explaining the
operation of cognition,
**答案**B
**查询**The flaw in Andersons ACT theory was that some considered it untestable and thus, of uncertain scientific value.
#### 提高匹配度
为避免可能的误报,该方法配置了两个关键设置:
- 用于METEOR的排序罚分gamma为0.8)确保匹配遵循序列;
- 匹配被限制在最多2倍查询长度的窗口内防止部分或脱离上下文的匹配。
#### 污染类型
- *input contamination*,其中只有问题出现在匹配页面中,但没有答案;
- *input-and-label contamination*,其中问题和答案都出现在匹配页面中。
## 数据准备
待完成
## 评估配置
待完成

View File

@ -68,6 +68,7 @@ OpenCompass 上手路线
advanced_guides/longeval.md
advanced_guides/subjective_evaluation.md
advanced_guides/circular_eval.md
advanced_guides/contamination_eval.md
.. _工具:
.. toctree::

View File

@ -1,4 +1,5 @@
import csv
import json
import os.path as osp
from datasets import Dataset, DatasetDict
@ -26,3 +27,50 @@ class CEvalDataset(BaseDataset):
dataset.setdefault(split, []).append(item)
dataset = {i: Dataset.from_list(dataset[i]) for i in dataset}
return DatasetDict(dataset)
class CEvalDatasetClean(BaseDataset):
# load the contamination annotations of CEval from
# https://github.com/liyucheng09/Contamination_Detector
@staticmethod
def load_contamination_annotations(path, split='val'):
import requests
assert split == 'val', 'Now we only have annotations for val set'
annotation_cache_path = osp.join(
path, split, 'ceval_contamination_annotations.json')
if osp.exists(annotation_cache_path):
with open(annotation_cache_path, 'r') as f:
annotations = json.load(f)
return annotations
link_of_annotations = 'https://github.com/liyucheng09/Contamination_Detector/releases/download/v0.1.1rc/ceval_annotations.json' # noqa
annotations = json.loads(requests.get(link_of_annotations).text)
with open(annotation_cache_path, 'w') as f:
json.dump(annotations, f)
return annotations
@staticmethod
def load(path: str, name: str):
dataset = {}
for split in ['dev', 'val', 'test']:
if split == 'val':
annotations = CEvalDatasetClean.load_contamination_annotations(
path, split)
filename = osp.join(path, split, f'{name}_{split}.csv')
with open(filename, encoding='utf-8') as f:
reader = csv.reader(f)
header = next(reader)
for row_index, row in enumerate(reader):
item = dict(zip(header, row))
item.setdefault('explanation', '')
item.setdefault('answer', '')
if split == 'val':
row_id = f'{name}-{row_index}'
if row_id in annotations:
item['is_clean'] = annotations[row_id][0]
else:
item['is_clean'] = 'not labeled'
dataset.setdefault(split, []).append(item)
dataset = {i: Dataset.from_list(dataset[i]) for i in dataset}
return DatasetDict(dataset)

View File

@ -4,6 +4,7 @@ from typing import List
import evaluate
import numpy as np
from datasets import Dataset
from opencompass.registry import ICL_EVALUATORS
@ -132,6 +133,53 @@ class AccEvaluator(HuggingfaceEvaluator):
return scores
@ICL_EVALUATORS.register_module()
class AccContaminationEvaluator(AccEvaluator):
"""Accuracy evaluator."""
def score(self, predictions: List, references: List,
test_set: Dataset) -> dict:
# group the predictions and references by their contamination status
clean_predictions, clean_references = [], []
input_contaminated_predictions, input_contaminated_references = [], []
input_and_label_contaminated_predictions, \
input_and_label_contaminated_references = [], []
for pred, ref, is_clean in zip(predictions, references,
test_set['is_clean']):
if is_clean == 'clean':
clean_predictions.append(pred)
clean_references.append(ref)
elif is_clean == 'input contamination':
input_contaminated_predictions.append(pred)
input_contaminated_references.append(ref)
elif is_clean == 'input-and-label contamination':
input_and_label_contaminated_predictions.append(pred)
input_and_label_contaminated_references.append(ref)
clean_results = super().score(clean_predictions, clean_references)
input_contaminated_results = super().score(
input_contaminated_predictions, input_contaminated_references)
input_and_label_contaminated_results = super().score(
input_and_label_contaminated_predictions,
input_and_label_contaminated_references)
# rename the keys of the results, add 'clean, 'input contaminated',
# 'input-and-label contaminated' as prefixes
clean_results = {f'{k} - clean': v for k, v in clean_results.items()}
input_contaminated_results = {
f'{k} - input contaminated': v
for k, v in input_contaminated_results.items()
}
input_and_label_contaminated_results = {
f'{k} - input-and-label contaminated': v
for k, v in input_and_label_contaminated_results.items()
}
return {
**clean_results,
**input_contaminated_results,
**input_and_label_contaminated_results
}
@ICL_EVALUATORS.register_module()
class RougeEvaluator(HuggingfaceEvaluator):
"""Rouge evaluator.

View File

@ -126,19 +126,37 @@ class DefaultSummarizer:
summary_groups = self.summary_groups
for sg in summary_groups:
for model_abbr in self.model_abbrs:
available_count = sum(dataset_abbr in parsed_results[model_abbr] for dataset_abbr in sg['subsets'])
if available_count == 0:
available_metrics, missing_metrics = [], []
for i in sg['subsets']:
if isinstance(i, (list, tuple)):
if i[0] in parsed_results[model_abbr] and i[1] in parsed_results[model_abbr][i[0]]:
available_metrics.append(i)
else:
missing_metrics.append(i)
else:
if i in parsed_results[model_abbr]:
available_metrics.append(i)
else:
missing_metrics.append(i)
if len(available_metrics) == 0:
continue
if available_count != len(sg['subsets']):
raw_results[model_abbr][sg['name']] = {'error': 'missing datasets: {}'.format(set(sg['subsets']) - set(parsed_results[model_abbr].keys()))}
if len(missing_metrics) != 0:
raw_results[model_abbr][sg['name']] = {'error': 'missing metrics: {}'.format(missing_metrics)}
continue
if sg.get('std', False):
default_metric = 'standard_deviation'
elif sg.get('weights', []):
default_metric = 'weighted_average'
if 'metric' in sg:
default_metric = sg['metric']
need_smart_metric = False
else:
default_metric = 'naive_average'
need_smart_metric = True
if sg.get('std', False):
default_metric = 'standard_deviation'
elif sg.get('weights', []):
default_metric = 'weighted_average'
else:
default_metric = 'naive_average'
scores, eval_modes, group_metrics = {}, [], None
if any(isinstance(dataset_abbr, (list, tuple)) for dataset_abbr in sg['subsets']) and \
any(isinstance(dataset_abbr, str) for dataset_abbr in sg['subsets']):
@ -151,7 +169,7 @@ class DefaultSummarizer:
eval_modes.append(dataset_eval_mode.get(dataset_abbr, 'unknown'))
else:
group_metrics = list(functools.reduce(lambda a, b: a & b, [set(dataset_metrics[dataset_abbr]) for dataset_abbr in sg['subsets']]))
if len(group_metrics) > 1:
if need_smart_metric and len(group_metrics) > 1:
for metric in group_metrics:
for dataset_abbr in sg['subsets']:
scores.setdefault(metric, {})[dataset_abbr] = parsed_results[model_abbr][dataset_abbr][metric]
@ -163,15 +181,16 @@ class DefaultSummarizer:
scores.setdefault(default_metric, {})[dataset_abbr] = parsed_results[model_abbr][dataset_abbr][metric]
eval_modes.append(dataset_eval_mode.get(dataset_abbr, 'unknown'))
result = {}
result = parsed_results[model_abbr].get(sg['name'], {})
for metric in scores:
if default_metric == 'standard_deviation':
avg = sum(scores[metric].values()) / len(scores[metric])
variance = sum((k - avg) ** 2 for k in scores[metric]) / len(scores[metric])
scores[metric] = result[metric] = math.sqrt(variance)
else:
if default_metric == 'weighted_average':
numerator = sum(scores[metric][k] * sg['weights'][k] for k in sg['weights'])
if sg.get('weights', []):
# check sg['weights'][k] != 0 in case of scores[metric][k] is NaN
numerator = sum(scores[metric][k] * sg['weights'][k] for k in sg['weights'] if sg['weights'][k] != 0)
denominator = sum(sg['weights'].values())
else:
numerator = sum(scores[metric].values())
@ -182,7 +201,7 @@ class DefaultSummarizer:
# add to global results
raw_results[model_abbr][sg['name']] = scores
parsed_results[model_abbr][sg['name']]= result
parsed_results[model_abbr][sg['name']] = result
dataset_metrics[sg['name']] = group_metrics
dataset_eval_mode[sg['name']] = eval_mode