mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
236 lines
10 KiB
Python
236 lines
10 KiB
Python
![]() |
import json
|
||
|
import importlib
|
||
|
from pathlib import Path
|
||
|
import json
|
||
|
import os
|
||
|
from ..evaluation.core_metrics import initialize_core_metric_evaluation_components
|
||
|
|
||
|
def initialize_error_identification_components(task, prompt_style):
|
||
|
"""
|
||
|
Initialize error identification components.
|
||
|
|
||
|
Args:
|
||
|
task (str): The task for which error identification components are being initialized.
|
||
|
prompt_style (str): The style of prompt for error identification.
|
||
|
|
||
|
Returns:
|
||
|
Module: The error identification module corresponding to the provided task and prompt style.
|
||
|
"""
|
||
|
prompt_style_to_error_module_map = {"basic":"basic_adversarial",
|
||
|
"basic-CN":"basic_adversarial",
|
||
|
"adversarial-ignore":"basic_adversarial",
|
||
|
"adversarial-ignore-CN":"basic_adversarial",
|
||
|
"adversarial-doubt":"basic_adversarial",
|
||
|
"adversarial-doubt-CN":"basic_adversarial",
|
||
|
"zero-shot-IcL":"icl",
|
||
|
"zero-shot-IcL-CN":"icl",
|
||
|
"one-shot-IcL":"icl",
|
||
|
"one-shot-IcL-CN":"icl",
|
||
|
"three-shot-IcL":"icl",
|
||
|
"three-shot-IcL-CN":"icl",
|
||
|
"zero-shot-CoT":"cot",
|
||
|
"zero-shot-CoT-CN":"cot",
|
||
|
"manual-CoT":"cot",
|
||
|
"manual-CoT-CN":"cot"}
|
||
|
task_to_error_module_map = {
|
||
|
# association/
|
||
|
# correlation/
|
||
|
"CORR-B_correlation_CN":"CLADDER",
|
||
|
"CORR-B_correlation_EN":"CLADDER",
|
||
|
# explaining_away_effect/
|
||
|
"EAE-B_exp-away_CN":"CLADDER",
|
||
|
"EAE-B_exp-away_EN":"CLADDER",
|
||
|
# causal_discovery/
|
||
|
# abstract_reasoning/
|
||
|
"AR-B_CaLM-AR_CN":"AR-B_CaLM-AR",
|
||
|
"AR-B_CaLM-AR_EN":"AR-B_CaLM-AR",
|
||
|
# causal_attribution/
|
||
|
"CA-B_FA_CN":"CA-B",
|
||
|
"CA-B_FA_EN":"CA-B",
|
||
|
"CA-B_FP_CN":"CA-B",
|
||
|
"CA-B_FP_EN":"CA-B",
|
||
|
# event_causality_identification/
|
||
|
"ECI-B_CTB_CN":"ECI",
|
||
|
"ECI-B_CTB_EN":"ECI",
|
||
|
"ECI-B_ESC_CN":"ECI",
|
||
|
"ECI-B_ESC_EN":"ECI",
|
||
|
"ECI-B_MAVEN-ERE_CN":"ECI",
|
||
|
"ECI-B_MAVEN-ERE_EN":"ECI",
|
||
|
# pairwise_causal_discovery/
|
||
|
"PCD-B_COPA_CN":"PCD-B",
|
||
|
"PCD-B_COPA_EN":"PCD-B",
|
||
|
"PCD-B_E-CARE_CN":"PCD-B",
|
||
|
"PCD-B_E-CARE_EN":"PCD-B",
|
||
|
"PCD-C_COPA_CN":"PCD-C",
|
||
|
"PCD-C_COPA_EN":"PCD-C",
|
||
|
"PCD-C_E-CARE_CN":"PCD-C",
|
||
|
"PCD-C_E-CARE_EN":"PCD-C",
|
||
|
# counterfactual/
|
||
|
# actual_causality/
|
||
|
"AC-B_causal_judgement_CN":"AC-B_causal_judgement",
|
||
|
"AC-B_causal_judgement_EN":"AC-B_causal_judgement",
|
||
|
# counterfactual_reasoning/
|
||
|
"CR-B_det-counterfactual_CN":"CLADDER",
|
||
|
"CR-B_det-counterfactual_EN":"CLADDER",
|
||
|
"CR-C_CRASS_CN":"CR-C_CRASS",
|
||
|
"CR-C_CRASS_EN":"CR-C_CRASS",
|
||
|
# effect_of_the_treatment_on_the_treated/
|
||
|
"ETT-B_ETT-natural_CN":"Natural",
|
||
|
"ETT-B_ETT-natural_EN":"Natural",
|
||
|
"ETT-P_ETT-basic_CN":"Probability",
|
||
|
"ETT-P_ETT-basic_EN":"Probability",
|
||
|
"ETT-P_ETT-hard_CN":"Probability",
|
||
|
"ETT-P_ETT-hard_EN":"Probability",
|
||
|
# natural_direct_effect/
|
||
|
"NDE-B_NDE-natural_CN":"Natural",
|
||
|
"NDE-B_NDE-natural_EN":"Natural",
|
||
|
"NDE-P_NDE-basic_CN":"Probability",
|
||
|
"NDE-P_NDE-basic_EN":"Probability",
|
||
|
"NDE-P_NDE-hard_CN":"Probability",
|
||
|
"NDE-P_NDE-hard_EN":"Probability",
|
||
|
# natural_indirect_effect/
|
||
|
"NIE-B_NIE-natural_CN":"Natural",
|
||
|
"NIE-B_NIE-natural_EN":"Natural",
|
||
|
"NIE-P_NIE-basic_CN":"Probability",
|
||
|
"NIE-P_NIE-basic_EN":"Probability",
|
||
|
"NIE-P_NIE-hard_CN":"Probability",
|
||
|
"NIE-P_NIE-hard_EN":"Probability",
|
||
|
# probability_of_necessity/
|
||
|
"PN-P_PN-basic_CN":"Probability",
|
||
|
"PN-P_PN-basic_EN":"Probability",
|
||
|
"PN-P_PN-hard_CN":"Probability",
|
||
|
"PN-P_PN-hard_EN":"Probability",
|
||
|
# probability_of_sufficiency/
|
||
|
"PS-P_PS-basic_CN":"Probability",
|
||
|
"PS-P_PS-basic_EN":"Probability",
|
||
|
"PS-P_PS-hard_CN":"Probability",
|
||
|
"PS-P_PS-hard_EN":"Probability",
|
||
|
# intervention/
|
||
|
# average_treatment_effect/
|
||
|
"ATE-B_ATE-natural_CN":"Natural",
|
||
|
"ATE-B_ATE-natural_EN":"Natural",
|
||
|
"ATE-P_ATE-basic_CN":"Probability",
|
||
|
"ATE-P_ATE-basic_EN":"Probability",
|
||
|
"ATE-P_ATE-hard_CN":"Probability",
|
||
|
"ATE-P_ATE-hard_EN":"Probability",
|
||
|
# backdoor_adjustment_set/
|
||
|
"BAS-B_backadj_CN":"CLADDER",
|
||
|
"BAS-B_backadj_EN":"CLADDER",
|
||
|
"BAS-C_max-BAS_CN":"AS",
|
||
|
"BAS-C_max-BAS_EN":"AS",
|
||
|
"BAS-C_min-BAS_CN":"AS",
|
||
|
"BAS-C_min-BAS_EN":"AS",
|
||
|
"BAS-C_mix-BAS_CN":"AS",
|
||
|
"BAS-C_mix-BAS_EN":"AS",
|
||
|
# causal_effect_identification/
|
||
|
"CEI-B_0.2-UC_CN":"CEI-B",
|
||
|
"CEI-B_0.2-UC_EN":"CEI-B",
|
||
|
"CEI-B_0.4-UC_CN":"CEI-B",
|
||
|
"CEI-B_0.4-UC_EN":"CEI-B",
|
||
|
"CEI-B_0.6-UC_CN":"CEI-B",
|
||
|
"CEI-B_0.6-UC_EN":"CEI-B",
|
||
|
"CEI-B_0.8-UC_CN":"CEI-B",
|
||
|
"CEI-B_0.8-UC_EN":"CEI-B",
|
||
|
# collider_bias/
|
||
|
"CB-B_collider-bias_CN":"CLADDER",
|
||
|
"CB-B_collider-bias_EN":"CLADDER",
|
||
|
# controlled_direct_effect/
|
||
|
"CDE-B_CDE-natural_CN":"Natural",
|
||
|
"CDE-B_CDE-natural_EN":"Natural",
|
||
|
"CDE-P_CDE-basic_CN":"Probability",
|
||
|
"CDE-P_CDE-basic_EN":"Probability",
|
||
|
"CDE-P_CDE-hard_CN":"Probability",
|
||
|
"CDE-P_CDE-hard_EN":"Probability",
|
||
|
# frontdoor_adjustment_set/
|
||
|
"FAS-C_FAS_CN":"AS",
|
||
|
"FAS-C_FAS_EN":"AS",
|
||
|
# instrumental_variable/
|
||
|
"IV-C_CaLM-IV_CN":"AS",
|
||
|
"IV-C_CaLM-IV_EN":"AS",
|
||
|
}
|
||
|
|
||
|
error_task_module_name = task_to_error_module_map.get(task)
|
||
|
error_prompt_module_name = prompt_style_to_error_module_map.get(prompt_style)
|
||
|
|
||
|
|
||
|
if error_task_module_name and error_prompt_module_name:
|
||
|
error_module = importlib.import_module(f"opencompass.datasets.calm.evaluation.error.{error_prompt_module_name}.{error_task_module_name}")
|
||
|
return error_module
|
||
|
else:
|
||
|
raise NotImplementedError(f"No get_score function found for task {task} and prompt {prompt_style}.")
|
||
|
|
||
|
def identify_model_errors(items, task, prompt_style, gt_items):
|
||
|
"""
|
||
|
Identify errors in model responses based on provided items, task, and prompt style.
|
||
|
|
||
|
Args:
|
||
|
items (list): A list of items containing model responses.
|
||
|
task (str): The task type, note that CEG-O_E-CARE is not supported for error analysis.
|
||
|
prompt_style (str): The style of prompt used, note that explicit-function is not supported for error analysis.
|
||
|
gt_items (list): A list of ground truth items.
|
||
|
|
||
|
Returns:
|
||
|
dict: A dictionary containing error metrics for the model responses. (Same response to all questions, language inconsistency, limitation of instruction-following, repetition, empty response.)
|
||
|
"""
|
||
|
if task == "CEG-O_E-CARE" or prompt_style in ["explicit-function", "explicit-function-CN"]:
|
||
|
print("CEG-O_E-CARE and explicit-function prompts are not supported for error identification.")
|
||
|
return
|
||
|
|
||
|
language_error, nonstandrad, repetition, empty = 0., 0., 0., 0.
|
||
|
error_module = initialize_error_identification_components(task, prompt_style)
|
||
|
get_gt_label, get_pred_label, compute_acc = initialize_core_metric_evaluation_components(task)
|
||
|
pred_list = []
|
||
|
|
||
|
for item, gt_item in zip(items, gt_items):
|
||
|
pred_label = get_pred_label(item, gt_item, prompt_style, task.split("-")[0])
|
||
|
pred_error = get_item_error(item, task, error_module, prompt_style)
|
||
|
|
||
|
pred_list.append(pred_label)
|
||
|
language_error += pred_error["language_error"]
|
||
|
nonstandrad += pred_error["nonstandrad"]
|
||
|
repetition += pred_error["repetition"]
|
||
|
empty += pred_error["empty"]
|
||
|
|
||
|
abnormalities = error_module.check_abnormality(pred_list)
|
||
|
|
||
|
return {
|
||
|
"Same response to all questions": 1 if abnormalities!=0 else 0,
|
||
|
"Language inconsistency": language_error / len(pred_list),
|
||
|
"Limitation of instruction-following": nonstandrad / len(pred_list),
|
||
|
"Repetition": repetition / len(pred_list),
|
||
|
"Empty response": empty / len(pred_list),
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
def get_item_error(model_response, task, error_module, prompt_style):
|
||
|
"""
|
||
|
Analyze errors in a single model response for a given task and prompt style.
|
||
|
|
||
|
Args:
|
||
|
model_response (str): The model's response to analyze.
|
||
|
task (str): The task type.
|
||
|
error_module: The error module containing error identification methods.
|
||
|
prompt_style (str): The style of prompt used.
|
||
|
|
||
|
Returns:
|
||
|
dict: A dictionary containing error metrics for the model response. (Language inconsistency, nonstandardization, repetition, empty response.)
|
||
|
"""
|
||
|
model_response = model_response.strip().lower()
|
||
|
if 'CN' in task:
|
||
|
language_error = error_module.contains_english(model_response)
|
||
|
elif 'CN' not in task:
|
||
|
language_error = error_module.contains_chinese(model_response)
|
||
|
|
||
|
nonstandrad = error_module.check_standalization(model_response, prompt_style, type=task.split("-")[0])
|
||
|
|
||
|
repetition = error_module.check_repetition(model_response)
|
||
|
|
||
|
empty = error_module.check_empty(model_response)
|
||
|
|
||
|
return {
|
||
|
"language_error": language_error,
|
||
|
"nonstandrad": nonstandrad,
|
||
|
"repetition": repetition,
|
||
|
"empty": empty,
|
||
|
}
|