import importlib from pathlib import Path import json import os from ..utils.load_items import load_query_instances def get_get_prompt_func(task): """ Returns the appropriate prompt generation function based on the given task. Args: task (str): The name of the task for which the prompt function is required. Returns: function: The prompt generation function for the specified task. Raises: NotImplementedError: If no prompt function is found for the given task. """ task_to_module_map = { # association/ # correlation/ "CORR-B_correlation_CN":"CORR-B_correlation", "CORR-B_correlation_EN":"CORR-B_correlation", # explaining_away_effect/ "EAE-B_exp-away_CN":"EAE-B_exp-away", "EAE-B_exp-away_EN":"EAE-B_exp-away", # causal_discovery/ # abstract_reasoning/ "AR-B_CaLM-AR_CN":"AR-B_CaLM-AR", "AR-B_CaLM-AR_EN":"AR-B_CaLM-AR", # causal_attribution/ "CA-B_FA_CN":"CA-B_FA", "CA-B_FA_EN":"CA-B_FA", "CA-B_FP_CN":"CA-B_FP", "CA-B_FP_EN":"CA-B_FP", # event_causality_identification/ "ECI-B_CTB_CN":"ECI-B_CTB", "ECI-B_CTB_EN":"ECI-B_CTB", "ECI-B_ESC_CN":"ECI-B_ESC", "ECI-B_ESC_EN":"ECI-B_ESC", "ECI-B_MAVEN-ERE_CN":"ECI-B_MAVEN-ERE", "ECI-B_MAVEN-ERE_EN":"ECI-B_MAVEN-ERE", # pairwise_causal_discovery/ "PCD-B_COPA_CN":"PCD-B_COPA", "PCD-B_COPA_EN":"PCD-B_COPA", "PCD-B_E-CARE_CN":"PCD-B_E-CARE", "PCD-B_E-CARE_EN":"PCD-B_E-CARE", "PCD-C_COPA_CN":"PCD-C_COPA", "PCD-C_COPA_EN":"PCD-C_COPA", "PCD-C_E-CARE_CN":"PCD-C_E-CARE", "PCD-C_E-CARE_EN":"PCD-C_E-CARE", # counterfactual/ # actual_causality/ "AC-B_causal_judgement_CN":"AC-B_causal_judgement", "AC-B_causal_judgement_EN":"AC-B_causal_judgement", # causal_explanation_generation/ "CEG-O_E-CARE_CN":"CEG-O_E-CARE", "CEG-O_E-CARE_EN":"CEG-O_E-CARE", # counterfactual_reasoning/ "CR-B_det-counterfactual_CN":"CR-B_det-counterfactual", "CR-B_det-counterfactual_EN":"CR-B_det-counterfactual", "CR-C_CRASS_CN":"CR-C_CRASS", "CR-C_CRASS_EN":"CR-C_CRASS", # effect_of_the_treatment_on_the_treated/ "ETT-B_ETT-natural_CN":"ETT", "ETT-B_ETT-natural_EN":"ETT", "ETT-P_ETT-basic_CN":"ETT", "ETT-P_ETT-basic_EN":"ETT", "ETT-P_ETT-hard_CN":"ETT", "ETT-P_ETT-hard_EN":"ETT", # natural_direct_effect/ "NDE-B_NDE-natural_CN":"NDE", "NDE-B_NDE-natural_EN":"NDE", "NDE-P_NDE-basic_CN":"NDE", "NDE-P_NDE-basic_EN":"NDE", "NDE-P_NDE-hard_CN":"NDE", "NDE-P_NDE-hard_EN":"NDE", # natural_indirect_effect/ "NIE-B_NIE-natural_CN":"NIE", "NIE-B_NIE-natural_EN":"NIE", "NIE-P_NIE-basic_CN":"NIE", "NIE-P_NIE-basic_EN":"NIE", "NIE-P_NIE-hard_CN":"NIE", "NIE-P_NIE-hard_EN":"NIE", # probability_of_necessity/ "PN-P_PN-basic_CN":"PN", "PN-P_PN-basic_EN":"PN", "PN-P_PN-hard_CN":"PN", "PN-P_PN-hard_EN":"PN", # probability_of_sufficiency/ "PS-P_PS-basic_CN":"PS", "PS-P_PS-basic_EN":"PS", "PS-P_PS-hard_CN":"PS", "PS-P_PS-hard_EN":"PS", # intervention/ # average_treatment_effect/ "ATE-B_ATE-natural_CN":"ATE", "ATE-B_ATE-natural_EN":"ATE", "ATE-P_ATE-basic_CN":"ATE", "ATE-P_ATE-basic_EN":"ATE", "ATE-P_ATE-hard_CN":"ATE", "ATE-P_ATE-hard_EN":"ATE", # backdoor_adjustment_set/ "BAS-B_backadj_CN":"BAS-B_backadj", "BAS-B_backadj_EN":"BAS-B_backadj", "BAS-C_max-BAS_CN":"BAS-C_max-BAS", "BAS-C_max-BAS_EN":"BAS-C_max-BAS", "BAS-C_min-BAS_CN":"BAS-C_min-BAS", "BAS-C_min-BAS_EN":"BAS-C_min-BAS", "BAS-C_mix-BAS_CN":"BAS-C_mix-BAS", "BAS-C_mix-BAS_EN":"BAS-C_mix-BAS", # causal_effect_identification/ "CEI-B_0.2-UC_CN":"CEI-B", "CEI-B_0.2-UC_EN":"CEI-B", "CEI-B_0.4-UC_CN":"CEI-B", "CEI-B_0.4-UC_EN":"CEI-B", "CEI-B_0.6-UC_CN":"CEI-B", "CEI-B_0.6-UC_EN":"CEI-B", "CEI-B_0.8-UC_CN":"CEI-B", "CEI-B_0.8-UC_EN":"CEI-B", # collider_bias/ "CB-B_collider-bias_CN":"CB-B_collider-bias", "CB-B_collider-bias_EN":"CB-B_collider-bias", # controlled_direct_effect/ "CDE-B_CDE-natural_CN":"CDE", "CDE-B_CDE-natural_EN":"CDE", "CDE-P_CDE-basic_CN":"CDE", "CDE-P_CDE-basic_EN":"CDE", "CDE-P_CDE-hard_CN":"CDE", "CDE-P_CDE-hard_EN":"CDE", # frontdoor_adjustment_set/ "FAS-C_FAS_CN":"FAS-C_FAS", "FAS-C_FAS_EN":"FAS-C_FAS", # instrumental_variable/ "IV-C_CaLM-IV_CN":"IV-C_CaLM-IV", "IV-C_CaLM-IV_EN":"IV-C_CaLM-IV", } module_name = task_to_module_map.get(task) if module_name: module = importlib.import_module("opencompass.datasets.calm.data_processing.prompt."+module_name) return module.get_prompt else: raise NotImplementedError(f"No get_prompt function found for task {task}.") def generate_question_list(dataset_path, prompt_style): """ Generates a list of questions from the dataset based on the specified prompt style. Args: dataset_path (str): The path to the dataset JSON file. prompt_style (str): The style of prompt to be used for generating questions. Returns: list: A list of question dictionaries, each containing an item from the dataset along with its corresponding question. Raises: AssertionError: If the task name and prompt style do not match the expected language suffix. """ # Extract task name from dataset path dataset_path = Path(dataset_path) task_name = dataset_path.name[:-len(".json")] # Validate prompt style based on task language if task_name.endswith("CN"): assert prompt_style.endswith("-CN") else: assert not prompt_style.endswith("-CN") # Get prompt generation function based on task get_prompt_func = get_get_prompt_func(task=task_name) # Load items from dataset item_list = load_query_instances(dataset_path) question_list = [] # Generate questions for each item in the dataset for idx, item in enumerate(item_list): question = get_prompt_func(task_name=task_name, prompt_style = prompt_style, item=item) question_list.append({ "question": question, "gt_item": item, }) return question_list