from numpy import mean from mmengine import load from ..utils.format_load import format_load import itertools import networkx as nx import numpy as np import copy import re from tqdm import tqdm from ..schema import ResponseDataSample from sentence_transformers import SentenceTransformer, util class PlanningEvaluator: """Planning Evaluation Args: dataset_path(str): File path of evaluation dataset name_weight(float): the weight of action_name in bert_score match, default = 0.9 args_weight(float): the weight of action_args in bert_score match, default = 0.1 match_threshold(float): the threshold of matching match_strategy(str): matching method, can choose 'bertscore' or 'permutation' bert_score_model(str): the bert_score model for sentence similarity, default = "all-mpnet-base-v2". Refer to https://www.sbert.net/docs/pretrained_models.html for more models. """ def __init__( self, dataset_path: str, name_weight = 0.75, args_weight = 0.25, match_threshold = 0.7, match_strategy: str = 'bertscore', # ["bertscore", "permutation"] bert_score_model: str = "all-mpnet-base-v2", # ['thenlper/gte-large-zh', 'all-mpnet-base-v2'] default_prompt_type: str = 'json', # ["json", "ReWOO"] **kwargs, ) -> None: self.bert_score_model = bert_score_model print(bert_score_model) self.dataset_path = dataset_path self.name_weight = name_weight self.args_weight = args_weight self.match_threshold = match_threshold self.default_prompt_type = default_prompt_type # ["json", "ReWOO"] assert match_strategy in ["bertscore", "permutation"], f"match strategy must in [\"bertscore\", \"permutation\"], but get {match_strategy}" self.match_strategy = match_strategy self.valid_data_count = None self.sentence_model = SentenceTransformer(self.bert_score_model) def _load_dataset(self): self.dataset = [] dataset = load(self.dataset_path) total_error = 0 total_count = 0 for key in dataset.keys(): datum = dataset[key] data_sample, error = self._process_response(datum) total_error += error total_count += 1 self.dataset.append( dict(response_data_sample=data_sample)) self.num_samples = len(self.dataset) print("total_data_count:", total_count, "valid_data_count:", total_count - total_error) self.valid_data_count = total_count - total_error def format_load(self, data): r''' ensure evaluator can work correctly under any data input ''' try: json_format = format_load(data, start_character='[', end_character=']') except Exception as e: return [] if type(json_format) != list: return [] for i in range(len(json_format)): try: json_format[i] = { 'name': str(json_format[i]['name']), 'id': int(json_format[i]['id']), 'args': str(json_format[i]['args']) } except Exception as e: return [] return json_format def _process_response( self, datum, ) -> ResponseDataSample: """Process the response to needed format. Args: datum(dict): inputs. Returns: dict: Processed response data sample. """ # Generated response, which can be a string or list pred_data = datum['prediction'] # Response of ground truth, which can be a string or list gt_data = datum['ground_truth'] # prompt_type: The type of planning prompt, supporting "json" and "ReWOO" if "meta" in datum: prompt_type = datum["meta"].get("prompt_type", self.default_prompt_type) else: prompt_type = self.default_prompt_type error = 0 pred = dict() gt = dict() gt['planning'] = self.format_load(gt_data) if prompt_type == 'json': pred['planning'] = self.format_load(pred_data) if pred['planning'] == [] or gt['planning'] == []: error = 1 elif prompt_type == 'ReWOO': """ This type is deprecated The planning prediction data should in this format: Plan 1: description about the first action Dependency 1: the first action depends on which previous actions Action 1: #E1 = api_name1(args1) ... Which will be passed only if "number of plan lines == number of dependency lines == number of action lines" The passed data's format is: [ dict( id = i, name = curr_name, args = args_str ) ... ] The golden answer prediction is a json that is the same as the json format. """ thoughts = re.findall(r'(Plan [0-9]+: .+)', pred_data) dependencies = re.findall(r'(Dependency [0-9]+: .+)', pred_data) action_units = re.findall(r'Action [0-9]+: (.+)', pred_data) if not (len(thoughts) == len(dependencies) and len(thoughts) == len(action_units)): pred['planning'] = [] gt['planning'] = [] return ResponseDataSample(template = '', pred=pred, gt=gt), 1 plan_action = [] for i in range(len(action_units)): dependency_list = re.findall(r'Dependency [0-9]+: (.+)', dependencies[i]) if action_units[i][0] == '#': # The action has a return #E args_str_list = re.findall(r'#E[0-9]+ = .+\((.+)\)', action_units[i]) name_list = re.findall(r'#E[0-9]+ = (.+)\(', action_units[i]) else: # The action does not have a return args_str_list = re.findall(r'.+\((.+)\)', action_units[i]) name_list = re.findall(r'(.+)\(', action_units[i]) if (len(name_list) > 0): curr_name = name_list[0] else: curr_name = "" if (len(args_str_list) > 0): args_str = "{" + args_str_list[0] + "}" else: args_str = "{}" if (len(dependency_list) > 0): dependency_str = dependency_list[0] else: dependency_str = "" dependency = re.findall('([0-9]+)', dependency_str) dependency = list(set([int(x) - 1 for x in dependency])) plan_action.append( dict( id = i, name = curr_name, prev = dependency, args = args_str )) pred['planning'] = plan_action #Turn dict into args str for i in range(len(gt['planning'])): args_str = "" if type(gt['planning'][i]['args']) == str: args_dict = eval(gt['planning'][i]['args']) else: assert type(gt['planning'][i]['args']) == dict args_dict = gt['planning'][i]['args'] for it in args_dict: if args_str == "": args_str += f"{it}=\"{args_dict[it]}\"" else: args_str += f", {it}=\"{args_dict[it]}\"" gt['planning'][i]['args'] = '{' + args_str + '}' elif prompt_type == 'str': pred_data_format = pred_data.replace('. ', '\n').split('\n') pred_actions = [] for pred_step in pred_data_format: first_occur_time = 1e9 pred_action = "" for api_name in datum['meta']['API_list']: occur_time = pred_step.find(api_name) if occur_time != -1 and occur_time < first_occur_time: first_occur_time = occur_time pred_action = api_name if pred_action != "": pred_actions.append({ 'id': len(pred_actions), 'name': pred_action, 'args': pred_step }) pred['planning'] = pred_actions if len(pred['planning']) == 0: error = 1 else: raise NotImplementedError(f"Currently, we only support json and ReWOO format, but get {prompt_type}") return ResponseDataSample(template = '', pred=pred, gt=gt), error def _evaluate(self, data_sample) -> dict: if self.match_strategy == 'bertscore': metrics_result = self.bertscore_match( data_sample.pred['planning'], data_sample.gt['planning']) elif self.match_strategy == 'permutation': metrics_result = self.permutation_match( data_sample.pred['planning'], data_sample.gt['planning']) else: raise NotImplementedError if len(data_sample.pred['planning']) == 0 or len(data_sample.gt['planning']) == 0: metrics_result['parse_rate'] = 0 else: metrics_result['parse_rate'] = 1 return metrics_result def evaluate(self): self._load_dataset() results_list = [] for data_sample in tqdm(self.dataset): metrics_result = self._evaluate( data_sample['response_data_sample']) results_list.append(metrics_result) return self._post_process(results_list) def permutation_match(self, pred_plan, gt_plan) -> dict: ''' The function calculates all the permutation matches' score and selects the max f1_score; Since permutation is time consuming, we truncate the length of plans to 9 ''' if pred_plan[-1]['name'] != 'FinishAction': pred_plan.append( {'id': len(pred_plan), 'prev': [], 'name': 'FinishAction', 'args': r'\{\}'} ) if gt_plan[-1]['name'] != 'FinishAction': gt_plan.append( {'id': len(gt_plan), 'prev': [], 'name': 'FinishAction', 'args': r'\{\}'} ) # truncate plans to 9 since it is too long for permutation. if len(pred_plan) > 9: pred_plan = pred_plan[:9] if len(gt_plan) > 9: gt_plan = pred_plan[:9] pred_plan = sorted(pred_plan, key=lambda x: x['id']) gt_plan = sorted(gt_plan, key=lambda x: x['id']) len_pred = len(pred_plan) len_gt = len(gt_plan) map_id_max = max(len_pred, len_gt) numbers = [i for i in range(map_id_max)] perms = itertools.permutations(numbers, len_pred) gt_prev_count, pred_prev_count = 0, 0 for i in range(len_gt): gt_plan[i]['prev'].append(i) gt_prev_count += len(gt_plan[i]['prev']) for i in range(len_pred): pred_plan[i]['prev'].append(i) pred_prev_count += len(pred_plan[i]['prev']) if gt_prev_count == 0 or pred_prev_count == 0: return { 'precision': 0, 'recall': 0, 'f1_score': 0 } max_recall, max_precision, max_f1 = 0, 0, 0 for perm in perms: correct_count = 0 for i in range(len_pred): if perm[i] >= len_gt: continue for j in pred_plan[i]['prev']: if perm[j] in gt_plan[perm[i]]['prev']: correct_count += 1 now_recall, now_precision = correct_count / gt_prev_count, correct_count / pred_prev_count if now_recall + now_precision == 0: continue now_f1 = 2 * now_recall * now_precision / (now_recall + now_precision) if now_f1 > max_f1: max_f1, max_recall, max_precision = now_f1, now_recall, now_precision return { 'precision': max_precision, 'recall': max_recall, 'f1_score': max_f1 } def bertscore_match(self, pred_plan, gt_plan) -> dict: """ Calculate the similarity between predicted plan and golden answer, A plan can be regarded a sequence of actions, and each action has a name and args. Firstly, use bertscore to calculate pointwise similarity by: similarity(u, v) = bertscore(u.name, v.name) * name_weight + bertscore(u.args, v.args) * args_weight; Secondly, use Hungarian matching to match the points; Finally, use LIS to calculate the number of matched nodes. """ if len(pred_plan) == 0 or len(gt_plan) == 0: return { 'precision': 0, 'recall': 0, 'f1_score': 0 } pred_plan = copy.deepcopy(sorted(pred_plan, key=lambda x: x['id'])) gt_plan = copy.deepcopy(sorted(gt_plan, key=lambda x: x['id'])) #Add end action #Currently it is hard-code if pred_plan[-1]['name'] == 'FinishAction': pred_plan = pred_plan[:-1] if gt_plan[-1]['name'] == 'FinishAction': gt_plan = gt_plan[:-1] #The total counts of nodes and edges. len_pred = len(pred_plan) len_gt = len(gt_plan) bert_score_matrix = np.zeros((len_pred, len_gt)) name_pred, args_pred = [], [] name_gt, args_gt = [], [] for i in range(len_pred): name_pred.append(pred_plan[i]['name']) args_pred.append(str(pred_plan[i]['args'])) for i in range(len_gt): name_gt.append(gt_plan[i]['name']) args_gt.append(str(gt_plan[i]['args'])) name_pred_emb = self.sentence_model.encode(name_pred, convert_to_tensor=True) name_gt_emb = self.sentence_model.encode(name_gt, convert_to_tensor=True) args_pred_emb = self.sentence_model.encode(args_pred, convert_to_tensor=True) args_gt_emb = self.sentence_model.encode(args_gt, convert_to_tensor=True) name_cosine_scores = np.maximum(util.cos_sim(name_pred_emb, name_gt_emb).cpu().numpy(), 0) args_cosine_scores = np.maximum(util.cos_sim(args_pred_emb, args_gt_emb).cpu().numpy(), 0) for i in range(len_pred): for j in range(len_gt): bert_score_matrix[i][j] = \ name_cosine_scores[i][j] * self.name_weight \ + args_cosine_scores[i][j] * self.args_weight G = nx.Graph() for i in range(len_pred): for j in range(len_gt): if bert_score_matrix[i][j] > self.match_threshold: G.add_edge(i, str(j), weight=bert_score_matrix[i][j]) max_weight_matching = nx.max_weight_matching(G) pred_to_gt_mapping = dict() for key in max_weight_matching: if type(key[0]) == int: pred_to_gt_mapping[int(key[0])] = int(key[1]) else: pred_to_gt_mapping[int(key[1])] = int(key[0]) #If a prediction node does not match any golden answer node, we mark the node as -1. for i in range(len_pred): if i not in pred_to_gt_mapping: pred_to_gt_mapping[i] = -1 #Calculate how many nodes are matched by Longest Increasing Subsequence (LIS) dp = np.ones(len_pred) for i in range(len_pred): for j in range(i): if pred_to_gt_mapping[i] == -1 or pred_to_gt_mapping[j] == -1: continue if pred_to_gt_mapping[i] > pred_to_gt_mapping[j]: dp[i] = max(dp[i], dp[j] + 1) correct_count = int(max(dp)) recall, precision = correct_count / len(gt_plan), correct_count / len(pred_plan) f1_score = 2 * recall * precision / (recall + precision) result = { 'precision': precision, 'recall': recall, 'f1_score': f1_score } return result def _post_process(self, results_list): # list of dict to dict of list results = dict() planning_metric_keys = ["precision", "recall", "f1_score", 'parse_rate'] for key in planning_metric_keys: results[key] = mean([result[key] for result in results_list]) return results