import json import importlib from pathlib import Path import json import os from ..evaluation.core_metrics import initialize_core_metric_evaluation_components def initialize_error_identification_components(task, prompt_style): """ Initialize error identification components. Args: task (str): The task for which error identification components are being initialized. prompt_style (str): The style of prompt for error identification. Returns: Module: The error identification module corresponding to the provided task and prompt style. """ prompt_style_to_error_module_map = {"basic":"basic_adversarial", "basic-CN":"basic_adversarial", "adversarial-ignore":"basic_adversarial", "adversarial-ignore-CN":"basic_adversarial", "adversarial-doubt":"basic_adversarial", "adversarial-doubt-CN":"basic_adversarial", "zero-shot-IcL":"icl", "zero-shot-IcL-CN":"icl", "one-shot-IcL":"icl", "one-shot-IcL-CN":"icl", "three-shot-IcL":"icl", "three-shot-IcL-CN":"icl", "zero-shot-CoT":"cot", "zero-shot-CoT-CN":"cot", "manual-CoT":"cot", "manual-CoT-CN":"cot"} task_to_error_module_map = { # association/ # correlation/ "CORR-B_correlation_CN":"CLADDER", "CORR-B_correlation_EN":"CLADDER", # explaining_away_effect/ "EAE-B_exp-away_CN":"CLADDER", "EAE-B_exp-away_EN":"CLADDER", # causal_discovery/ # abstract_reasoning/ "AR-B_CaLM-AR_CN":"AR-B_CaLM-AR", "AR-B_CaLM-AR_EN":"AR-B_CaLM-AR", # causal_attribution/ "CA-B_FA_CN":"CA-B", "CA-B_FA_EN":"CA-B", "CA-B_FP_CN":"CA-B", "CA-B_FP_EN":"CA-B", # event_causality_identification/ "ECI-B_CTB_CN":"ECI", "ECI-B_CTB_EN":"ECI", "ECI-B_ESC_CN":"ECI", "ECI-B_ESC_EN":"ECI", "ECI-B_MAVEN-ERE_CN":"ECI", "ECI-B_MAVEN-ERE_EN":"ECI", # pairwise_causal_discovery/ "PCD-B_COPA_CN":"PCD-B", "PCD-B_COPA_EN":"PCD-B", "PCD-B_E-CARE_CN":"PCD-B", "PCD-B_E-CARE_EN":"PCD-B", "PCD-C_COPA_CN":"PCD-C", "PCD-C_COPA_EN":"PCD-C", "PCD-C_E-CARE_CN":"PCD-C", "PCD-C_E-CARE_EN":"PCD-C", # counterfactual/ # actual_causality/ "AC-B_causal_judgement_CN":"AC-B_causal_judgement", "AC-B_causal_judgement_EN":"AC-B_causal_judgement", # counterfactual_reasoning/ "CR-B_det-counterfactual_CN":"CLADDER", "CR-B_det-counterfactual_EN":"CLADDER", "CR-C_CRASS_CN":"CR-C_CRASS", "CR-C_CRASS_EN":"CR-C_CRASS", # effect_of_the_treatment_on_the_treated/ "ETT-B_ETT-natural_CN":"Natural", "ETT-B_ETT-natural_EN":"Natural", "ETT-P_ETT-basic_CN":"Probability", "ETT-P_ETT-basic_EN":"Probability", "ETT-P_ETT-hard_CN":"Probability", "ETT-P_ETT-hard_EN":"Probability", # natural_direct_effect/ "NDE-B_NDE-natural_CN":"Natural", "NDE-B_NDE-natural_EN":"Natural", "NDE-P_NDE-basic_CN":"Probability", "NDE-P_NDE-basic_EN":"Probability", "NDE-P_NDE-hard_CN":"Probability", "NDE-P_NDE-hard_EN":"Probability", # natural_indirect_effect/ "NIE-B_NIE-natural_CN":"Natural", "NIE-B_NIE-natural_EN":"Natural", "NIE-P_NIE-basic_CN":"Probability", "NIE-P_NIE-basic_EN":"Probability", "NIE-P_NIE-hard_CN":"Probability", "NIE-P_NIE-hard_EN":"Probability", # probability_of_necessity/ "PN-P_PN-basic_CN":"Probability", "PN-P_PN-basic_EN":"Probability", "PN-P_PN-hard_CN":"Probability", "PN-P_PN-hard_EN":"Probability", # probability_of_sufficiency/ "PS-P_PS-basic_CN":"Probability", "PS-P_PS-basic_EN":"Probability", "PS-P_PS-hard_CN":"Probability", "PS-P_PS-hard_EN":"Probability", # intervention/ # average_treatment_effect/ "ATE-B_ATE-natural_CN":"Natural", "ATE-B_ATE-natural_EN":"Natural", "ATE-P_ATE-basic_CN":"Probability", "ATE-P_ATE-basic_EN":"Probability", "ATE-P_ATE-hard_CN":"Probability", "ATE-P_ATE-hard_EN":"Probability", # backdoor_adjustment_set/ "BAS-B_backadj_CN":"CLADDER", "BAS-B_backadj_EN":"CLADDER", "BAS-C_max-BAS_CN":"AS", "BAS-C_max-BAS_EN":"AS", "BAS-C_min-BAS_CN":"AS", "BAS-C_min-BAS_EN":"AS", "BAS-C_mix-BAS_CN":"AS", "BAS-C_mix-BAS_EN":"AS", # causal_effect_identification/ "CEI-B_0.2-UC_CN":"CEI-B", "CEI-B_0.2-UC_EN":"CEI-B", "CEI-B_0.4-UC_CN":"CEI-B", "CEI-B_0.4-UC_EN":"CEI-B", "CEI-B_0.6-UC_CN":"CEI-B", "CEI-B_0.6-UC_EN":"CEI-B", "CEI-B_0.8-UC_CN":"CEI-B", "CEI-B_0.8-UC_EN":"CEI-B", # collider_bias/ "CB-B_collider-bias_CN":"CLADDER", "CB-B_collider-bias_EN":"CLADDER", # controlled_direct_effect/ "CDE-B_CDE-natural_CN":"Natural", "CDE-B_CDE-natural_EN":"Natural", "CDE-P_CDE-basic_CN":"Probability", "CDE-P_CDE-basic_EN":"Probability", "CDE-P_CDE-hard_CN":"Probability", "CDE-P_CDE-hard_EN":"Probability", # frontdoor_adjustment_set/ "FAS-C_FAS_CN":"AS", "FAS-C_FAS_EN":"AS", # instrumental_variable/ "IV-C_CaLM-IV_CN":"AS", "IV-C_CaLM-IV_EN":"AS", } error_task_module_name = task_to_error_module_map.get(task) error_prompt_module_name = prompt_style_to_error_module_map.get(prompt_style) if error_task_module_name and error_prompt_module_name: error_module = importlib.import_module(f"opencompass.datasets.calm.evaluation.error.{error_prompt_module_name}.{error_task_module_name}") return error_module else: raise NotImplementedError(f"No get_score function found for task {task} and prompt {prompt_style}.") def identify_model_errors(items, task, prompt_style, gt_items): """ Identify errors in model responses based on provided items, task, and prompt style. Args: items (list): A list of items containing model responses. task (str): The task type, note that CEG-O_E-CARE is not supported for error analysis. prompt_style (str): The style of prompt used, note that explicit-function is not supported for error analysis. gt_items (list): A list of ground truth items. Returns: dict: A dictionary containing error metrics for the model responses. (Same response to all questions, language inconsistency, limitation of instruction-following, repetition, empty response.) """ if task == "CEG-O_E-CARE" or prompt_style in ["explicit-function", "explicit-function-CN"]: print("CEG-O_E-CARE and explicit-function prompts are not supported for error identification.") return language_error, nonstandrad, repetition, empty = 0., 0., 0., 0. error_module = initialize_error_identification_components(task, prompt_style) get_gt_label, get_pred_label, compute_acc = initialize_core_metric_evaluation_components(task) pred_list = [] for item, gt_item in zip(items, gt_items): pred_label = get_pred_label(item, gt_item, prompt_style, task.split("-")[0]) pred_error = get_item_error(item, task, error_module, prompt_style) pred_list.append(pred_label) language_error += pred_error["language_error"] nonstandrad += pred_error["nonstandrad"] repetition += pred_error["repetition"] empty += pred_error["empty"] abnormalities = error_module.check_abnormality(pred_list) return { "Same response to all questions": 1 if abnormalities!=0 else 0, "Language inconsistency": language_error / len(pred_list), "Limitation of instruction-following": nonstandrad / len(pred_list), "Repetition": repetition / len(pred_list), "Empty response": empty / len(pred_list), } def get_item_error(model_response, task, error_module, prompt_style): """ Analyze errors in a single model response for a given task and prompt style. Args: model_response (str): The model's response to analyze. task (str): The task type. error_module: The error module containing error identification methods. prompt_style (str): The style of prompt used. Returns: dict: A dictionary containing error metrics for the model response. (Language inconsistency, nonstandardization, repetition, empty response.) """ model_response = model_response.strip().lower() if 'CN' in task: language_error = error_module.contains_english(model_response) elif 'CN' not in task: language_error = error_module.contains_chinese(model_response) nonstandrad = error_module.check_standalization(model_response, prompt_style, type=task.split("-")[0]) repetition = error_module.check_repetition(model_response) empty = error_module.check_empty(model_response) return { "language_error": language_error, "nonstandrad": nonstandrad, "repetition": repetition, "empty": empty, }