[Feature] Contamination analysis for MMLU, Hellaswag, and ARC_c (#699)

* Contamination analysis for ARC_c, mmlu, and Hellaswag

* update `eval_contamination.py`

* update `contamination.py` summarizer

* fix `eval_contamination.py`

* add mmlu groups for contamination analysis
This commit is contained in:
liyucheng09 2024-01-08 07:51:48 +00:00 committed by GitHub
parent ba1b684fec
commit 0b2863039e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 476 additions and 67 deletions

View File

@ -0,0 +1,54 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccContaminationEvaluator
from opencompass.datasets import ARCDatasetClean as ARCDataset
ARC_c_reader_cfg = dict(
input_columns=['question', 'textA', 'textB', 'textC', 'textD'],
output_column='answerKey')
ARC_c_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template={
"A":
dict(
round=[
dict(role="HUMAN", prompt="Question: {question}\nAnswer: "),
dict(role="BOT", prompt="{textA}")
], ),
"B":
dict(
round=[
dict(role="HUMAN", prompt="Question: {question}\nAnswer: "),
dict(role="BOT", prompt="{textB}")
], ),
"C":
dict(
round=[
dict(role="HUMAN", prompt="Question: {question}\nAnswer: "),
dict(role="BOT", prompt="{textC}")
], ),
"D":
dict(
round=[
dict(role="HUMAN", prompt="Question: {question}\nAnswer: "),
dict(role="BOT", prompt="{textD}")
], ),
}),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=PPLInferencer))
ARC_c_eval_cfg = dict(evaluator=dict(type=AccContaminationEvaluator),
analyze_contamination=True)
ARC_c_datasets = [
dict(
type=ARCDataset,
abbr='ARC-c-test',
path='./data/ARC/ARC-c/ARC-Challenge-Test.jsonl',
reader_cfg=ARC_c_reader_cfg,
infer_cfg=ARC_c_infer_cfg,
eval_cfg=ARC_c_eval_cfg)
]

View File

@ -0,0 +1,35 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccContaminationEvaluator
from opencompass.datasets import hellaswagDatasetClean as hellaswagDataset
hellaswag_reader_cfg = dict(
input_columns=['ctx', 'A', 'B', 'C', 'D'],
output_column='label')
hellaswag_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template={
i: dict(round=[
dict(role="HUMAN", prompt="{ctx}"),
dict(role="BOT", prompt=f"{{{chr(ord('A') + i)}}}"),
])
for i in range(4)
}),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=PPLInferencer))
hellaswag_eval_cfg = dict(evaluator=dict(type=AccContaminationEvaluator),
analyze_contamination=True)
hellaswag_datasets = [
dict(
abbr='hellaswag',
type=hellaswagDataset,
path='./data/hellaswag/hellaswag.jsonl',
reader_cfg=hellaswag_reader_cfg,
infer_cfg=hellaswag_infer_cfg,
eval_cfg=hellaswag_eval_cfg)
]

View File

@ -0,0 +1,114 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import FixKRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccContaminationEvaluator
from opencompass.datasets import MMLUDatasetClean as MMLUDataset
# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader
# Please download the dataset from https://people.eecs.berkeley.edu/~hendrycks/data.tar
mmlu_reader_cfg = dict(
input_columns=["input", "A", "B", "C", "D"],
output_column="target",
train_split='dev')
mmlu_all_sets = [
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_physics",
"electrical_engineering",
"astronomy",
"anatomy",
"abstract_algebra",
"machine_learning",
"clinical_knowledge",
"global_facts",
"management",
"nutrition",
"marketing",
"professional_accounting",
"high_school_geography",
"international_law",
"moral_scenarios",
"computer_security",
"high_school_microeconomics",
"professional_law",
"medical_genetics",
"professional_psychology",
"jurisprudence",
"world_religions",
"philosophy",
"virology",
"high_school_chemistry",
"public_relations",
"high_school_macroeconomics",
"human_sexuality",
"elementary_mathematics",
"high_school_physics",
"high_school_computer_science",
"high_school_european_history",
"business_ethics",
"moral_disputes",
"high_school_statistics",
"miscellaneous",
"formal_logic",
"high_school_government_and_politics",
"prehistory",
"security_studies",
"high_school_biology",
"logical_fallacies",
"high_school_world_history",
"professional_medicine",
"high_school_mathematics",
"college_medicine",
"high_school_us_history",
"sociology",
"econometrics",
"high_school_psychology",
"human_aging",
"us_foreign_policy",
"conceptual_physics",
]
mmlu_datasets = []
for _name in mmlu_all_sets:
_hint = f'The following are multiple choice questions (with answers) about {_name.replace("_", " ")}.\n\n'
mmlu_infer_cfg = dict(
ice_template=dict(
type=PromptTemplate,
template={
opt:
f"{{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: {opt}\n"
for opt in ["A", "B", "C", "D"]
},
),
prompt_template=dict(
type=PromptTemplate,
template={
opt:
f"{_hint}</E>{{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: {opt}"
for opt in ["A", "B", "C", "D"]
},
ice_token="</E>",
),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer),
)
mmlu_eval_cfg = dict(evaluator=dict(type=AccContaminationEvaluator),
analyze_contamination=True)
mmlu_datasets.append(
dict(
abbr=f"lukaemon_mmlu_{_name}",
type=MMLUDataset,
path="./data/mmlu/",
name=_name,
reader_cfg=mmlu_reader_cfg,
infer_cfg=mmlu_infer_cfg,
eval_cfg=mmlu_eval_cfg,
))
del _name, _hint

View File

@ -2,10 +2,13 @@ from mmengine.config import read_base
with read_base():
from .datasets.ceval.ceval_clean_ppl import ceval_datasets
from .datasets.mmlu.mmlu_clean_ppl import mmlu_datasets
from .datasets.hellaswag.hellaswag_clean_ppl import hellaswag_datasets
from .datasets.ARC_c.ARC_c_clean_ppl import ARC_c_datasets
from .models.yi.hf_yi_6b import models as hf_yi_6b_model
from .models.qwen.hf_qwen_7b import models as hf_qwen_7b_model
from .models.hf_llama.hf_llama2_7b import models as hf_llama2_7b_model
from .summarizers.contamination import ceval_summarizer as summarizer
from .summarizers.contamination import summarizer
datasets = [*ceval_datasets]

View File

@ -59,6 +59,69 @@ ceval_category_weights = {
'physician': {'accuracy - clean': 24, 'accuracy - input contaminated': 1, 'accuracy - input-and-label contaminated': 24, 'accuracy - not labeled': 0},
}
mmlu_category_weights = {
"business_ethics": {"accuracy - clean": 44, "accuracy - input contaminated": 16, "accuracy - input-and-label contaminated": 38, "accuracy - not labeled": 1},
"security_studies": {"accuracy - clean": 188, "accuracy - input contaminated": 9, "accuracy - input-and-label contaminated": 47, "accuracy - not labeled": 0},
"high_school_us_history": {"accuracy - clean": 42, "accuracy - input contaminated": 0, "accuracy - input-and-label contaminated": 0, "accuracy - not labeled": 161},
"moral_disputes": {"accuracy - clean": 105, "accuracy - input contaminated": 13, "accuracy - input-and-label contaminated": 168, "accuracy - not labeled": 59},
"philosophy": {"accuracy - clean": 81, "accuracy - input contaminated": 11, "accuracy - input-and-label contaminated": 187, "accuracy - not labeled": 31},
"public_relations": {"accuracy - clean": 75, "accuracy - input contaminated": 8, "accuracy - input-and-label contaminated": 26, "accuracy - not labeled": 0},
"high_school_microeconomics": {"accuracy - clean": 82, "accuracy - input contaminated": 9, "accuracy - input-and-label contaminated": 146, "accuracy - not labeled": 0},
"human_sexuality": {"accuracy - clean": 108, "accuracy - input contaminated": 3, "accuracy - input-and-label contaminated": 15, "accuracy - not labeled": 4},
"professional_accounting": {"accuracy - clean": 88, "accuracy - input contaminated": 40, "accuracy - input-and-label contaminated": 152, "accuracy - not labeled": 1},
"high_school_government_and_politics": {"accuracy - clean": 104, "accuracy - input contaminated": 6, "accuracy - input-and-label contaminated": 82, "accuracy - not labeled": 0},
"sociology": {"accuracy - clean": 105, "accuracy - input contaminated": 4, "accuracy - input-and-label contaminated": 91, "accuracy - not labeled": 0},
"conceptual_physics": {"accuracy - clean": 79, "accuracy - input contaminated": 8, "accuracy - input-and-label contaminated": 147, "accuracy - not labeled": 0},
"human_aging": {"accuracy - clean": 208, "accuracy - input contaminated": 1, "accuracy - input-and-label contaminated": 13, "accuracy - not labeled": 0},
"high_school_psychology": {"accuracy - clean": 108, "accuracy - input contaminated": 26, "accuracy - input-and-label contaminated": 162, "accuracy - not labeled": 248},
"jurisprudence": {"accuracy - clean": 59, "accuracy - input contaminated": 5, "accuracy - input-and-label contaminated": 43, "accuracy - not labeled": 0},
"moral_scenarios": {"accuracy - clean": 320, "accuracy - input contaminated": 0, "accuracy - input-and-label contaminated": 0, "accuracy - not labeled": 574},
"college_medicine": {"accuracy - clean": 107, "accuracy - input contaminated": 16, "accuracy - input-and-label contaminated": 44, "accuracy - not labeled": 5},
"high_school_world_history": {"accuracy - clean": 61, "accuracy - input contaminated": 2, "accuracy - input-and-label contaminated": 0, "accuracy - not labeled": 173},
"virology": {"accuracy - clean": 104, "accuracy - input contaminated": 3, "accuracy - input-and-label contaminated": 58, "accuracy - not labeled": 0},
"high_school_statistics": {"accuracy - clean": 96, "accuracy - input contaminated": 43, "accuracy - input-and-label contaminated": 76, "accuracy - not labeled": 0},
"nutrition": {"accuracy - clean": 172, "accuracy - input contaminated": 11, "accuracy - input-and-label contaminated": 98, "accuracy - not labeled": 24},
"abstract_algebra": {"accuracy - clean": 84, "accuracy - input contaminated": 8, "accuracy - input-and-label contaminated": 7, "accuracy - not labeled": 0},
"high_school_geography": {"accuracy - clean": 91, "accuracy - input contaminated": 1, "accuracy - input-and-label contaminated": 105, "accuracy - not labeled": 0},
"econometrics": {"accuracy - clean": 62, "accuracy - input contaminated": 13, "accuracy - input-and-label contaminated": 38, "accuracy - not labeled": 0},
"marketing": {"accuracy - clean": 115, "accuracy - input contaminated": 15, "accuracy - input-and-label contaminated": 101, "accuracy - not labeled": 2},
"high_school_chemistry": {"accuracy - clean": 108, "accuracy - input contaminated": 25, "accuracy - input-and-label contaminated": 69, "accuracy - not labeled": 0},
"prehistory": {"accuracy - clean": 154, "accuracy - input contaminated": 5, "accuracy - input-and-label contaminated": 107, "accuracy - not labeled": 57},
"college_physics": {"accuracy - clean": 25, "accuracy - input contaminated": 20, "accuracy - input-and-label contaminated": 57, "accuracy - not labeled": 0},
"management": {"accuracy - clean": 35, "accuracy - input contaminated": 5, "accuracy - input-and-label contaminated": 62, "accuracy - not labeled": 0},
"college_biology": {"accuracy - clean": 91, "accuracy - input contaminated": 12, "accuracy - input-and-label contaminated": 40, "accuracy - not labeled": 0},
"high_school_biology": {"accuracy - clean": 128, "accuracy - input contaminated": 17, "accuracy - input-and-label contaminated": 135, "accuracy - not labeled": 29},
"high_school_physics": {"accuracy - clean": 42, "accuracy - input contaminated": 28, "accuracy - input-and-label contaminated": 80, "accuracy - not labeled": 0},
"logical_fallacies": {"accuracy - clean": 133, "accuracy - input contaminated": 5, "accuracy - input-and-label contaminated": 24, "accuracy - not labeled": 0},
"medical_genetics": {"accuracy - clean": 49, "accuracy - input contaminated": 6, "accuracy - input-and-label contaminated": 43, "accuracy - not labeled": 1},
"machine_learning": {"accuracy - clean": 71, "accuracy - input contaminated": 8, "accuracy - input-and-label contaminated": 32, "accuracy - not labeled": 0},
"professional_law": {"accuracy - clean": 401, "accuracy - input contaminated": 8, "accuracy - input-and-label contaminated": 5, "accuracy - not labeled": 1119},
"professional_psychology": {"accuracy - clean": 265, "accuracy - input contaminated": 9, "accuracy - input-and-label contaminated": 27, "accuracy - not labeled": 310},
"global_facts": {"accuracy - clean": 89, "accuracy - input contaminated": 5, "accuracy - input-and-label contaminated": 5, "accuracy - not labeled": 0},
"us_foreign_policy": {"accuracy - clean": 71, "accuracy - input contaminated": 3, "accuracy - input-and-label contaminated": 25, "accuracy - not labeled": 0},
"international_law": {"accuracy - clean": 73, "accuracy - input contaminated": 1, "accuracy - input-and-label contaminated": 46, "accuracy - not labeled": 0},
"clinical_knowledge": {"accuracy - clean": 172, "accuracy - input contaminated": 6, "accuracy - input-and-label contaminated": 86, "accuracy - not labeled": 0},
"high_school_mathematics": {"accuracy - clean": 178, "accuracy - input contaminated": 59, "accuracy - input-and-label contaminated": 32, "accuracy - not labeled": 0},
"high_school_computer_science": {"accuracy - clean": 62, "accuracy - input contaminated": 7, "accuracy - input-and-label contaminated": 28, "accuracy - not labeled": 2},
"college_computer_science": {"accuracy - clean": 68, "accuracy - input contaminated": 15, "accuracy - input-and-label contaminated": 15, "accuracy - not labeled": 1},
"electrical_engineering": {"accuracy - clean": 75, "accuracy - input contaminated": 8, "accuracy - input-and-label contaminated": 61, "accuracy - not labeled": 0},
"college_mathematics": {"accuracy - clean": 61, "accuracy - input contaminated": 13, "accuracy - input-and-label contaminated": 26, "accuracy - not labeled": 0},
"computer_security": {"accuracy - clean": 55, "accuracy - input contaminated": 8, "accuracy - input-and-label contaminated": 36, "accuracy - not labeled": 0},
"high_school_macroeconomics": {"accuracy - clean": 102, "accuracy - input contaminated": 14, "accuracy - input-and-label contaminated": 173, "accuracy - not labeled": 100},
"astronomy": {"accuracy - clean": 112, "accuracy - input contaminated": 4, "accuracy - input-and-label contaminated": 35, "accuracy - not labeled": 0},
"college_chemistry": {"accuracy - clean": 46, "accuracy - input contaminated": 19, "accuracy - input-and-label contaminated": 34, "accuracy - not labeled": 0},
"high_school_european_history": {"accuracy - clean": 41, "accuracy - input contaminated": 0, "accuracy - input-and-label contaminated": 0, "accuracy - not labeled": 123},
"miscellaneous": {"accuracy - clean": 256, "accuracy - input contaminated": 9, "accuracy - input-and-label contaminated": 40, "accuracy - not labeled": 477},
"formal_logic": {"accuracy - clean": 92, "accuracy - input contaminated": 12, "accuracy - input-and-label contaminated": 21, "accuracy - not labeled": 0},
"elementary_mathematics": {"accuracy - clean": 155, "accuracy - input contaminated": 31, "accuracy - input-and-label contaminated": 103, "accuracy - not labeled": 88},
"world_religions": {"accuracy - clean": 130, "accuracy - input contaminated": 4, "accuracy - input-and-label contaminated": 36, "accuracy - not labeled": 0},
"professional_medicine": {"accuracy - clean": 191, "accuracy - input contaminated": 43, "accuracy - input-and-label contaminated": 1, "accuracy - not labeled": 36},
"anatomy": {"accuracy - clean": 52, "accuracy - input contaminated": 6, "accuracy - input-and-label contaminated": 76, "accuracy - not labeled": 0},
}
ARC_weights = {'accuracy - clean': 836, 'accuracy - input contaminated': 53, 'accuracy - input-and-label contaminated': 283, 'accuracy - not labeled': 0}
hellaswag_weights = {'accuracy - clean': 5169, 'accuracy - input contaminated': 37, 'accuracy - input-and-label contaminated': 673, 'accuracy - not labeled': 4163}
ceval_stem = ['computer_network', 'operating_system', 'computer_architecture', 'college_programming', 'college_physics', 'college_chemistry', 'advanced_mathematics', 'probability_and_statistics', 'discrete_mathematics', 'electrical_engineer', 'metrology_engineer', 'high_school_mathematics', 'high_school_physics', 'high_school_chemistry', 'high_school_biology', 'middle_school_mathematics', 'middle_school_biology', 'middle_school_physics', 'middle_school_chemistry', 'veterinary_medicine']
ceval_social_science = ['college_economics', 'business_administration', 'marxism', 'mao_zedong_thought', 'education_science', 'teacher_qualification', 'high_school_politics', 'high_school_geography', 'middle_school_politics', 'middle_school_geography']
@ -67,7 +130,13 @@ ceval_other = ['civil_servant', 'sports_science', 'plant_protection', 'basic_med
ceval_hard = ['advanced_mathematics', 'discrete_mathematics', 'probability_and_statistics', 'college_chemistry', 'college_physics', 'high_school_mathematics', 'high_school_chemistry', 'high_school_physics']
ceval_all = ceval_stem + ceval_social_science + ceval_humanities + ceval_other
name_and_subsets = [
_mmlu_humanities = ['formal_logic', 'high_school_european_history', 'high_school_us_history', 'high_school_world_history', 'international_law', 'jurisprudence', 'logical_fallacies', 'moral_disputes', 'moral_scenarios', 'philosophy', 'prehistory', 'professional_law', 'world_religions']
_mmlu_stem = ['abstract_algebra', 'anatomy', 'astronomy', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_physics', 'computer_security', 'conceptual_physics', 'electrical_engineering', 'elementary_mathematics', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_mathematics', 'high_school_physics', 'high_school_statistics', 'machine_learning']
_mmlu_social_science = ['econometrics', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_microeconomics', 'high_school_psychology', 'human_sexuality', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy']
_mmlu_other = ['business_ethics', 'clinical_knowledge', 'college_medicine', 'global_facts', 'human_aging', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'nutrition', 'professional_accounting', 'professional_medicine', 'virology']
_mmlu_all = _mmlu_humanities + _mmlu_stem + _mmlu_social_science + _mmlu_other
ceval_name_and_subsets = [
('ceval', ceval_all),
('ceval-stem', ceval_stem),
('ceval-social-science', ceval_social_science),
@ -76,12 +145,20 @@ name_and_subsets = [
('ceval-hard', ceval_hard)
]
ceval_summary_groups = []
mmlu_name_and_subsets = [
('mmlu', _mmlu_all),
('mmlu-humanities', _mmlu_humanities),
('mmlu-stem', _mmlu_stem),
('mmlu-social-science', _mmlu_social_science),
('mmlu-other', _mmlu_other)
]
summary_groups = []
for metric_name in ['accuracy - clean', 'accuracy - input contaminated', 'accuracy - input-and-label contaminated']:
for dataset_abbr, subsets in name_and_subsets:
for dataset_abbr, subsets in ceval_name_and_subsets:
weights = {f'ceval-{i}': ceval_category_weights[i][metric_name] for i in subsets}
subsets = [[f'ceval-{i}', metric_name] for i in subsets]
ceval_summary_groups.append(
summary_groups.append(
{
'name': dataset_abbr,
'subsets': subsets,
@ -90,68 +167,39 @@ for metric_name in ['accuracy - clean', 'accuracy - input contaminated', 'accura
}
)
ceval_summarizer = dict(
for dataset_abbr, subsets in mmlu_name_and_subsets:
weights = {f'lukaemon_mmlu_{i}': mmlu_category_weights[i][metric_name] for i in subsets}
subsets = [[f'lukaemon_mmlu_{i}', metric_name] for i in subsets]
summary_groups.append(
{
'name': dataset_abbr,
'subsets': subsets,
'metric': metric_name,
'weights': weights,
}
)
summary_groups.append(
{
'name': 'hellaswag',
'subsets': [['hellaswag', metric_name]],
'metric': metric_name,
'weights': {'hellaswag': hellaswag_weights[metric_name]}
}
)
summary_groups.append(
{
'name': 'ARC-c-test',
'subsets': [['ARC-c-test', metric_name]],
'metric': metric_name,
'weights': {'ARC-c-test': ARC_weights[metric_name]}
}
)
summarizer = dict(
type=CircularSummarizer,
metric_types=['accuracy - clean', 'accuracy - input contaminated', 'accuracy - input-and-label contaminated'],
dataset_abbrs = [
'ceval-computer_network',
'ceval-operating_system',
'ceval-computer_architecture',
'ceval-college_programming',
'ceval-college_physics',
'ceval-college_chemistry',
'ceval-advanced_mathematics',
'ceval-probability_and_statistics',
'ceval-discrete_mathematics',
'ceval-electrical_engineer',
'ceval-metrology_engineer',
'ceval-high_school_mathematics',
'ceval-high_school_physics',
'ceval-high_school_chemistry',
'ceval-high_school_biology',
'ceval-middle_school_mathematics',
'ceval-middle_school_biology',
'ceval-middle_school_physics',
'ceval-middle_school_chemistry',
'ceval-veterinary_medicine',
'ceval-college_economics',
'ceval-business_administration',
'ceval-marxism',
'ceval-mao_zedong_thought',
'ceval-education_science',
'ceval-teacher_qualification',
'ceval-high_school_politics',
'ceval-high_school_geography',
'ceval-middle_school_politics',
'ceval-middle_school_geography',
'ceval-modern_chinese_history',
'ceval-ideological_and_moral_cultivation',
'ceval-logic',
'ceval-law',
'ceval-chinese_language_and_literature',
'ceval-art_studies',
'ceval-professional_tour_guide',
'ceval-legal_professional',
'ceval-high_school_chinese',
'ceval-high_school_history',
'ceval-middle_school_history',
'ceval-civil_servant',
'ceval-sports_science',
'ceval-plant_protection',
'ceval-basic_medicine',
'ceval-clinical_medicine',
'ceval-urban_and_rural_planner',
'ceval-accountant',
'ceval-fire_engineer',
'ceval-environmental_impact_assessment_engineer',
'ceval-tax_accountant',
'ceval-physician',
'ceval-humanities',
'ceval-stem',
'ceval-social-science',
'ceval-other',
'ceval-hard',
'ceval',
],
summary_groups=ceval_summary_groups,
dataset_abbrs = ['ceval', 'ceval-stem', 'ceval-social-science', 'ceval-humanities', 'ceval-other', 'ceval-hard', 'mmlu', 'mmlu-humanities', 'mmlu-stem', 'mmlu-social-science', 'mmlu-other', 'hellaswag', 'ARC-c-test'],
summary_groups=summary_groups,
)

View File

@ -1,4 +1,5 @@
import json
import os.path as osp
from datasets import Dataset
@ -30,3 +31,54 @@ class ARCDataset(BaseDataset):
'textD': question['choices'][3]['text'],
})
return Dataset.from_list(rows)
class ARCDatasetClean(BaseDataset):
# load the contamination annotations of CEval from
# https://github.com/liyucheng09/Contamination_Detector
@staticmethod
def load_contamination_annotations(path, split='val'):
import requests
assert split == 'test', 'We only have test set annotation for ARC'
annotation_cache_path = osp.join(
path, f'ARC_c_{split}_contamination_annotations.json')
if osp.exists(annotation_cache_path):
with open(annotation_cache_path, 'r') as f:
annotations = json.load(f)
return annotations
link_of_annotations = 'https://github.com/liyucheng09/Contamination_Detector/releases/download/v0.1.1rc/ARC_annotations.json' # noqa
annotations = json.loads(requests.get(link_of_annotations).text)
with open(annotation_cache_path, 'w') as f:
json.dump(annotations, f)
return annotations
@staticmethod
def load(path: str):
annotations = ARCDatasetClean.load_contamination_annotations(
osp.dirname(path), 'test')
with open(path, 'r', errors='ignore') as in_f:
rows = []
for line in in_f:
item = json.loads(line.strip())
id_ = item['id']
question = item['question']
if id_ in annotations:
is_clean = annotations[id_][0]
else:
is_clean = 'not labeled'
if len(question['choices']) != 4:
continue
labels = [c['label'] for c in question['choices']]
answerKey = 'ABCD'[labels.index(item['answerKey'])]
rows.append({
'question': question['stem'],
'answerKey': answerKey,
'textA': question['choices'][0]['text'],
'textB': question['choices'][1]['text'],
'textC': question['choices'][2]['text'],
'textD': question['choices'][3]['text'],
'is_clean': is_clean,
})
return Dataset.from_list(rows)

View File

@ -1,4 +1,5 @@
import json
import os.path as osp
from datasets import Dataset
@ -68,3 +69,50 @@ class hellaswagDataset_V3(BaseDataset):
})
dataset = Dataset.from_list(dataset)
return dataset
class hellaswagDatasetClean(BaseDataset):
# load the contamination annotations of CEval from
# https://github.com/liyucheng09/Contamination_Detector
@staticmethod
def load_contamination_annotations(path, split='val'):
import requests
assert split == 'val', 'We only use val set of hellaswag'
annotation_cache_path = osp.join(
path, f'hellaswag_{split}_contamination_annotations.json')
if osp.exists(annotation_cache_path):
with open(annotation_cache_path, 'r') as f:
annotations = json.load(f)
return annotations
link_of_annotations = 'https://github.com/liyucheng09/Contamination_Detector/releases/download/v0.1.1rc2/hellaswag_annotations_with_line_index.json' # noqa
annotations = json.loads(requests.get(link_of_annotations).text)
with open(annotation_cache_path, 'w') as f:
json.dump(annotations, f)
return annotations
@staticmethod
def load(path):
dataset = []
annotations = hellaswagDatasetClean.load_contamination_annotations(
osp.dirname(path))
with open(path, 'r', encoding='utf-8') as f:
for rwo_index, line in enumerate(f):
data = json.loads(line)
rwo_index = f'{rwo_index}'
if rwo_index in annotations:
is_clean = annotations[rwo_index][0]
else:
is_clean = 'not labeled'
dataset.append({
'ctx': data['query'].split(': ', 2)[-1],
'A': data['choices'][0],
'B': data['choices'][1],
'C': data['choices'][2],
'D': data['choices'][3],
'label': data['gold'],
'is_clean': is_clean,
})
dataset = Dataset.from_list(dataset)
return dataset

View File

@ -1,4 +1,5 @@
import csv
import json
import os.path as osp
from datasets import Dataset, DatasetDict
@ -31,3 +32,57 @@ class MMLUDataset(BaseDataset):
})
dataset[split] = Dataset.from_list(raw_data)
return dataset
class MMLUDatasetClean(BaseDataset):
# load the contamination annotations of CEval from
# https://github.com/liyucheng09/Contamination_Detector
@staticmethod
def load_contamination_annotations(path, split='val'):
import requests
assert split == 'test', 'We only use test set for MMLU'
annotation_cache_path = osp.join(
path, split, f'MMLU_{split}_contamination_annotations.json')
if osp.exists(annotation_cache_path):
with open(annotation_cache_path, 'r') as f:
annotations = json.load(f)
return annotations
link_of_annotations = 'https://github.com/liyucheng09/Contamination_Detector/releases/download/v0.1.1rc2/mmlu_annotations.json' # noqa
annotations = json.loads(requests.get(link_of_annotations).text)
with open(annotation_cache_path, 'w') as f:
json.dump(annotations, f)
return annotations
@staticmethod
def load(path: str, name: str):
dataset = DatasetDict()
for split in ['dev', 'test']:
raw_data = []
filename = osp.join(path, split, f'{name}_{split}.csv')
if split == 'test':
annotations = MMLUDatasetClean.load_contamination_annotations(
path, split)
with open(filename, encoding='utf-8') as f:
reader = csv.reader(f)
for row_index, row in enumerate(reader):
assert len(row) == 6
item = {
'input': row[0],
'A': row[1],
'B': row[2],
'C': row[3],
'D': row[4],
'target': row[5],
}
if split == 'test':
row_id = f'{name} {row_index}'
if row_id in annotations:
is_clean = annotations[row_id][0]
else:
is_clean = 'not labeled'
item['is_clean'] = is_clean
raw_data.append(item)
dataset[split] = Dataset.from_list(raw_data)
return dataset