OpenCompass/opencompass/datasets/calm/data_processing/generate_questions.py
Peng Bo edd0ffdf70
Calm dataset (#1287)
* add calm dataset

* modify config max_out_len

* update README

* Modify README

* update README

* update README

* update README

* update README

* update README

* add summarizer and modify readme

* delete summarizer config comment

* update summarizer

* modify same response to all questions

* update README
2024-07-26 11:48:16 +08:00

193 lines
7.6 KiB
Python

import importlib
from pathlib import Path
import json
import os
from ..utils.load_items import load_query_instances
def get_get_prompt_func(task):
"""
Returns the appropriate prompt generation function based on the given task.
Args:
task (str): The name of the task for which the prompt function is required.
Returns:
function: The prompt generation function for the specified task.
Raises:
NotImplementedError: If no prompt function is found for the given task.
"""
task_to_module_map = {
# association/
# correlation/
"CORR-B_correlation_CN":"CORR-B_correlation",
"CORR-B_correlation_EN":"CORR-B_correlation",
# explaining_away_effect/
"EAE-B_exp-away_CN":"EAE-B_exp-away",
"EAE-B_exp-away_EN":"EAE-B_exp-away",
# causal_discovery/
# abstract_reasoning/
"AR-B_CaLM-AR_CN":"AR-B_CaLM-AR",
"AR-B_CaLM-AR_EN":"AR-B_CaLM-AR",
# causal_attribution/
"CA-B_FA_CN":"CA-B_FA",
"CA-B_FA_EN":"CA-B_FA",
"CA-B_FP_CN":"CA-B_FP",
"CA-B_FP_EN":"CA-B_FP",
# event_causality_identification/
"ECI-B_CTB_CN":"ECI-B_CTB",
"ECI-B_CTB_EN":"ECI-B_CTB",
"ECI-B_ESC_CN":"ECI-B_ESC",
"ECI-B_ESC_EN":"ECI-B_ESC",
"ECI-B_MAVEN-ERE_CN":"ECI-B_MAVEN-ERE",
"ECI-B_MAVEN-ERE_EN":"ECI-B_MAVEN-ERE",
# pairwise_causal_discovery/
"PCD-B_COPA_CN":"PCD-B_COPA",
"PCD-B_COPA_EN":"PCD-B_COPA",
"PCD-B_E-CARE_CN":"PCD-B_E-CARE",
"PCD-B_E-CARE_EN":"PCD-B_E-CARE",
"PCD-C_COPA_CN":"PCD-C_COPA",
"PCD-C_COPA_EN":"PCD-C_COPA",
"PCD-C_E-CARE_CN":"PCD-C_E-CARE",
"PCD-C_E-CARE_EN":"PCD-C_E-CARE",
# counterfactual/
# actual_causality/
"AC-B_causal_judgement_CN":"AC-B_causal_judgement",
"AC-B_causal_judgement_EN":"AC-B_causal_judgement",
# causal_explanation_generation/
"CEG-O_E-CARE_CN":"CEG-O_E-CARE",
"CEG-O_E-CARE_EN":"CEG-O_E-CARE",
# counterfactual_reasoning/
"CR-B_det-counterfactual_CN":"CR-B_det-counterfactual",
"CR-B_det-counterfactual_EN":"CR-B_det-counterfactual",
"CR-C_CRASS_CN":"CR-C_CRASS",
"CR-C_CRASS_EN":"CR-C_CRASS",
# effect_of_the_treatment_on_the_treated/
"ETT-B_ETT-natural_CN":"ETT",
"ETT-B_ETT-natural_EN":"ETT",
"ETT-P_ETT-basic_CN":"ETT",
"ETT-P_ETT-basic_EN":"ETT",
"ETT-P_ETT-hard_CN":"ETT",
"ETT-P_ETT-hard_EN":"ETT",
# natural_direct_effect/
"NDE-B_NDE-natural_CN":"NDE",
"NDE-B_NDE-natural_EN":"NDE",
"NDE-P_NDE-basic_CN":"NDE",
"NDE-P_NDE-basic_EN":"NDE",
"NDE-P_NDE-hard_CN":"NDE",
"NDE-P_NDE-hard_EN":"NDE",
# natural_indirect_effect/
"NIE-B_NIE-natural_CN":"NIE",
"NIE-B_NIE-natural_EN":"NIE",
"NIE-P_NIE-basic_CN":"NIE",
"NIE-P_NIE-basic_EN":"NIE",
"NIE-P_NIE-hard_CN":"NIE",
"NIE-P_NIE-hard_EN":"NIE",
# probability_of_necessity/
"PN-P_PN-basic_CN":"PN",
"PN-P_PN-basic_EN":"PN",
"PN-P_PN-hard_CN":"PN",
"PN-P_PN-hard_EN":"PN",
# probability_of_sufficiency/
"PS-P_PS-basic_CN":"PS",
"PS-P_PS-basic_EN":"PS",
"PS-P_PS-hard_CN":"PS",
"PS-P_PS-hard_EN":"PS",
# intervention/
# average_treatment_effect/
"ATE-B_ATE-natural_CN":"ATE",
"ATE-B_ATE-natural_EN":"ATE",
"ATE-P_ATE-basic_CN":"ATE",
"ATE-P_ATE-basic_EN":"ATE",
"ATE-P_ATE-hard_CN":"ATE",
"ATE-P_ATE-hard_EN":"ATE",
# backdoor_adjustment_set/
"BAS-B_backadj_CN":"BAS-B_backadj",
"BAS-B_backadj_EN":"BAS-B_backadj",
"BAS-C_max-BAS_CN":"BAS-C_max-BAS",
"BAS-C_max-BAS_EN":"BAS-C_max-BAS",
"BAS-C_min-BAS_CN":"BAS-C_min-BAS",
"BAS-C_min-BAS_EN":"BAS-C_min-BAS",
"BAS-C_mix-BAS_CN":"BAS-C_mix-BAS",
"BAS-C_mix-BAS_EN":"BAS-C_mix-BAS",
# causal_effect_identification/
"CEI-B_0.2-UC_CN":"CEI-B",
"CEI-B_0.2-UC_EN":"CEI-B",
"CEI-B_0.4-UC_CN":"CEI-B",
"CEI-B_0.4-UC_EN":"CEI-B",
"CEI-B_0.6-UC_CN":"CEI-B",
"CEI-B_0.6-UC_EN":"CEI-B",
"CEI-B_0.8-UC_CN":"CEI-B",
"CEI-B_0.8-UC_EN":"CEI-B",
# collider_bias/
"CB-B_collider-bias_CN":"CB-B_collider-bias",
"CB-B_collider-bias_EN":"CB-B_collider-bias",
# controlled_direct_effect/
"CDE-B_CDE-natural_CN":"CDE",
"CDE-B_CDE-natural_EN":"CDE",
"CDE-P_CDE-basic_CN":"CDE",
"CDE-P_CDE-basic_EN":"CDE",
"CDE-P_CDE-hard_CN":"CDE",
"CDE-P_CDE-hard_EN":"CDE",
# frontdoor_adjustment_set/
"FAS-C_FAS_CN":"FAS-C_FAS",
"FAS-C_FAS_EN":"FAS-C_FAS",
# instrumental_variable/
"IV-C_CaLM-IV_CN":"IV-C_CaLM-IV",
"IV-C_CaLM-IV_EN":"IV-C_CaLM-IV",
}
module_name = task_to_module_map.get(task)
if module_name:
module = importlib.import_module("opencompass.datasets.calm.data_processing.prompt."+module_name)
return module.get_prompt
else:
raise NotImplementedError(f"No get_prompt function found for task {task}.")
def generate_question_list(dataset_path, prompt_style):
"""
Generates a list of questions from the dataset based on the specified prompt style.
Args:
dataset_path (str): The path to the dataset JSON file.
prompt_style (str): The style of prompt to be used for generating questions.
Returns:
list: A list of question dictionaries, each containing an item from the dataset along with its corresponding question.
Raises:
AssertionError: If the task name and prompt style do not match the expected language suffix.
"""
# Extract task name from dataset path
dataset_path = Path(dataset_path)
task_name = dataset_path.name[:-len(".json")]
# Validate prompt style based on task language
if task_name.endswith("CN"):
assert prompt_style.endswith("-CN")
else:
assert not prompt_style.endswith("-CN")
# Get prompt generation function based on task
get_prompt_func = get_get_prompt_func(task=task_name)
# Load items from dataset
item_list = load_query_instances(dataset_path)
question_list = []
# Generate questions for each item in the dataset
for idx, item in enumerate(item_list):
question = get_prompt_func(task_name=task_name, prompt_style = prompt_style,
item=item)
question_list.append({
"question": question,
"gt_item": item,
})
return question_list