From 4dd9a3fc10b24babcc49bda757ab4f9b7781fa80 Mon Sep 17 00:00:00 2001 From: Leymore Date: Wed, 18 Oct 2023 23:37:35 -0500 Subject: [PATCH] [Sync] sync with internal codes 20231019 (#488) --- .pre-commit-config-zh-cn.yaml | 4 +- .../lawbench/evaluation_functions/cjft.py | 38 +- .../lawbench/evaluation_functions/flzx.py | 36 +- .../lawbench/evaluation_functions/ftcs.py | 38 +- .../lawbench/evaluation_functions/jdzy.py | 72 +- .../lawbench/evaluation_functions/jec_ac.py | 58 +- .../lawbench/evaluation_functions/jec_kd.py | 58 +- .../lawbench/evaluation_functions/jetq.py | 86 +- .../lawbench/evaluation_functions/lblj.py | 58 +- .../evaluation_functions/ljp_accusation.py | 152 +-- .../evaluation_functions/ljp_article.py | 140 +-- .../evaluation_functions/ljp_imprison.py | 98 +- .../lawbench/evaluation_functions/sjjc.py | 128 +-- .../lawbench/evaluation_functions/wbfl.py | 84 +- .../lawbench/evaluation_functions/wsjd.py | 100 +- .../lawbench/evaluation_functions/xxcq.py | 34 +- .../lawbench/evaluation_functions/ydlj.py | 34 +- .../lawbench/evaluation_functions/yqzy.py | 36 +- .../lawbench/evaluation_functions/zxfl.py | 54 +- .../datasets/lawbench/utils/char_smi.py | 910 +++++++++--------- .../utils/compare_m2_for_evaluation.py | 866 ++++++++--------- .../datasets/lawbench/utils/function_utils.py | 98 +- .../lawbench/utils/modules/alignment.py | 666 ++++++------- .../lawbench/utils/modules/annotator.py | 152 +-- .../lawbench/utils/modules/classifier.py | 302 +++--- .../datasets/lawbench/utils/modules/merger.py | 544 +++++------ .../lawbench/utils/modules/tokenizer.py | 184 ++-- .../datasets/lawbench/utils/parallel_to_m2.py | 442 ++++----- opencompass/runners/dlc.py | 117 ++- 29 files changed, 2833 insertions(+), 2756 deletions(-) diff --git a/.pre-commit-config-zh-cn.yaml b/.pre-commit-config-zh-cn.yaml index e114dc61..6b9be079 100644 --- a/.pre-commit-config-zh-cn.yaml +++ b/.pre-commit-config-zh-cn.yaml @@ -3,7 +3,9 @@ exclude: | tests/data/| opencompass/models/internal/| opencompass/utils/internal/| - opencompass/openicl/icl_evaluator/hf_metrics/ + opencompass/openicl/icl_evaluator/hf_metrics/| + opencompass/datasets/lawbench/utils| + opencompass/datasets/lawbench/evaluation_functions/ ) repos: - repo: https://gitee.com/openmmlab/mirrors-flake8 diff --git a/opencompass/datasets/lawbench/evaluation_functions/cjft.py b/opencompass/datasets/lawbench/evaluation_functions/cjft.py index bf149db8..71d6c1dd 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/cjft.py +++ b/opencompass/datasets/lawbench/evaluation_functions/cjft.py @@ -1,19 +1,19 @@ -from ..utils.function_utils import compute_rouge - -#情景法条识别 - -def compute_cjft(data_dict): - """ - Compute the ROUGE-L score between the prediction and the reference - """ - references, predictions = [], [] - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - predictions.append(prediction) - references.append(answer) - - # compute the accuracy of score_list - rouge_scores = compute_rouge(predictions, references) - rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores] - average_rouge_l = sum(rouge_ls) / len(rouge_ls) - return {"score": average_rouge_l} +from ..utils.function_utils import compute_rouge + +#情景法条识别 + +def compute_cjft(data_dict): + """ + Compute the ROUGE-L score between the prediction and the reference + """ + references, predictions = [], [] + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + predictions.append(prediction) + references.append(answer) + + # compute the accuracy of score_list + rouge_scores = compute_rouge(predictions, references) + rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores] + average_rouge_l = sum(rouge_ls) / len(rouge_ls) + return {"score": average_rouge_l} diff --git a/opencompass/datasets/lawbench/evaluation_functions/flzx.py b/opencompass/datasets/lawbench/evaluation_functions/flzx.py index 9d0f6ec7..376c7733 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/flzx.py +++ b/opencompass/datasets/lawbench/evaluation_functions/flzx.py @@ -1,18 +1,18 @@ -from ..utils.function_utils import compute_rouge - -#法律咨询 -def compute_flzx(data_dict): - """ - Compute the ROUGE-L score between the prediction and the reference - """ - references, predictions = [], [] - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - predictions.append(prediction) - references.append(answer) - - # compute the accuracy of score_list - rouge_scores = compute_rouge(predictions, references) - rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores] - average_rouge_l = sum(rouge_ls) / len(rouge_ls) - return {"score": average_rouge_l} +from ..utils.function_utils import compute_rouge + +#法律咨询 +def compute_flzx(data_dict): + """ + Compute the ROUGE-L score between the prediction and the reference + """ + references, predictions = [], [] + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + predictions.append(prediction) + references.append(answer) + + # compute the accuracy of score_list + rouge_scores = compute_rouge(predictions, references) + rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores] + average_rouge_l = sum(rouge_ls) / len(rouge_ls) + return {"score": average_rouge_l} diff --git a/opencompass/datasets/lawbench/evaluation_functions/ftcs.py b/opencompass/datasets/lawbench/evaluation_functions/ftcs.py index 5b21b632..009099e7 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/ftcs.py +++ b/opencompass/datasets/lawbench/evaluation_functions/ftcs.py @@ -1,19 +1,19 @@ -from ..utils.function_utils import compute_rouge - -#法条记忆问答 -def compute_ftcs(data_dict): - """ - Compute the ROUGE-L score between the prediction and the reference - """ - references, predictions = [], [] - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - answer = answer.replace("答案:", "") - predictions.append(prediction) - references.append(answer) - - # compute the accuracy of score_list - rouge_scores = compute_rouge(predictions, references) - rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores] - average_rouge_l = sum(rouge_ls) / len(rouge_ls) - return {"score": average_rouge_l} +from ..utils.function_utils import compute_rouge + +#法条记忆问答 +def compute_ftcs(data_dict): + """ + Compute the ROUGE-L score between the prediction and the reference + """ + references, predictions = [], [] + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + answer = answer.replace("答案:", "") + predictions.append(prediction) + references.append(answer) + + # compute the accuracy of score_list + rouge_scores = compute_rouge(predictions, references) + rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores] + average_rouge_l = sum(rouge_ls) / len(rouge_ls) + return {"score": average_rouge_l} diff --git a/opencompass/datasets/lawbench/evaluation_functions/jdzy.py b/opencompass/datasets/lawbench/evaluation_functions/jdzy.py index 498df762..1129d58b 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/jdzy.py +++ b/opencompass/datasets/lawbench/evaluation_functions/jdzy.py @@ -1,36 +1,36 @@ -from ..utils.function_utils import multi_choice_judge - -""" -multi-choice single-label selection -metric: accuracy -争议焦点:识别案件涉及的争议焦点 -""" - -def compute_jdzy(data_dict): - """ - Compute the Accuracy - The JEC dataset has 16 possible answers for each question, stored in the option_list - A prediction is correct if - 1. The correct answer appears in the prediction, and - 2. Options other than the answer do not appear in the prediction. - """ - - score_list, abstentions = [], 0 - option_list = ["诉讼主体", "租金情况", "利息", "本金争议", "责任认定", "责任划分", "损失认定及处理", - "原审判决是否适当", "合同效力", "财产分割", "责任承担", "鉴定结论采信问题", "诉讼时效", "违约", "合同解除", "肇事逃逸"] - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - if answer[7:-1] == "赔偿": - # todo: dataset imperfection - continue - assert answer.startswith("争议焦点类别:") and answer[7:-1] in option_list, \ - f"answer: {answer} \n question: {question}" - - answer_letter = answer[7:-1] - judge = multi_choice_judge(prediction, option_list, answer_letter) - score_list.append(judge["score"]) - abstentions += judge["abstention"] - - # compute the accuracy of score_list - accuracy = sum(score_list) / len(score_list) - return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)} +from ..utils.function_utils import multi_choice_judge + +""" +multi-choice single-label selection +metric: accuracy +争议焦点:识别案件涉及的争议焦点 +""" + +def compute_jdzy(data_dict): + """ + Compute the Accuracy + The JEC dataset has 16 possible answers for each question, stored in the option_list + A prediction is correct if + 1. The correct answer appears in the prediction, and + 2. Options other than the answer do not appear in the prediction. + """ + + score_list, abstentions = [], 0 + option_list = ["诉讼主体", "租金情况", "利息", "本金争议", "责任认定", "责任划分", "损失认定及处理", + "原审判决是否适当", "合同效力", "财产分割", "责任承担", "鉴定结论采信问题", "诉讼时效", "违约", "合同解除", "肇事逃逸"] + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + if answer[7:-1] == "赔偿": + # todo: dataset imperfection + continue + assert answer.startswith("争议焦点类别:") and answer[7:-1] in option_list, \ + f"answer: {answer} \n question: {question}" + + answer_letter = answer[7:-1] + judge = multi_choice_judge(prediction, option_list, answer_letter) + score_list.append(judge["score"]) + abstentions += judge["abstention"] + + # compute the accuracy of score_list + accuracy = sum(score_list) / len(score_list) + return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)} diff --git a/opencompass/datasets/lawbench/evaluation_functions/jec_ac.py b/opencompass/datasets/lawbench/evaluation_functions/jec_ac.py index 45a7f0f6..f6c98ad7 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/jec_ac.py +++ b/opencompass/datasets/lawbench/evaluation_functions/jec_ac.py @@ -1,29 +1,29 @@ -from ..utils.function_utils import multi_choice_judge - -""" -Task: multi-choice selection -Metric: Accuracy -司法考试-案例分析 -""" -def compute_jec_ac(data_dict): - """ - Compute the Accuracy - The JEC dataset has 4 options for each question: A, B, C, D - A prediction is correct if - 1. The correct answer appears in the prediction, and - 2. Options other than the answer do not appear in the prediction. - """ - score_list, abstentions = [], 0 - option_list = ["A", "B", "C", "D"] - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - assert answer.startswith("正确答案:") and answer[5] in option_list, f"answer[5]: {answer}, question: {question}" - - answer_letter = answer[5] - judge = multi_choice_judge(prediction, option_list, answer_letter) - score_list.append(judge["score"]) - abstentions += judge["abstention"] - - # compute the accuracy of score_list - accuracy = sum(score_list) / len(score_list) - return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)} +from ..utils.function_utils import multi_choice_judge + +""" +Task: multi-choice selection +Metric: Accuracy +司法考试-案例分析 +""" +def compute_jec_ac(data_dict): + """ + Compute the Accuracy + The JEC dataset has 4 options for each question: A, B, C, D + A prediction is correct if + 1. The correct answer appears in the prediction, and + 2. Options other than the answer do not appear in the prediction. + """ + score_list, abstentions = [], 0 + option_list = ["A", "B", "C", "D"] + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + assert answer.startswith("正确答案:") and answer[5] in option_list, f"answer[5]: {answer}, question: {question}" + + answer_letter = answer[5] + judge = multi_choice_judge(prediction, option_list, answer_letter) + score_list.append(judge["score"]) + abstentions += judge["abstention"] + + # compute the accuracy of score_list + accuracy = sum(score_list) / len(score_list) + return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)} diff --git a/opencompass/datasets/lawbench/evaluation_functions/jec_kd.py b/opencompass/datasets/lawbench/evaluation_functions/jec_kd.py index f68dfad1..3afe4ef9 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/jec_kd.py +++ b/opencompass/datasets/lawbench/evaluation_functions/jec_kd.py @@ -1,29 +1,29 @@ -from ..utils.function_utils import multi_choice_judge - -""" -Task: multi-choice selection -Metric: Accuracy -司法考试 -""" -def compute_jec_kd(data_dict): - """ - Compute the Accuracy - The JEC_KD dataset has 4 options for each question: A, B, C, D - A prediction is correct if - 1. The correct answer appears in the prediction, and - 2. Options other than the answer do not appear in the prediction. - """ - score_list, abstentions = [], 0 - option_list = ["A", "B", "C", "D"] - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - assert answer.startswith("正确答案:") and answer[5] in option_list, f"answer[5]: {answer}, question: {question}" - - answer_letter = answer[5] - judge = multi_choice_judge(prediction, option_list, answer_letter) - score_list.append(judge["score"]) - abstentions += judge["abstention"] - - # compute the accuracy of score_list - accuracy = sum(score_list) / len(score_list) - return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)} +from ..utils.function_utils import multi_choice_judge + +""" +Task: multi-choice selection +Metric: Accuracy +司法考试 +""" +def compute_jec_kd(data_dict): + """ + Compute the Accuracy + The JEC_KD dataset has 4 options for each question: A, B, C, D + A prediction is correct if + 1. The correct answer appears in the prediction, and + 2. Options other than the answer do not appear in the prediction. + """ + score_list, abstentions = [], 0 + option_list = ["A", "B", "C", "D"] + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + assert answer.startswith("正确答案:") and answer[5] in option_list, f"answer[5]: {answer}, question: {question}" + + answer_letter = answer[5] + judge = multi_choice_judge(prediction, option_list, answer_letter) + score_list.append(judge["score"]) + abstentions += judge["abstention"] + + # compute the accuracy of score_list + accuracy = sum(score_list) / len(score_list) + return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)} diff --git a/opencompass/datasets/lawbench/evaluation_functions/jetq.py b/opencompass/datasets/lawbench/evaluation_functions/jetq.py index 48b4afab..936de7c5 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/jetq.py +++ b/opencompass/datasets/lawbench/evaluation_functions/jetq.py @@ -1,43 +1,43 @@ -import re - -""" -number prediction -metric: accuracy -金额提取 -""" -def compute_jetq(data_dict): - """ - Compute the Accuracy - we extract the total amount of cost involved in the crime from the prediction and compare it with the reference - The prediction is correct if - the total amount of cost provided in the reference, appears in the prediction. - """ - score_list, abstentions = [], 0 - - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - assert answer.startswith("上文涉及到的犯罪金额:"), f"answer: {answer}, question: {question}" - assert answer.endswith("元。"), f"answer: {answer}, question: {question}" - answer = answer.replace("上文涉及到的犯罪金额:", "") - - assert "千元" not in answer, f"answer: {answer}, question: {question}" - assert "万" not in answer, f"answer: {answer}, question: {question}" - - # remove "元" - answer = answer.replace("元。", "") - answer = float(answer) - - prediction_digits = re.findall(r"\d+\.?\d*", prediction) - prediction_digits = [float(digit) for digit in prediction_digits] - - if len(prediction_digits) == 0: - abstentions += 1 - if answer in prediction_digits: - score_list.append(1) - else: - score_list.append(0) - - - # compute the accuracy of score_list - accuracy = sum(score_list) / len(score_list) - return {"score": accuracy, "abstention_rate": abstentions/len(data_dict)} +import re + +""" +number prediction +metric: accuracy +金额提取 +""" +def compute_jetq(data_dict): + """ + Compute the Accuracy + we extract the total amount of cost involved in the crime from the prediction and compare it with the reference + The prediction is correct if + the total amount of cost provided in the reference, appears in the prediction. + """ + score_list, abstentions = [], 0 + + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + assert answer.startswith("上文涉及到的犯罪金额:"), f"answer: {answer}, question: {question}" + assert answer.endswith("元。"), f"answer: {answer}, question: {question}" + answer = answer.replace("上文涉及到的犯罪金额:", "") + + assert "千元" not in answer, f"answer: {answer}, question: {question}" + assert "万" not in answer, f"answer: {answer}, question: {question}" + + # remove "元" + answer = answer.replace("元。", "") + answer = float(answer) + + prediction_digits = re.findall(r"\d+\.?\d*", prediction) + prediction_digits = [float(digit) for digit in prediction_digits] + + if len(prediction_digits) == 0: + abstentions += 1 + if answer in prediction_digits: + score_list.append(1) + else: + score_list.append(0) + + + # compute the accuracy of score_list + accuracy = sum(score_list) / len(score_list) + return {"score": accuracy, "abstention_rate": abstentions/len(data_dict)} diff --git a/opencompass/datasets/lawbench/evaluation_functions/lblj.py b/opencompass/datasets/lawbench/evaluation_functions/lblj.py index 0bc20e24..7675ec99 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/lblj.py +++ b/opencompass/datasets/lawbench/evaluation_functions/lblj.py @@ -1,29 +1,29 @@ -from ..utils.function_utils import multi_choice_judge - -""" -Task: multi-choice selection -Metric: Accuracy -论辩挖掘 -""" -def compute_lblj(data_dict): - """ - Compute the Accuracy - The LBLJ dataset has 5 options for each question: A, B, C, D, E - A prediction is correct if - 1. The correct answer appears in the prediction, and - 2. Options other than the answer do not appear in the prediction. - """ - score_list, abstentions = [], 0 - option_list = ["A", "B", "C", "D", "E"] - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - assert answer.startswith("[正确答案]") and answer[6] in option_list, f"answer[6]: {answer}, question: {question}" - - answer_letter = answer[6] - judge = multi_choice_judge(prediction, option_list, answer_letter) - score_list.append(judge["score"]) - abstentions += judge["abstention"] - - # compute the accuracy of score_list - accuracy = sum(score_list) / len(score_list) - return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)} +from ..utils.function_utils import multi_choice_judge + +""" +Task: multi-choice selection +Metric: Accuracy +论辩挖掘 +""" +def compute_lblj(data_dict): + """ + Compute the Accuracy + The LBLJ dataset has 5 options for each question: A, B, C, D, E + A prediction is correct if + 1. The correct answer appears in the prediction, and + 2. Options other than the answer do not appear in the prediction. + """ + score_list, abstentions = [], 0 + option_list = ["A", "B", "C", "D", "E"] + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + assert answer.startswith("[正确答案]") and answer[6] in option_list, f"answer[6]: {answer}, question: {question}" + + answer_letter = answer[6] + judge = multi_choice_judge(prediction, option_list, answer_letter) + score_list.append(judge["score"]) + abstentions += judge["abstention"] + + # compute the accuracy of score_list + accuracy = sum(score_list) / len(score_list) + return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)} diff --git a/opencompass/datasets/lawbench/evaluation_functions/ljp_accusation.py b/opencompass/datasets/lawbench/evaluation_functions/ljp_accusation.py index 93690a9f..dc16a7c4 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/ljp_accusation.py +++ b/opencompass/datasets/lawbench/evaluation_functions/ljp_accusation.py @@ -1,76 +1,76 @@ -from ..utils.function_utils import compute_f1_two_sets -""" -task: legal accusation prediction -metric: f1 score -法律判决预测-罪名预测 -""" - -option_list = ["侮辱", "违法发放贷款", "失火", "票据诈骗", "帮助犯罪分子逃避处罚", "重大责任事故", "对非国家工作人员行贿", - "非法制造、销售非法制造的注册商标标识", "非法制造、买卖、运输、邮寄、储存枪支、弹药、爆炸物", "非法获取公民个人信息", - "扰乱无线电通讯管理秩序", "非法持有、私藏枪支、弹药", "拒不执行判决、裁定", "虚开发票", "巨额财产来源不明", - "组织、领导、参加黑社会性质组织", "非法获取国家秘密", "以危险方法危害公共安全", "非法持有毒品", - "聚众扰乱公共场所秩序、交通秩序", "包庇毒品犯罪分子", "滥伐林木", "伪造公司、企业、事业单位、人民团体印章", - "非法占用农用地", "走私废物", "串通投标", "非法采伐、毁坏国家重点保护植物", "冒充军人招摇撞骗", "玩忽职守", - "重婚", "招收公务员、学生徇私舞弊", "组织、领导传销活动", "非法猎捕、杀害珍贵、濒危野生动物", "侵犯著作权", - "非法种植毒品原植物", "伪造、变造、买卖武装部队公文、证件、印章", "倒卖文物", "伪造、变造居民身份证", "滥用职权", - "诽谤", "猥亵儿童", "非法转让、倒卖土地使用权", "挪用公款", "污染环境", "出售、购买、运输假币", "敲诈勒索", - "高利转贷", "故意伤害", "持有、使用假币", "单位受贿", "强奸", "引诱、容留、介绍卖淫", "虐待", - "生产、销售伪劣农药、兽药、化肥、种子", "妨害公务", "容留他人吸毒", "拐骗儿童", "强制猥亵、侮辱妇女", - "非法处置查封、扣押、冻结的财产", "骗取贷款、票据承兑、金融票证", "强迫他人吸毒", "非法拘禁", - "非法携带枪支、弹药、管制刀具、危险物品危及公共安全", "绑架", "聚众斗殴", "破坏计算机信息系统", - "制造、贩卖、传播淫秽物品", "虐待被监管人", "贷款诈骗", "赌博", "徇私舞弊不征、少征税款", - "盗窃、抢夺枪支、弹药、爆炸物、危险物质", "故意杀人", "介绍贿赂", "提供侵入、非法控制计算机信息系统程序、工具", - "编造、故意传播虚假恐怖信息", "妨害作证", "强迫卖淫", "走私、贩卖、运输、制造毒品", "伪证", "拐卖妇女、儿童", - "过失损坏武器装备、军事设施、军事通信", "破坏广播电视设施、公用电信设施", "洗钱", "职务侵占", "倒卖车票、船票", - "抢劫", "侵占", "掩饰、隐瞒犯罪所得、犯罪所得收益", "徇私舞弊不移交刑事案件", "引诱、教唆、欺骗他人吸毒", "遗弃", - "生产、销售伪劣产品", "放火", "非法采矿", "对单位行贿", "盗窃、抢夺枪支、弹药、爆炸物", "破坏易燃易爆设备", - "妨害信用卡管理", "制作、复制、出版、贩卖、传播淫秽物品牟利", "金融凭证诈骗", "私分国有资产", - "走私国家禁止进出口的货物、物品", "假冒注册商标", "危险物品肇事", "走私普通货物、物品", "经济犯", "虚报注册资本", - "盗掘古文化遗址、古墓葬", "传播淫秽物品", "窝藏、包庇", "拒不支付劳动报酬", "行贿", "开设赌场", "传授犯罪方法", - "协助组织卖淫", "保险诈骗", "破坏生产经营", "破坏交通设施", "打击报复证人", "非法侵入住宅", "非国家工作人员受贿", - "过失致人重伤", "伪造、变造金融票证", "窝藏、转移、隐瞒毒品、毒赃", "帮助毁灭、伪造证据", "走私珍贵动物、珍贵动物制品", - "生产、销售假药", "逃税", "挪用特定款物", "聚众扰乱社会秩序", "组织、强迫、引诱、容留、介绍卖淫", "合同诈骗", - "非法生产、销售间谍专用器材", "破坏交通工具", "传播性病", "强迫交易", "隐匿、故意销毁会计凭证、会计帐簿、财务会计报告", - "非法组织卖血", "强迫劳动", "破坏电力设备", "销售假冒注册商标的商品", "收买被拐卖的妇女、儿童", "诬告陷害", "脱逃", - "非法经营", "徇私枉法", "信用卡诈骗", "生产、销售不符合安全标准的食品", "非法行医", "伪造货币", "动植物检疫徇私舞弊", - "单位行贿", "破坏监管秩序", "盗窃", "盗伐林木", "重大劳动安全事故", "非法吸收公众存款", - "非法制造、出售非法制造的发票", "非法狩猎", "组织卖淫", "非法买卖、运输、携带、持有毒品原植物种子、幼苗", "挪用资金", - "诈骗", "伪造、变造、买卖国家机关公文、证件、印章", "持有伪造的发票", "贪污", "非法生产、买卖警用装备", - "投放危险物质", "伪造、倒卖伪造的有价票证", "集资诈骗", "抢夺", "生产、销售有毒、有害食品", "非法捕捞水产品", - "过失致人死亡", "非法买卖制毒物品", "虚开增值税专用发票、用于骗取出口退税、抵扣税款发票", "寻衅滋事", "危险驾驶", - "故意毁坏财物", "招摇撞骗", "盗窃、侮辱尸体", "走私武器、弹药", - "非法收购、运输、加工、出售国家重点保护植物、国家重点保护植物制品", "非法出售发票", "劫持船只、汽车", - "受贿", "聚众哄抢", "交通肇事"] - - -def compute_ljp_accusation(data_dict): - """ - Compute the F1-score - The LJP_Accusation dataset a set of 189 different accusation types. - A question may involve one or more accusation types. - Given a list of accusation types from both the ground truth and the prediction, we compute the F1-score between - these two lists. - """ - score_list, abstentions = [], 0 - - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - - assert answer.startswith("罪名:"), f"answer: {answer} \n question: {question}" - answer = answer.replace("罪名:", "") - answers = answer.split(";") - - prediction_list =[] - for option in option_list: - if option in prediction: - prediction_list.append(option) - - if len(prediction_list) == 0: - abstentions += 1 - gt_set = set(answers) - pred_set = set(prediction_list) - score = compute_f1_two_sets(gt_set, pred_set) - score_list.append(score) - - f1_score_average = sum(score_list) / len(score_list) - return {"score": f1_score_average, "abstention_rate": abstentions/len(data_dict)} +from ..utils.function_utils import compute_f1_two_sets +""" +task: legal accusation prediction +metric: f1 score +法律判决预测-罪名预测 +""" + +option_list = ["侮辱", "违法发放贷款", "失火", "票据诈骗", "帮助犯罪分子逃避处罚", "重大责任事故", "对非国家工作人员行贿", + "非法制造、销售非法制造的注册商标标识", "非法制造、买卖、运输、邮寄、储存枪支、弹药、爆炸物", "非法获取公民个人信息", + "扰乱无线电通讯管理秩序", "非法持有、私藏枪支、弹药", "拒不执行判决、裁定", "虚开发票", "巨额财产来源不明", + "组织、领导、参加黑社会性质组织", "非法获取国家秘密", "以危险方法危害公共安全", "非法持有毒品", + "聚众扰乱公共场所秩序、交通秩序", "包庇毒品犯罪分子", "滥伐林木", "伪造公司、企业、事业单位、人民团体印章", + "非法占用农用地", "走私废物", "串通投标", "非法采伐、毁坏国家重点保护植物", "冒充军人招摇撞骗", "玩忽职守", + "重婚", "招收公务员、学生徇私舞弊", "组织、领导传销活动", "非法猎捕、杀害珍贵、濒危野生动物", "侵犯著作权", + "非法种植毒品原植物", "伪造、变造、买卖武装部队公文、证件、印章", "倒卖文物", "伪造、变造居民身份证", "滥用职权", + "诽谤", "猥亵儿童", "非法转让、倒卖土地使用权", "挪用公款", "污染环境", "出售、购买、运输假币", "敲诈勒索", + "高利转贷", "故意伤害", "持有、使用假币", "单位受贿", "强奸", "引诱、容留、介绍卖淫", "虐待", + "生产、销售伪劣农药、兽药、化肥、种子", "妨害公务", "容留他人吸毒", "拐骗儿童", "强制猥亵、侮辱妇女", + "非法处置查封、扣押、冻结的财产", "骗取贷款、票据承兑、金融票证", "强迫他人吸毒", "非法拘禁", + "非法携带枪支、弹药、管制刀具、危险物品危及公共安全", "绑架", "聚众斗殴", "破坏计算机信息系统", + "制造、贩卖、传播淫秽物品", "虐待被监管人", "贷款诈骗", "赌博", "徇私舞弊不征、少征税款", + "盗窃、抢夺枪支、弹药、爆炸物、危险物质", "故意杀人", "介绍贿赂", "提供侵入、非法控制计算机信息系统程序、工具", + "编造、故意传播虚假恐怖信息", "妨害作证", "强迫卖淫", "走私、贩卖、运输、制造毒品", "伪证", "拐卖妇女、儿童", + "过失损坏武器装备、军事设施、军事通信", "破坏广播电视设施、公用电信设施", "洗钱", "职务侵占", "倒卖车票、船票", + "抢劫", "侵占", "掩饰、隐瞒犯罪所得、犯罪所得收益", "徇私舞弊不移交刑事案件", "引诱、教唆、欺骗他人吸毒", "遗弃", + "生产、销售伪劣产品", "放火", "非法采矿", "对单位行贿", "盗窃、抢夺枪支、弹药、爆炸物", "破坏易燃易爆设备", + "妨害信用卡管理", "制作、复制、出版、贩卖、传播淫秽物品牟利", "金融凭证诈骗", "私分国有资产", + "走私国家禁止进出口的货物、物品", "假冒注册商标", "危险物品肇事", "走私普通货物、物品", "经济犯", "虚报注册资本", + "盗掘古文化遗址、古墓葬", "传播淫秽物品", "窝藏、包庇", "拒不支付劳动报酬", "行贿", "开设赌场", "传授犯罪方法", + "协助组织卖淫", "保险诈骗", "破坏生产经营", "破坏交通设施", "打击报复证人", "非法侵入住宅", "非国家工作人员受贿", + "过失致人重伤", "伪造、变造金融票证", "窝藏、转移、隐瞒毒品、毒赃", "帮助毁灭、伪造证据", "走私珍贵动物、珍贵动物制品", + "生产、销售假药", "逃税", "挪用特定款物", "聚众扰乱社会秩序", "组织、强迫、引诱、容留、介绍卖淫", "合同诈骗", + "非法生产、销售间谍专用器材", "破坏交通工具", "传播性病", "强迫交易", "隐匿、故意销毁会计凭证、会计帐簿、财务会计报告", + "非法组织卖血", "强迫劳动", "破坏电力设备", "销售假冒注册商标的商品", "收买被拐卖的妇女、儿童", "诬告陷害", "脱逃", + "非法经营", "徇私枉法", "信用卡诈骗", "生产、销售不符合安全标准的食品", "非法行医", "伪造货币", "动植物检疫徇私舞弊", + "单位行贿", "破坏监管秩序", "盗窃", "盗伐林木", "重大劳动安全事故", "非法吸收公众存款", + "非法制造、出售非法制造的发票", "非法狩猎", "组织卖淫", "非法买卖、运输、携带、持有毒品原植物种子、幼苗", "挪用资金", + "诈骗", "伪造、变造、买卖国家机关公文、证件、印章", "持有伪造的发票", "贪污", "非法生产、买卖警用装备", + "投放危险物质", "伪造、倒卖伪造的有价票证", "集资诈骗", "抢夺", "生产、销售有毒、有害食品", "非法捕捞水产品", + "过失致人死亡", "非法买卖制毒物品", "虚开增值税专用发票、用于骗取出口退税、抵扣税款发票", "寻衅滋事", "危险驾驶", + "故意毁坏财物", "招摇撞骗", "盗窃、侮辱尸体", "走私武器、弹药", + "非法收购、运输、加工、出售国家重点保护植物、国家重点保护植物制品", "非法出售发票", "劫持船只、汽车", + "受贿", "聚众哄抢", "交通肇事"] + + +def compute_ljp_accusation(data_dict): + """ + Compute the F1-score + The LJP_Accusation dataset a set of 189 different accusation types. + A question may involve one or more accusation types. + Given a list of accusation types from both the ground truth and the prediction, we compute the F1-score between + these two lists. + """ + score_list, abstentions = [], 0 + + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + + assert answer.startswith("罪名:"), f"answer: {answer} \n question: {question}" + answer = answer.replace("罪名:", "") + answers = answer.split(";") + + prediction_list =[] + for option in option_list: + if option in prediction: + prediction_list.append(option) + + if len(prediction_list) == 0: + abstentions += 1 + gt_set = set(answers) + pred_set = set(prediction_list) + score = compute_f1_two_sets(gt_set, pred_set) + score_list.append(score) + + f1_score_average = sum(score_list) / len(score_list) + return {"score": f1_score_average, "abstention_rate": abstentions/len(data_dict)} diff --git a/opencompass/datasets/lawbench/evaluation_functions/ljp_article.py b/opencompass/datasets/lawbench/evaluation_functions/ljp_article.py index dc9afb35..e12a1ac4 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/ljp_article.py +++ b/opencompass/datasets/lawbench/evaluation_functions/ljp_article.py @@ -1,70 +1,70 @@ -import re -import cn2an - -""" -task: law article prediction -metric: F1 score -法律判决预测-法条预测 -""" -def replace_match(match): - return match.group(1) - -def compute_ljp_article(data_dict): - """ - Compute the F1-score - A reference contains a list of articles of the Criminal Law of the People's Republic of China. - We compute the F1-score between the prediction and the reference. - """ - - score_list, abstentions = [], 0 - - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - assert answer.startswith("法条:刑法第"), f"answer: {answer}" - assert answer.endswith("条"), f"answer: {answer}" - - answer = answer.replace("法条:刑法第", "") - answer = answer.replace("条", "") - - answer_law_indices = answer.split("、") - answer_law_index_digit_list = [] - for answer_law_index in answer_law_indices: - assert answer_law_index.isdigit(), f"answer_law_index: {answer_law_index}" - answer_law_index_digit = int(answer_law_index) - assert answer_law_index_digit <= 490, "刑法总共只有490条" - answer_law_index_digit_list.append(answer_law_index_digit) - - prediction_law_chunks = prediction.split("、") - prediction_law_index_digit_list = [] - - for prediction_law_chunk in prediction_law_chunks: - prediction_law_chunk = prediction_law_chunk.replace("万元", "元") - - # delete phrase starts with "第" and ends with "款", we don't care about it in the answer - prediction_law_chunk = re.sub(r'第(.*?)款', "", prediction_law_chunk) - # keep only the digits in the phrase starts with "第" and ends with "条", otherwise cn may fail to convert - prediction_law_chunk = re.sub(r'第(.*?)条', replace_match, prediction_law_chunk) - prediction_law_chunk = cn2an.transform(prediction_law_chunk, "cn2an") - # find digtis in prediction_law_chunk - prediction_law_section_numbers = re.findall(r"\d+", prediction_law_chunk) - if len(prediction_law_section_numbers) == 0: - continue - if len(prediction_law_section_numbers) != 1: - # in this case, we only take the first number, and reject the others - pass - - prediction_law_index_digit = int(prediction_law_section_numbers[0]) - prediction_law_index_digit_list.append(prediction_law_index_digit) - - gt_set = set(answer_law_index_digit_list) - pred_set = set(prediction_law_index_digit_list) - if len(pred_set) == 0: - abstentions += 1 - precision = len(gt_set.intersection(pred_set)) / len(pred_set) if len(pred_set) != 0 else 0 - recall = len(gt_set.intersection(pred_set)) / len(gt_set) if len(gt_set) != 0 else 0 - f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0 - score_list.append(f1_score) - - # compute the accuracy of score_list - average_f1 = sum(score_list) / len(score_list) - return {'score': average_f1, 'abstention_rate': abstentions/len(data_dict)} +import re +import cn2an + +""" +task: law article prediction +metric: F1 score +法律判决预测-法条预测 +""" +def replace_match(match): + return match.group(1) + +def compute_ljp_article(data_dict): + """ + Compute the F1-score + A reference contains a list of articles of the Criminal Law of the People's Republic of China. + We compute the F1-score between the prediction and the reference. + """ + + score_list, abstentions = [], 0 + + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + assert answer.startswith("法条:刑法第"), f"answer: {answer}" + assert answer.endswith("条"), f"answer: {answer}" + + answer = answer.replace("法条:刑法第", "") + answer = answer.replace("条", "") + + answer_law_indices = answer.split("、") + answer_law_index_digit_list = [] + for answer_law_index in answer_law_indices: + assert answer_law_index.isdigit(), f"answer_law_index: {answer_law_index}" + answer_law_index_digit = int(answer_law_index) + assert answer_law_index_digit <= 490, "刑法总共只有490条" + answer_law_index_digit_list.append(answer_law_index_digit) + + prediction_law_chunks = prediction.split("、") + prediction_law_index_digit_list = [] + + for prediction_law_chunk in prediction_law_chunks: + prediction_law_chunk = prediction_law_chunk.replace("万元", "元") + + # delete phrase starts with "第" and ends with "款", we don't care about it in the answer + prediction_law_chunk = re.sub(r'第(.*?)款', "", prediction_law_chunk) + # keep only the digits in the phrase starts with "第" and ends with "条", otherwise cn may fail to convert + prediction_law_chunk = re.sub(r'第(.*?)条', replace_match, prediction_law_chunk) + prediction_law_chunk = cn2an.transform(prediction_law_chunk, "cn2an") + # find digtis in prediction_law_chunk + prediction_law_section_numbers = re.findall(r"\d+", prediction_law_chunk) + if len(prediction_law_section_numbers) == 0: + continue + if len(prediction_law_section_numbers) != 1: + # in this case, we only take the first number, and reject the others + pass + + prediction_law_index_digit = int(prediction_law_section_numbers[0]) + prediction_law_index_digit_list.append(prediction_law_index_digit) + + gt_set = set(answer_law_index_digit_list) + pred_set = set(prediction_law_index_digit_list) + if len(pred_set) == 0: + abstentions += 1 + precision = len(gt_set.intersection(pred_set)) / len(pred_set) if len(pred_set) != 0 else 0 + recall = len(gt_set.intersection(pred_set)) / len(gt_set) if len(gt_set) != 0 else 0 + f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0 + score_list.append(f1_score) + + # compute the accuracy of score_list + average_f1 = sum(score_list) / len(score_list) + return {'score': average_f1, 'abstention_rate': abstentions/len(data_dict)} diff --git a/opencompass/datasets/lawbench/evaluation_functions/ljp_imprison.py b/opencompass/datasets/lawbench/evaluation_functions/ljp_imprison.py index 0ded2c98..fc5bc0da 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/ljp_imprison.py +++ b/opencompass/datasets/lawbench/evaluation_functions/ljp_imprison.py @@ -1,49 +1,49 @@ -import math -import cn2an -import re - -#法律判决预测-刑期预测 -def compute_ljp_imprison(data_dict): - score_list, abstentions = [], 0 - - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - # get answer digit, which is the number between "刑期:" and "个月" - if "死刑" in answer or "无期" in answer: - # TODO: data imperfection - continue - - assert answer.startswith("刑期:") and answer.endswith("个月"), f"answer: {answer}, question: {question}" - answer = answer.replace("刑期:", "") - answer = answer.replace("个月", "") - answer_digit = int(answer) - prediction = cn2an.transform(prediction, "cn2an") - - # use regular expression to extract the digits from prediction, only consider digits before "个月" or "月" - prediction_digit_month_list = re.findall(r"\d+个月", prediction) - prediction_digit_month_list = [int(digit.replace("个月", "")) for digit in prediction_digit_month_list] - prediction_digit_month_list2 = re.findall(r"\d+月", prediction) - prediction_digit_month_list2 = [int(digit.replace("月", "")) for digit in prediction_digit_month_list2] - prediction_digit_month_list.extend(prediction_digit_month_list2) - # catches the digits before "年" - prediction_digit_year_list = re.findall(r"\d+年", prediction) - prediction_digit_year_list = [int(digit.replace("年", "")) for digit in prediction_digit_year_list] - - if len(prediction_digit_month_list) > 0: - prediction_digit_month = int(prediction_digit_month_list[0]) - elif len(prediction_digit_year_list) > 0: - prediction_digit_month = int(prediction_digit_year_list[0]) * 12 - else: - abstentions += 1 - prediction_digit_month = -1 - - if prediction_digit_month != -1: - score_list.append(abs(math.log(answer_digit + 1) - math.log(prediction_digit_month + 1))) - else: - score_list.append(math.log(216)) - - # compute the average of score_list (log distance) - log_distance = sum(score_list) / len(score_list) - # normalize the score to between 0 and 1 - log_distance = (math.log(216) - log_distance)/math.log(216) - return {"score": log_distance, "abstention_rate": abstentions/len(data_dict)} +import math +import cn2an +import re + +#法律判决预测-刑期预测 +def compute_ljp_imprison(data_dict): + score_list, abstentions = [], 0 + + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + # get answer digit, which is the number between "刑期:" and "个月" + if "死刑" in answer or "无期" in answer: + # TODO: data imperfection + continue + + assert answer.startswith("刑期:") and answer.endswith("个月"), f"answer: {answer}, question: {question}" + answer = answer.replace("刑期:", "") + answer = answer.replace("个月", "") + answer_digit = int(answer) + prediction = cn2an.transform(prediction, "cn2an") + + # use regular expression to extract the digits from prediction, only consider digits before "个月" or "月" + prediction_digit_month_list = re.findall(r"\d+个月", prediction) + prediction_digit_month_list = [int(digit.replace("个月", "")) for digit in prediction_digit_month_list] + prediction_digit_month_list2 = re.findall(r"\d+月", prediction) + prediction_digit_month_list2 = [int(digit.replace("月", "")) for digit in prediction_digit_month_list2] + prediction_digit_month_list.extend(prediction_digit_month_list2) + # catches the digits before "年" + prediction_digit_year_list = re.findall(r"\d+年", prediction) + prediction_digit_year_list = [int(digit.replace("年", "")) for digit in prediction_digit_year_list] + + if len(prediction_digit_month_list) > 0: + prediction_digit_month = int(prediction_digit_month_list[0]) + elif len(prediction_digit_year_list) > 0: + prediction_digit_month = int(prediction_digit_year_list[0]) * 12 + else: + abstentions += 1 + prediction_digit_month = -1 + + if prediction_digit_month != -1: + score_list.append(abs(math.log(answer_digit + 1) - math.log(prediction_digit_month + 1))) + else: + score_list.append(math.log(216)) + + # compute the average of score_list (log distance) + log_distance = sum(score_list) / len(score_list) + # normalize the score to between 0 and 1 + log_distance = (math.log(216) - log_distance)/math.log(216) + return {"score": log_distance, "abstention_rate": abstentions/len(data_dict)} diff --git a/opencompass/datasets/lawbench/evaluation_functions/sjjc.py b/opencompass/datasets/lawbench/evaluation_functions/sjjc.py index 0ff6f1de..d5d9b7e3 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/sjjc.py +++ b/opencompass/datasets/lawbench/evaluation_functions/sjjc.py @@ -1,64 +1,64 @@ -from ..utils.function_utils import compute_f1_two_sets -from ..utils.rc_f1 import CJRCEvaluator - - -""" -task: event detection -metric: F1 score -事件检测 -""" -option_list = ["支付/给付", "欺骗", "搜查/扣押", "要求/请求", "卖出", "买入", "获利", "拘捕", "鉴定", "同意/接受", "供述", "联络", "帮助/救助", "租用/借用", "受伤", "伪造", "卖淫", "伤害人身", "赔偿", "归还/偿还"] - -def compute_sjjc(data_dict): - """ - Compute the F1-score - The sjjc task covers 20 event types. - A question may involve one or more event types. - Given a list of event types from both the ground truth and the prediction, we compute the F1-score between - these two lists. - """ - score_list, abstentions = [], 0 - - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - - answers = answer.split(";") - - prediction_list =[] - for option in option_list: - if option in prediction: - prediction_list.append(option) - - if len(prediction_list) == 0: - abstentions += 1 - gt_set = set(answers) - pred_set = set(prediction_list) - score = compute_f1_two_sets(gt_set, pred_set) - score_list.append(score) - - f1_score_average = sum(score_list) / len(score_list) - return {"score": f1_score_average, "abstention_rate": abstentions/len(data_dict)} - -""" -task: trigger word extraction -metric: F1 score -触发词抽取 -""" -def compute_cfcy(data_dict): - - scores = 0 - - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - - answers = answer.split(";") - predictions = prediction.split(";") - intersected = [CJRCEvaluator.compute_f1(r, h) for r, h in zip(answers, predictions)] - - prec = sum(intersected) / len(predictions) if len(predictions) > 0 else 0 - rec = sum(intersected) / len(answers) if len(answers) > 0 else 0 - # print(prec, rec, intersected) - scores += 2 * prec * rec / (prec + rec + 1e-10) - - f1_score_average = scores / len(data_dict) - return {"score": f1_score_average} +from ..utils.function_utils import compute_f1_two_sets +from ..utils.rc_f1 import CJRCEvaluator + + +""" +task: event detection +metric: F1 score +事件检测 +""" +option_list = ["支付/给付", "欺骗", "搜查/扣押", "要求/请求", "卖出", "买入", "获利", "拘捕", "鉴定", "同意/接受", "供述", "联络", "帮助/救助", "租用/借用", "受伤", "伪造", "卖淫", "伤害人身", "赔偿", "归还/偿还"] + +def compute_sjjc(data_dict): + """ + Compute the F1-score + The sjjc task covers 20 event types. + A question may involve one or more event types. + Given a list of event types from both the ground truth and the prediction, we compute the F1-score between + these two lists. + """ + score_list, abstentions = [], 0 + + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + + answers = answer.split(";") + + prediction_list =[] + for option in option_list: + if option in prediction: + prediction_list.append(option) + + if len(prediction_list) == 0: + abstentions += 1 + gt_set = set(answers) + pred_set = set(prediction_list) + score = compute_f1_two_sets(gt_set, pred_set) + score_list.append(score) + + f1_score_average = sum(score_list) / len(score_list) + return {"score": f1_score_average, "abstention_rate": abstentions/len(data_dict)} + +""" +task: trigger word extraction +metric: F1 score +触发词抽取 +""" +def compute_cfcy(data_dict): + + scores = 0 + + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + + answers = answer.split(";") + predictions = prediction.split(";") + intersected = [CJRCEvaluator.compute_f1(r, h) for r, h in zip(answers, predictions)] + + prec = sum(intersected) / len(predictions) if len(predictions) > 0 else 0 + rec = sum(intersected) / len(answers) if len(answers) > 0 else 0 + # print(prec, rec, intersected) + scores += 2 * prec * rec / (prec + rec + 1e-10) + + f1_score_average = scores / len(data_dict) + return {"score": f1_score_average} diff --git a/opencompass/datasets/lawbench/evaluation_functions/wbfl.py b/opencompass/datasets/lawbench/evaluation_functions/wbfl.py index 7ed4334b..edde3eb9 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/wbfl.py +++ b/opencompass/datasets/lawbench/evaluation_functions/wbfl.py @@ -1,42 +1,42 @@ -""" -task: multiple choice classification -metric: F1 score -婚姻文本分类 -""" - -def compute_wbfl(data_dict): - """ - A reference (R) contains a list of options, each option is from the option_list. - We will extract the options appearing in the prediction and convert them into a set (P). - We compute the F1 score between the prediction (P) and the reference (R). - """ - - - score_list, abstentions = [], 0 - option_list = ["婚后有子女", "限制行为能力子女抚养", "有夫妻共同财产", "支付抚养费", "不动产分割", "婚后分局", - "二次起诉离婚", "按月给付抚养费", "准予离婚", "有夫妻共同债务", "婚前个人财产", "法定离婚", "不履行家庭义务", - "存在非婚生子", "适当帮助", "不履行离婚协议", "损害赔偿", "感情不和分居满二年", "子女随非抚养权人生活", "婚后个人财产"] - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - assert answer.startswith("类别:") and answer.endswith("。"), f"answer: {answer}, question: {question}" - - gt_list = (answer[3:-1].split("、")) - for gt in gt_list: - assert gt in option_list, f"gt: {gt}, question: {question}" - gt_set = set(gt_list) - - prediction_list = [] - for option in option_list: - if option in prediction: - prediction_list.append(option) - if len(prediction_list) == 0: - abstentions += 1 - predict_set = set(prediction_list) - precision = len(gt_set.intersection(predict_set)) / len(predict_set) if len(predict_set) != 0 else 0 - recall = len(gt_set.intersection(predict_set)) / len(gt_set) if len(gt_set) != 0 else 0 - f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0 - score_list.append(f1_score) - - # compute the accuracy of score_list - final_f1_score = sum(score_list) / len(score_list) - return {'score': final_f1_score, 'abstention_rate': abstentions / len(data_dict)} +""" +task: multiple choice classification +metric: F1 score +婚姻文本分类 +""" + +def compute_wbfl(data_dict): + """ + A reference (R) contains a list of options, each option is from the option_list. + We will extract the options appearing in the prediction and convert them into a set (P). + We compute the F1 score between the prediction (P) and the reference (R). + """ + + + score_list, abstentions = [], 0 + option_list = ["婚后有子女", "限制行为能力子女抚养", "有夫妻共同财产", "支付抚养费", "不动产分割", "婚后分局", + "二次起诉离婚", "按月给付抚养费", "准予离婚", "有夫妻共同债务", "婚前个人财产", "法定离婚", "不履行家庭义务", + "存在非婚生子", "适当帮助", "不履行离婚协议", "损害赔偿", "感情不和分居满二年", "子女随非抚养权人生活", "婚后个人财产"] + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + assert answer.startswith("类别:") and answer.endswith("。"), f"answer: {answer}, question: {question}" + + gt_list = (answer[3:-1].split("、")) + for gt in gt_list: + assert gt in option_list, f"gt: {gt}, question: {question}" + gt_set = set(gt_list) + + prediction_list = [] + for option in option_list: + if option in prediction: + prediction_list.append(option) + if len(prediction_list) == 0: + abstentions += 1 + predict_set = set(prediction_list) + precision = len(gt_set.intersection(predict_set)) / len(predict_set) if len(predict_set) != 0 else 0 + recall = len(gt_set.intersection(predict_set)) / len(gt_set) if len(gt_set) != 0 else 0 + f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0 + score_list.append(f1_score) + + # compute the accuracy of score_list + final_f1_score = sum(score_list) / len(score_list) + return {'score': final_f1_score, 'abstention_rate': abstentions / len(data_dict)} diff --git a/opencompass/datasets/lawbench/evaluation_functions/wsjd.py b/opencompass/datasets/lawbench/evaluation_functions/wsjd.py index d334a4bd..231ea77e 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/wsjd.py +++ b/opencompass/datasets/lawbench/evaluation_functions/wsjd.py @@ -1,50 +1,50 @@ -import re -import os -import subprocess - -""" -Task: legal document grammar correction -Metric: F0.5 score -文书校对 -""" -def compute_wsjd(data_dict): - origins, references, predictions = [], [], [] - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - if isinstance(question, list): - question = question[0]['prompt'] - start = question.index('句子:\n') + 4 - origins.append(re.sub(r'\n|\t', '', question[start:].split('\n')[0])) - # truncate predictions >5 tokens longer than the reference - prediction = re.sub(r'\n|\t', '', prediction) - if len(prediction) - len(answer) > 5: - prediction = prediction[:len(answer) + 5] - if len(prediction) == 0: - prediction = "无内容" - predictions.append(prediction) - references.append(re.sub(r'\n|\t', '', answer)) - - #generate input files for ChERRANT - preds = [f'{i} \t {origin} \t {prediction} \n' for i, (origin, prediction) in enumerate(zip(origins, predictions))] - golds = [f'{i} \t {origin} \t {reference} \n' for i, (origin, reference) in enumerate(zip(origins, references))] - - now_path = os.path.abspath(os.getcwd()) - utils_path = os.path.abspath(os.path.join(__file__, '..', '..', 'utils')) - uid = os.getuid() - os.chdir(utils_path) - with open(f'/tmp/tmp_pred_{uid}.para', 'w') as f: - f.writelines(preds) - with open(f'/tmp/tmp_gold_{uid}.para', 'w') as f: - f.writelines(golds) - os.environ['KMP_DUPLICATE_LIB_OK']='True' - os.system(f'python3 parallel_to_m2.py -f /tmp/tmp_pred_{uid}.para -o /tmp/tmp_pred_{uid}.para.m2 -g char') - os.system(f'python3 parallel_to_m2.py -f /tmp/tmp_gold_{uid}.para -o /tmp/tmp_gold_{uid}.para.m2 -g char') - output = subprocess.check_output(f"python3 compare_m2_for_evaluation.py -hyp /tmp/tmp_pred_{uid}.para.m2 -ref /tmp/tmp_gold_{uid}.para.m2", shell = True) - score = float(output.decode().split('\t')[-1].split('\n')[0]) - #remove prediction files - os.remove(f'/tmp/tmp_pred_{uid}.para') - os.remove(f'/tmp/tmp_gold_{uid}.para') - os.remove(f'/tmp/tmp_pred_{uid}.para.m2') - os.remove(f'/tmp/tmp_gold_{uid}.para.m2') - os.chdir(now_path) - return {"score": score} +import re +import os +import subprocess + +""" +Task: legal document grammar correction +Metric: F0.5 score +文书校对 +""" +def compute_wsjd(data_dict): + origins, references, predictions = [], [], [] + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + if isinstance(question, list): + question = question[0]['prompt'] + start = question.index('句子:\n') + 4 + origins.append(re.sub(r'\n|\t', '', question[start:].split('\n')[0])) + # truncate predictions >5 tokens longer than the reference + prediction = re.sub(r'\n|\t', '', prediction) + if len(prediction) - len(answer) > 5: + prediction = prediction[:len(answer) + 5] + if len(prediction) == 0: + prediction = "无内容" + predictions.append(prediction) + references.append(re.sub(r'\n|\t', '', answer)) + + #generate input files for ChERRANT + preds = [f'{i} \t {origin} \t {prediction} \n' for i, (origin, prediction) in enumerate(zip(origins, predictions))] + golds = [f'{i} \t {origin} \t {reference} \n' for i, (origin, reference) in enumerate(zip(origins, references))] + + now_path = os.path.abspath(os.getcwd()) + utils_path = os.path.abspath(os.path.join(__file__, '..', '..', 'utils')) + uid = os.getuid() + os.chdir(utils_path) + with open(f'/tmp/tmp_pred_{uid}.para', 'w') as f: + f.writelines(preds) + with open(f'/tmp/tmp_gold_{uid}.para', 'w') as f: + f.writelines(golds) + os.environ['KMP_DUPLICATE_LIB_OK']='True' + os.system(f'python3 parallel_to_m2.py -f /tmp/tmp_pred_{uid}.para -o /tmp/tmp_pred_{uid}.para.m2 -g char') + os.system(f'python3 parallel_to_m2.py -f /tmp/tmp_gold_{uid}.para -o /tmp/tmp_gold_{uid}.para.m2 -g char') + output = subprocess.check_output(f"python3 compare_m2_for_evaluation.py -hyp /tmp/tmp_pred_{uid}.para.m2 -ref /tmp/tmp_gold_{uid}.para.m2", shell = True) + score = float(output.decode().split('\t')[-1].split('\n')[0]) + #remove prediction files + os.remove(f'/tmp/tmp_pred_{uid}.para') + os.remove(f'/tmp/tmp_gold_{uid}.para') + os.remove(f'/tmp/tmp_pred_{uid}.para.m2') + os.remove(f'/tmp/tmp_gold_{uid}.para.m2') + os.chdir(now_path) + return {"score": score} diff --git a/opencompass/datasets/lawbench/evaluation_functions/xxcq.py b/opencompass/datasets/lawbench/evaluation_functions/xxcq.py index c504c730..679d94d7 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/xxcq.py +++ b/opencompass/datasets/lawbench/evaluation_functions/xxcq.py @@ -1,17 +1,17 @@ -from ..utils.comprehension_scores import compute_ie_f1 - - -""" -task: information extraction -metric: F1 score -信息抽取 -""" -def compute_xxcq(data_dict): - references, predictions = [], [] - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - predictions.append(prediction) - references.append(answer) - - return compute_ie_f1(predictions, references, {"犯罪嫌疑人", "受害人", "被盗货币", "物品价值", "盗窃获利", - "被盗物品", "作案工具", "时间", "地点", "组织机构"}) +from ..utils.comprehension_scores import compute_ie_f1 + + +""" +task: information extraction +metric: F1 score +信息抽取 +""" +def compute_xxcq(data_dict): + references, predictions = [], [] + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + predictions.append(prediction) + references.append(answer) + + return compute_ie_f1(predictions, references, {"犯罪嫌疑人", "受害人", "被盗货币", "物品价值", "盗窃获利", + "被盗物品", "作案工具", "时间", "地点", "组织机构"}) diff --git a/opencompass/datasets/lawbench/evaluation_functions/ydlj.py b/opencompass/datasets/lawbench/evaluation_functions/ydlj.py index 1065959f..5081e027 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/ydlj.py +++ b/opencompass/datasets/lawbench/evaluation_functions/ydlj.py @@ -1,17 +1,17 @@ -from ..utils.comprehension_scores import compute_rc_f1 - -""" -Task: machine reading comprehension -Metric: F1 score -法律阅读理解 -""" -def compute_ydlj(data_dict): - references, predictions = [], [] - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - answer = answer.replace("回答:", "") - predictions.append(prediction) - references.append(answer) - - f1_score = compute_rc_f1(predictions, references) - return f1_score +from ..utils.comprehension_scores import compute_rc_f1 + +""" +Task: machine reading comprehension +Metric: F1 score +法律阅读理解 +""" +def compute_ydlj(data_dict): + references, predictions = [], [] + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + answer = answer.replace("回答:", "") + predictions.append(prediction) + references.append(answer) + + f1_score = compute_rc_f1(predictions, references) + return f1_score diff --git a/opencompass/datasets/lawbench/evaluation_functions/yqzy.py b/opencompass/datasets/lawbench/evaluation_functions/yqzy.py index 57b62466..1568050d 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/yqzy.py +++ b/opencompass/datasets/lawbench/evaluation_functions/yqzy.py @@ -1,18 +1,18 @@ -from ..utils.function_utils import compute_rouge - -#舆情摘要 -def compute_yqzy(data_dict): - """ - Compute the ROUGE-L score between the prediction and the reference - """ - references, predictions = [], [] - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - predictions.append(prediction) - references.append(answer) - - # compute the accuracy of score_list - rouge_scores = compute_rouge(predictions, references) - rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores] - average_rouge_l = sum(rouge_ls) / len(rouge_ls) - return {"score": average_rouge_l} +from ..utils.function_utils import compute_rouge + +#舆情摘要 +def compute_yqzy(data_dict): + """ + Compute the ROUGE-L score between the prediction and the reference + """ + references, predictions = [], [] + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + predictions.append(prediction) + references.append(answer) + + # compute the accuracy of score_list + rouge_scores = compute_rouge(predictions, references) + rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores] + average_rouge_l = sum(rouge_ls) / len(rouge_ls) + return {"score": average_rouge_l} diff --git a/opencompass/datasets/lawbench/evaluation_functions/zxfl.py b/opencompass/datasets/lawbench/evaluation_functions/zxfl.py index 4cb0ec00..7e4b0bc4 100644 --- a/opencompass/datasets/lawbench/evaluation_functions/zxfl.py +++ b/opencompass/datasets/lawbench/evaluation_functions/zxfl.py @@ -1,27 +1,27 @@ -from ..utils.function_utils import multi_choice_judge - -""" -task: multiple choice classification -metric: accuracy -咨询分类 -""" - -def compute_zxfl(data_dict): - """ - A reference (R) contains a list of options, each option is from the option_list. - We will extract the options appearing in the prediction and convert them into a set (P). - We compute the accuracy between the prediction (P) and the reference (R). - """ - - - score_list, abstentions = [], 0 - option_list = ['婚姻家庭', '劳动纠纷', '交通事故', '债权债务', '刑事辩护', '合同纠纷', '房产纠纷', '侵权', '公司法', '医疗纠纷', '拆迁安置', '行政诉讼', '建设工程', '知识产权', '综合咨询', '人身损害', '涉外法律', '海事海商', '消费权益', '抵押担保'] - for example in data_dict: - question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] - judge = multi_choice_judge(prediction, option_list, answer) - score_list.append(judge["score"]) - abstentions += judge["abstention"] - - # compute the accuracy of score_list - final_accuracy_score = sum(score_list) / len(score_list) - return {'score': final_accuracy_score, 'abstention_rate': abstentions / len(data_dict)} +from ..utils.function_utils import multi_choice_judge + +""" +task: multiple choice classification +metric: accuracy +咨询分类 +""" + +def compute_zxfl(data_dict): + """ + A reference (R) contains a list of options, each option is from the option_list. + We will extract the options appearing in the prediction and convert them into a set (P). + We compute the accuracy between the prediction (P) and the reference (R). + """ + + + score_list, abstentions = [], 0 + option_list = ['婚姻家庭', '劳动纠纷', '交通事故', '债权债务', '刑事辩护', '合同纠纷', '房产纠纷', '侵权', '公司法', '医疗纠纷', '拆迁安置', '行政诉讼', '建设工程', '知识产权', '综合咨询', '人身损害', '涉外法律', '海事海商', '消费权益', '抵押担保'] + for example in data_dict: + question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] + judge = multi_choice_judge(prediction, option_list, answer) + score_list.append(judge["score"]) + abstentions += judge["abstention"] + + # compute the accuracy of score_list + final_accuracy_score = sum(score_list) / len(score_list) + return {'score': final_accuracy_score, 'abstention_rate': abstentions / len(data_dict)} diff --git a/opencompass/datasets/lawbench/utils/char_smi.py b/opencompass/datasets/lawbench/utils/char_smi.py index 0d257601..54bb4790 100644 --- a/opencompass/datasets/lawbench/utils/char_smi.py +++ b/opencompass/datasets/lawbench/utils/char_smi.py @@ -1,456 +1,456 @@ -### Copy from https://github.com/iqiyi/FASPell ### - -""" -Requirements: - - java (required only if tree edit distance is used) - - numpy -""" -import numpy as np -from subprocess import Popen, PIPE, STDOUT -import os -import argparse - -IDCS = {'\u2ff0': 2, # 12 ideographic description characters and their capacity of son nodes - '\u2ff1': 2, - '\u2ff2': 3, - '\u2ff3': 3, - '\u2ff4': 2, - '\u2ff5': 2, - '\u2ff6': 2, - '\u2ff7': 2, - '\u2ff8': 2, - '\u2ff9': 2, - '\u2ffa': 2, - '\u2ffb': 2, } - -PINYIN = {'ā': ['a', 1], 'á': ['a', 2], 'ǎ': ['a', 3], 'à': ['a', 4], - 'ē': ['e', 1], 'é': ['e', 2], 'ě': ['e', 3], 'è': ['e', 4], - 'ī': ['i', 1], 'í': ['i', 2], 'ǐ': ['i', 3], 'ì': ['i', 4], - 'ō': ['o', 1], 'ó': ['o', 2], 'ǒ': ['o', 3], 'ò': ['o', 4], - 'ū': ['u', 1], 'ú': ['u', 2], 'ǔ': ['u', 3], 'ù': ['u', 4], - 'ǖ': ['ü', 1], 'ǘ': ['ü', 2], 'ǚ': ['ü', 3], 'ǜ': ['ü', 4], - '': ['m', 2], 'ń': ['n', 2], 'ň': ['n', 3], 'ǹ': ['n', 4], - } - -# APTED_JAR_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'apted.jar') -APTED_JAR_PATH = 'apted.jar' - - -def tree_edit_distance(tree_a, tree_b): - """ - We use APTED algorithm proposed by M. Pawlik and N. Augsten - github link: https://github.com/DatabaseGroup/apted - """ - p = Popen(['java', '-jar', APTED_JAR_PATH, '-t', tree_a, tree_b], stdout=PIPE, stderr=STDOUT) - - res = [line for line in p.stdout] - res = res[0] - res = res.strip() - res = float(res) - - return res - - -def edit_distance(string_a, string_b, name='Levenshtein'): - """ - >>> edit_distance('abcde', 'avbcude') - 2 - >>> edit_distance(['至', '刂'], ['亻', '至', '刂']) - 1 - >>> edit_distance('fang', 'qwe') - 4 - >>> edit_distance('fang', 'hen') - 3 - """ - size_x = len(string_a) + 1 - size_y = len(string_b) + 1 - matrix = np.zeros((size_x, size_y), dtype=int) - for x in range(size_x): - matrix[x, 0] = x - for y in range(size_y): - matrix[0, y] = y - - for x in range(1, size_x): - for y in range(1, size_y): - if string_a[x - 1] == string_b[y - 1]: - matrix[x, y] = min( - matrix[x - 1, y] + 1, - matrix[x - 1, y - 1], - matrix[x, y - 1] + 1 - ) - else: - if name == 'Levenshtein': - matrix[x, y] = min( - matrix[x - 1, y] + 1, - matrix[x - 1, y - 1] + 1, - matrix[x, y - 1] + 1 - ) - else: # Canonical - matrix[x, y] = min( - matrix[x - 1, y] + 1, - matrix[x - 1, y - 1] + 2, - matrix[x, y - 1] + 1 - ) - - return matrix[size_x - 1, size_y - 1] - - -class CharFuncs(object): - def __init__(self, char_meta_fname): - self.data = self.load_char_meta(char_meta_fname) - self.char_dict = dict([(c, 0) for c in self.data]) - - self.safe = {'\u2ff0': 'A', - # to eliminate the bug that, in Windows CMD, char ⿻ and ⿵ are encoded to be the same. - '\u2ff1': 'B', - '\u2ff2': 'C', - '\u2ff3': 'D', - '\u2ff4': 'E', - '\u2ff5': 'F', - '\u2ff6': 'G', - '\u2ff7': 'H', - '\u2ff8': 'I', - '\u2ff9': 'J', - '\u2ffa': 'L', - '\u2ffb': 'M', } - - @staticmethod - def load_char_meta(fname): - data = {} - f = open(fname, 'r', encoding='utf-8') - for line in f: - items = line.strip().split('\t') - code_point = items[0] - char = items[1] - pronunciation = items[2] - decompositions = items[3:] - assert char not in data - data[char] = {"code_point": code_point, "pronunciation": pronunciation, "decompositions": decompositions} - return data - - def shape_distance(self, char1, char2, safe=True, as_tree=False): - """ - >>> c = CharFuncs('data/char_meta.txt') - >>> c.shape_distance('田', '由') - 1 - >>> c.shape_distance('牛', '午') - 1 - """ - assert char1 in self.data - assert char2 in self.data - - def safe_encode(decomp): - tree = '' - for c in string_to_tree(decomp): - if c not in self.safe: - tree += c - else: - tree += self.safe[c] - return tree - - def safe_encode_string(decomp): - tree = '' - for c in decomp: - if c not in self.safe: - tree += c - else: - tree += self.safe[c] - return tree - - decomps_1 = self.data[char1]["decompositions"] - decomps_2 = self.data[char2]["decompositions"] - - distance = 1e5 - if as_tree: - for decomp1 in decomps_1: - for decomp2 in decomps_2: - if not safe: - ted = tree_edit_distance(string_to_tree(decomp1), string_to_tree(decomp2)) - else: - ted = tree_edit_distance(safe_encode(decomp1), safe_encode(decomp2)) - distance = min(distance, ted) - else: - for decomp1 in decomps_1: - for decomp2 in decomps_2: - if not safe: - ed = edit_distance(decomp1, decomp2) - else: - ed = edit_distance(safe_encode_string(decomp1), safe_encode_string(decomp2)) - distance = min(distance, ed) - - return distance - - def pronunciation_distance(self, char1, char2): - """ - >>> c = CharFuncs('data/char_meta.txt') - >>> c.pronunciation_distance('田', '由') - 3.4 - >>> c.pronunciation_distance('牛', '午') - 2.6 - """ - assert char1 in self.data - assert char2 in self.data - pronunciations1 = self.data[char1]["pronunciation"] - pronunciations2 = self.data[char2]["pronunciation"] - - if pronunciations1[0] == 'null' or pronunciations2 == 'null': - return 0.0 - else: - - pronunciations1 = pronunciations1.split(';') # separate by lan - pronunciations2 = pronunciations2.split(';') # separate by lan - - distance = 0.0 - count = 0 - for pron_lan1, pron_lan2 in zip(pronunciations1, pronunciations2): - if (pron_lan1 == 'null') or (pron_lan2 == 'null'): - pass - else: - distance_lan = 1e5 - for p1 in pron_lan1.split(','): - for p2 in pron_lan2.split(','): - distance_lan = min(distance_lan, edit_distance(p1, p2)) - distance += distance_lan - count += 1 - - return distance / count - - @staticmethod - def load_dict(fname): - data = {} - f = open(fname, 'r', encoding='utf-8') - for line in f: - char, freq = line.strip().split('\t') - assert char not in data - data[char] = freq - - return data - - def similarity(self, char1, char2, weights=(0.8, 0.2, 0.0), as_tree=False): - """ - this function returns weighted similarity. When used in FASPell, each weight can only be 0 or 1. - """ - - # assert char1 in self.char_dict - # assert char2 in self.char_dict - shape_w, sound_w, freq_w = weights - - if char1 in self.char_dict and char2 in self.char_dict: - - shape_sim = self.shape_similarity(char1, char2, as_tree=as_tree) - sound_sim = self.pronunciation_similarity(char1, char2) - freq_sim = 1.0 - self.char_dict[char2] / len(self.char_dict) - - return shape_sim * shape_w + sound_sim * sound_w + freq_sim * freq_w - else: - return 0.0 - - def shape_similarity(self, char1, char2, safe=True, as_tree=False): - """ - >>> c = CharFuncs('data/char_meta.txt') - >>> c.shape_similarity('牛', '午') - 0.8571428571428572 - >>> c.shape_similarity('田', '由') - 0.8888888888888888 - """ - assert char1 in self.data - assert char2 in self.data - - def safe_encode(decomp): - tree = '' - for c in string_to_tree(decomp): - if c not in self.safe: - tree += c - else: - tree += self.safe[c] - return tree - - def safe_encode_string(decomp): - tree = '' - for c in decomp: - if c not in self.safe: - tree += c - else: - tree += self.safe[c] - return tree - - decomps_1 = self.data[char1]["decompositions"] - decomps_2 = self.data[char2]["decompositions"] - - similarity = 0.0 - if as_tree: - for decomp1 in decomps_1: - for decomp2 in decomps_2: - if not safe: - ted = tree_edit_distance(string_to_tree(decomp1), string_to_tree(decomp2)) - else: - ted = tree_edit_distance(safe_encode(decomp1), safe_encode(decomp2)) - normalized_ted = 2 * ted / (len(decomp1) + len(decomp2) + ted) - similarity = max(similarity, 1 - normalized_ted) - else: - for decomp1 in decomps_1: - for decomp2 in decomps_2: - if not safe: - ed = edit_distance(decomp1, decomp2) - else: - ed = edit_distance(safe_encode_string(decomp1), safe_encode_string(decomp2)) - normalized_ed = ed / max(len(decomp1), len(decomp2)) - similarity = max(similarity, 1 - normalized_ed) - - return similarity - - def pronunciation_similarity(self, char1, char2): - """ - >>> c = CharFuncs('data/char_meta.txt') - >>> c.pronunciation_similarity('牛', '午') - 0.27999999999999997 - >>> c.pronunciation_similarity('由', '田') - 0.09 - - """ - assert char1 in self.data - assert char2 in self.data - pronunciations1 = self.data[char1]["pronunciation"] - pronunciations2 = self.data[char2]["pronunciation"] - - if pronunciations1[0] == 'null' or pronunciations2 == 'null': - return 0.0 - else: - - pronunciations1 = pronunciations1.split(';') # separate by lan - pronunciations2 = pronunciations2.split(';') # separate by lan - - similarity = 0.0 - count = 0 - for pron_lan1, pron_lan2 in zip(pronunciations1, pronunciations2): - if (pron_lan1 == 'null') or (pron_lan2 == 'null'): - pass - else: - similarity_lan = 0.0 - for p1 in pron_lan1.split(','): - for p2 in pron_lan2.split(','): - tmp_sim = 1 - edit_distance(p1, p2) / max(len(p1), len(p2)) - similarity_lan = max(similarity_lan, tmp_sim) - similarity += similarity_lan - count += 1 - - return similarity / count if count else 0 - - -def string_to_tree(string): - """ - This function converts ids string to a string that can be used as a tree input to APTED. - Any Error raised by this function implies that the input string is invalid. - >>> string_to_tree('⿱⿱⿰丿㇏⿰丿㇏⿱⿰丿㇏⿰丿㇏') # 炎 - '{⿱{⿱{⿰{丿}{㇏}}{⿰{丿}{㇏}}}{⿱{⿰{丿}{㇏}}{⿰{丿}{㇏}}}}' - >>> string_to_tree('⿱⿰丿㇏⿱一⿱⿻一丨一') # 全 - '{⿱{⿰{丿}{㇏}}{⿱{一}{⿱{⿻{一}{丨}}{一}}}}' - >>> string_to_tree('⿱⿰丿㇏⿻⿱一⿱⿻一丨一丷') # 金 - '{⿱{⿰{丿}{㇏}}{⿻{⿱{一}{⿱{⿻{一}{丨}}{一}}}{丷}}}' - >>> string_to_tree('⿻⿻⿻一丨一⿴⿱⿰丨𠃌一一') # 車 - '{⿻{⿻{⿻{一}{丨}}{一}}{⿴{⿱{⿰{丨}{𠃌}}{一}}{一}}}' - >>> string_to_tree('⿻⿻⿻一丨⿰丿㇏⿴⿱⿰丨𠃌一一') # 東 - '{⿻{⿻{⿻{一}{丨}}{⿰{丿}{㇏}}}{⿴{⿱{⿰{丨}{𠃌}}{一}}{一}}}' - >>> string_to_tree('丿') # 丿 - '{丿}' - >>> string_to_tree('⿻') # ⿻ - '{⿻}' - """ - if string[0] in IDCS and len(string) != 1: - bracket_stack = [] - tree = [] - - def add_brackets(num): - if num == 2: - bracket_stack.extend(['}', '{', '}']) - else: - bracket_stack.extend(['}', '{', '}', '{', '}']) - tree.append('{') - - global_just_put = '{' - - for c in string: - tree.append(c) - if c in IDCS: - assert global_just_put != '}' - add_brackets(IDCS[c]) - global_just_put = '{' - else: - just_put = '' - while just_put != '{' and bracket_stack: - just_put = bracket_stack.pop(-1) - tree.append(just_put) - global_just_put = just_put - - res = ''.join(tree) - assert res[-1] == '}' - else: - assert len(string) == 1 or string == 'null' - res = string[0] - - return '{' + res + '}' - - -def pinyin_map(standard_pinyin): - """ - >>> pinyin_map('xuě') - 'xue3' - >>> pinyin_map('xue') - 'xue' - >>> pinyin_map('lǜ') - 'lü4' - >>> pinyin_map('fá') - 'fa2' - """ - tone = '' - pinyin = '' - - assert ' ' not in standard_pinyin - for c in standard_pinyin: - if c in PINYIN: - pinyin += PINYIN[c][0] - assert tone == '' - tone = str(PINYIN[c][1]) - else: - pinyin += c - pinyin += tone - return pinyin - - -def parse_args(): - usage = '\n1. You can compute character similarity by:\n' \ - 'python char_sim.py 午 牛 年 千\n' \ - '\n' \ - '2. You can use ted in computing character similarity by:\n' \ - 'python char_sim.py 午 牛 年 千 -t\n' \ - '\n' - parser = argparse.ArgumentParser( - description='A script to compute Chinese character (Kanji) similarity', usage=usage) - - parser.add_argument('multiargs', nargs='*', type=str, default=None, - help='Chinese characters in question') - parser.add_argument('--ted', '-t', action="store_true", default=False, - help='True=to use tree edit distence (TED)' - 'False=to use string edit distance') - - args = parser.parse_args() - return args - - -if __name__ == '__main__': - args = parse_args() - c = CharFuncs('data/char_meta.txt') - if not args.ted: - for i, c1 in enumerate(args.multiargs): - for c2 in args.multiargs[i:]: - if c1 != c2: - print(f'For character pair ({c1}, {c2}):') - print(f' v-sim = {c.shape_similarity(c1, c2)}') - print(f' p-sim = {c.pronunciation_similarity(c1, c2)}\n') - else: - for i, c1 in enumerate(args.multiargs): - for c2 in args.multiargs[i:]: - if c1 != c2: - print(f'For character pair ({c1}, {c2}):') - print(f' v-sim = {c.shape_similarity(c1, c2, as_tree=True)}') +### Copy from https://github.com/iqiyi/FASPell ### + +""" +Requirements: + - java (required only if tree edit distance is used) + - numpy +""" +import numpy as np +from subprocess import Popen, PIPE, STDOUT +import os +import argparse + +IDCS = {'\u2ff0': 2, # 12 ideographic description characters and their capacity of son nodes + '\u2ff1': 2, + '\u2ff2': 3, + '\u2ff3': 3, + '\u2ff4': 2, + '\u2ff5': 2, + '\u2ff6': 2, + '\u2ff7': 2, + '\u2ff8': 2, + '\u2ff9': 2, + '\u2ffa': 2, + '\u2ffb': 2, } + +PINYIN = {'ā': ['a', 1], 'á': ['a', 2], 'ǎ': ['a', 3], 'à': ['a', 4], + 'ē': ['e', 1], 'é': ['e', 2], 'ě': ['e', 3], 'è': ['e', 4], + 'ī': ['i', 1], 'í': ['i', 2], 'ǐ': ['i', 3], 'ì': ['i', 4], + 'ō': ['o', 1], 'ó': ['o', 2], 'ǒ': ['o', 3], 'ò': ['o', 4], + 'ū': ['u', 1], 'ú': ['u', 2], 'ǔ': ['u', 3], 'ù': ['u', 4], + 'ǖ': ['ü', 1], 'ǘ': ['ü', 2], 'ǚ': ['ü', 3], 'ǜ': ['ü', 4], + '': ['m', 2], 'ń': ['n', 2], 'ň': ['n', 3], 'ǹ': ['n', 4], + } + +# APTED_JAR_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'apted.jar') +APTED_JAR_PATH = 'apted.jar' + + +def tree_edit_distance(tree_a, tree_b): + """ + We use APTED algorithm proposed by M. Pawlik and N. Augsten + github link: https://github.com/DatabaseGroup/apted + """ + p = Popen(['java', '-jar', APTED_JAR_PATH, '-t', tree_a, tree_b], stdout=PIPE, stderr=STDOUT) + + res = [line for line in p.stdout] + res = res[0] + res = res.strip() + res = float(res) + + return res + + +def edit_distance(string_a, string_b, name='Levenshtein'): + """ + >>> edit_distance('abcde', 'avbcude') + 2 + >>> edit_distance(['至', '刂'], ['亻', '至', '刂']) + 1 + >>> edit_distance('fang', 'qwe') + 4 + >>> edit_distance('fang', 'hen') + 3 + """ + size_x = len(string_a) + 1 + size_y = len(string_b) + 1 + matrix = np.zeros((size_x, size_y), dtype=int) + for x in range(size_x): + matrix[x, 0] = x + for y in range(size_y): + matrix[0, y] = y + + for x in range(1, size_x): + for y in range(1, size_y): + if string_a[x - 1] == string_b[y - 1]: + matrix[x, y] = min( + matrix[x - 1, y] + 1, + matrix[x - 1, y - 1], + matrix[x, y - 1] + 1 + ) + else: + if name == 'Levenshtein': + matrix[x, y] = min( + matrix[x - 1, y] + 1, + matrix[x - 1, y - 1] + 1, + matrix[x, y - 1] + 1 + ) + else: # Canonical + matrix[x, y] = min( + matrix[x - 1, y] + 1, + matrix[x - 1, y - 1] + 2, + matrix[x, y - 1] + 1 + ) + + return matrix[size_x - 1, size_y - 1] + + +class CharFuncs(object): + def __init__(self, char_meta_fname): + self.data = self.load_char_meta(char_meta_fname) + self.char_dict = dict([(c, 0) for c in self.data]) + + self.safe = {'\u2ff0': 'A', + # to eliminate the bug that, in Windows CMD, char ⿻ and ⿵ are encoded to be the same. + '\u2ff1': 'B', + '\u2ff2': 'C', + '\u2ff3': 'D', + '\u2ff4': 'E', + '\u2ff5': 'F', + '\u2ff6': 'G', + '\u2ff7': 'H', + '\u2ff8': 'I', + '\u2ff9': 'J', + '\u2ffa': 'L', + '\u2ffb': 'M', } + + @staticmethod + def load_char_meta(fname): + data = {} + f = open(fname, 'r', encoding='utf-8') + for line in f: + items = line.strip().split('\t') + code_point = items[0] + char = items[1] + pronunciation = items[2] + decompositions = items[3:] + assert char not in data + data[char] = {"code_point": code_point, "pronunciation": pronunciation, "decompositions": decompositions} + return data + + def shape_distance(self, char1, char2, safe=True, as_tree=False): + """ + >>> c = CharFuncs('data/char_meta.txt') + >>> c.shape_distance('田', '由') + 1 + >>> c.shape_distance('牛', '午') + 1 + """ + assert char1 in self.data + assert char2 in self.data + + def safe_encode(decomp): + tree = '' + for c in string_to_tree(decomp): + if c not in self.safe: + tree += c + else: + tree += self.safe[c] + return tree + + def safe_encode_string(decomp): + tree = '' + for c in decomp: + if c not in self.safe: + tree += c + else: + tree += self.safe[c] + return tree + + decomps_1 = self.data[char1]["decompositions"] + decomps_2 = self.data[char2]["decompositions"] + + distance = 1e5 + if as_tree: + for decomp1 in decomps_1: + for decomp2 in decomps_2: + if not safe: + ted = tree_edit_distance(string_to_tree(decomp1), string_to_tree(decomp2)) + else: + ted = tree_edit_distance(safe_encode(decomp1), safe_encode(decomp2)) + distance = min(distance, ted) + else: + for decomp1 in decomps_1: + for decomp2 in decomps_2: + if not safe: + ed = edit_distance(decomp1, decomp2) + else: + ed = edit_distance(safe_encode_string(decomp1), safe_encode_string(decomp2)) + distance = min(distance, ed) + + return distance + + def pronunciation_distance(self, char1, char2): + """ + >>> c = CharFuncs('data/char_meta.txt') + >>> c.pronunciation_distance('田', '由') + 3.4 + >>> c.pronunciation_distance('牛', '午') + 2.6 + """ + assert char1 in self.data + assert char2 in self.data + pronunciations1 = self.data[char1]["pronunciation"] + pronunciations2 = self.data[char2]["pronunciation"] + + if pronunciations1[0] == 'null' or pronunciations2 == 'null': + return 0.0 + else: + + pronunciations1 = pronunciations1.split(';') # separate by lan + pronunciations2 = pronunciations2.split(';') # separate by lan + + distance = 0.0 + count = 0 + for pron_lan1, pron_lan2 in zip(pronunciations1, pronunciations2): + if (pron_lan1 == 'null') or (pron_lan2 == 'null'): + pass + else: + distance_lan = 1e5 + for p1 in pron_lan1.split(','): + for p2 in pron_lan2.split(','): + distance_lan = min(distance_lan, edit_distance(p1, p2)) + distance += distance_lan + count += 1 + + return distance / count + + @staticmethod + def load_dict(fname): + data = {} + f = open(fname, 'r', encoding='utf-8') + for line in f: + char, freq = line.strip().split('\t') + assert char not in data + data[char] = freq + + return data + + def similarity(self, char1, char2, weights=(0.8, 0.2, 0.0), as_tree=False): + """ + this function returns weighted similarity. When used in FASPell, each weight can only be 0 or 1. + """ + + # assert char1 in self.char_dict + # assert char2 in self.char_dict + shape_w, sound_w, freq_w = weights + + if char1 in self.char_dict and char2 in self.char_dict: + + shape_sim = self.shape_similarity(char1, char2, as_tree=as_tree) + sound_sim = self.pronunciation_similarity(char1, char2) + freq_sim = 1.0 - self.char_dict[char2] / len(self.char_dict) + + return shape_sim * shape_w + sound_sim * sound_w + freq_sim * freq_w + else: + return 0.0 + + def shape_similarity(self, char1, char2, safe=True, as_tree=False): + """ + >>> c = CharFuncs('data/char_meta.txt') + >>> c.shape_similarity('牛', '午') + 0.8571428571428572 + >>> c.shape_similarity('田', '由') + 0.8888888888888888 + """ + assert char1 in self.data + assert char2 in self.data + + def safe_encode(decomp): + tree = '' + for c in string_to_tree(decomp): + if c not in self.safe: + tree += c + else: + tree += self.safe[c] + return tree + + def safe_encode_string(decomp): + tree = '' + for c in decomp: + if c not in self.safe: + tree += c + else: + tree += self.safe[c] + return tree + + decomps_1 = self.data[char1]["decompositions"] + decomps_2 = self.data[char2]["decompositions"] + + similarity = 0.0 + if as_tree: + for decomp1 in decomps_1: + for decomp2 in decomps_2: + if not safe: + ted = tree_edit_distance(string_to_tree(decomp1), string_to_tree(decomp2)) + else: + ted = tree_edit_distance(safe_encode(decomp1), safe_encode(decomp2)) + normalized_ted = 2 * ted / (len(decomp1) + len(decomp2) + ted) + similarity = max(similarity, 1 - normalized_ted) + else: + for decomp1 in decomps_1: + for decomp2 in decomps_2: + if not safe: + ed = edit_distance(decomp1, decomp2) + else: + ed = edit_distance(safe_encode_string(decomp1), safe_encode_string(decomp2)) + normalized_ed = ed / max(len(decomp1), len(decomp2)) + similarity = max(similarity, 1 - normalized_ed) + + return similarity + + def pronunciation_similarity(self, char1, char2): + """ + >>> c = CharFuncs('data/char_meta.txt') + >>> c.pronunciation_similarity('牛', '午') + 0.27999999999999997 + >>> c.pronunciation_similarity('由', '田') + 0.09 + + """ + assert char1 in self.data + assert char2 in self.data + pronunciations1 = self.data[char1]["pronunciation"] + pronunciations2 = self.data[char2]["pronunciation"] + + if pronunciations1[0] == 'null' or pronunciations2 == 'null': + return 0.0 + else: + + pronunciations1 = pronunciations1.split(';') # separate by lan + pronunciations2 = pronunciations2.split(';') # separate by lan + + similarity = 0.0 + count = 0 + for pron_lan1, pron_lan2 in zip(pronunciations1, pronunciations2): + if (pron_lan1 == 'null') or (pron_lan2 == 'null'): + pass + else: + similarity_lan = 0.0 + for p1 in pron_lan1.split(','): + for p2 in pron_lan2.split(','): + tmp_sim = 1 - edit_distance(p1, p2) / max(len(p1), len(p2)) + similarity_lan = max(similarity_lan, tmp_sim) + similarity += similarity_lan + count += 1 + + return similarity / count if count else 0 + + +def string_to_tree(string): + """ + This function converts ids string to a string that can be used as a tree input to APTED. + Any Error raised by this function implies that the input string is invalid. + >>> string_to_tree('⿱⿱⿰丿㇏⿰丿㇏⿱⿰丿㇏⿰丿㇏') # 炎 + '{⿱{⿱{⿰{丿}{㇏}}{⿰{丿}{㇏}}}{⿱{⿰{丿}{㇏}}{⿰{丿}{㇏}}}}' + >>> string_to_tree('⿱⿰丿㇏⿱一⿱⿻一丨一') # 全 + '{⿱{⿰{丿}{㇏}}{⿱{一}{⿱{⿻{一}{丨}}{一}}}}' + >>> string_to_tree('⿱⿰丿㇏⿻⿱一⿱⿻一丨一丷') # 金 + '{⿱{⿰{丿}{㇏}}{⿻{⿱{一}{⿱{⿻{一}{丨}}{一}}}{丷}}}' + >>> string_to_tree('⿻⿻⿻一丨一⿴⿱⿰丨𠃌一一') # 車 + '{⿻{⿻{⿻{一}{丨}}{一}}{⿴{⿱{⿰{丨}{𠃌}}{一}}{一}}}' + >>> string_to_tree('⿻⿻⿻一丨⿰丿㇏⿴⿱⿰丨𠃌一一') # 東 + '{⿻{⿻{⿻{一}{丨}}{⿰{丿}{㇏}}}{⿴{⿱{⿰{丨}{𠃌}}{一}}{一}}}' + >>> string_to_tree('丿') # 丿 + '{丿}' + >>> string_to_tree('⿻') # ⿻ + '{⿻}' + """ + if string[0] in IDCS and len(string) != 1: + bracket_stack = [] + tree = [] + + def add_brackets(num): + if num == 2: + bracket_stack.extend(['}', '{', '}']) + else: + bracket_stack.extend(['}', '{', '}', '{', '}']) + tree.append('{') + + global_just_put = '{' + + for c in string: + tree.append(c) + if c in IDCS: + assert global_just_put != '}' + add_brackets(IDCS[c]) + global_just_put = '{' + else: + just_put = '' + while just_put != '{' and bracket_stack: + just_put = bracket_stack.pop(-1) + tree.append(just_put) + global_just_put = just_put + + res = ''.join(tree) + assert res[-1] == '}' + else: + assert len(string) == 1 or string == 'null' + res = string[0] + + return '{' + res + '}' + + +def pinyin_map(standard_pinyin): + """ + >>> pinyin_map('xuě') + 'xue3' + >>> pinyin_map('xue') + 'xue' + >>> pinyin_map('lǜ') + 'lü4' + >>> pinyin_map('fá') + 'fa2' + """ + tone = '' + pinyin = '' + + assert ' ' not in standard_pinyin + for c in standard_pinyin: + if c in PINYIN: + pinyin += PINYIN[c][0] + assert tone == '' + tone = str(PINYIN[c][1]) + else: + pinyin += c + pinyin += tone + return pinyin + + +def parse_args(): + usage = '\n1. You can compute character similarity by:\n' \ + 'python char_sim.py 午 牛 年 千\n' \ + '\n' \ + '2. You can use ted in computing character similarity by:\n' \ + 'python char_sim.py 午 牛 年 千 -t\n' \ + '\n' + parser = argparse.ArgumentParser( + description='A script to compute Chinese character (Kanji) similarity', usage=usage) + + parser.add_argument('multiargs', nargs='*', type=str, default=None, + help='Chinese characters in question') + parser.add_argument('--ted', '-t', action="store_true", default=False, + help='True=to use tree edit distence (TED)' + 'False=to use string edit distance') + + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + c = CharFuncs('data/char_meta.txt') + if not args.ted: + for i, c1 in enumerate(args.multiargs): + for c2 in args.multiargs[i:]: + if c1 != c2: + print(f'For character pair ({c1}, {c2}):') + print(f' v-sim = {c.shape_similarity(c1, c2)}') + print(f' p-sim = {c.pronunciation_similarity(c1, c2)}\n') + else: + for i, c1 in enumerate(args.multiargs): + for c2 in args.multiargs[i:]: + if c1 != c2: + print(f'For character pair ({c1}, {c2}):') + print(f' v-sim = {c.shape_similarity(c1, c2, as_tree=True)}') print(f' p-sim = {c.pronunciation_similarity(c1, c2)}\n') \ No newline at end of file diff --git a/opencompass/datasets/lawbench/utils/compare_m2_for_evaluation.py b/opencompass/datasets/lawbench/utils/compare_m2_for_evaluation.py index 41f6e818..2e7567e8 100644 --- a/opencompass/datasets/lawbench/utils/compare_m2_for_evaluation.py +++ b/opencompass/datasets/lawbench/utils/compare_m2_for_evaluation.py @@ -1,433 +1,433 @@ -import argparse -from collections import Counter - -def main(): - # Parse command line args - args = parse_args() - # Open hypothesis and reference m2 files and split into chunks - hyp_m2 = open(args.hyp).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.hyp).read().strip().split("\n\n") - ref_m2 = open(args.ref).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.ref).read().strip().split("\n\n") - # Make sure they have the same number of sentences - assert len(hyp_m2) == len(ref_m2), print(len(hyp_m2), len(ref_m2)) - - # Store global corpus level best counts here - best_dict = Counter({"tp":0, "fp":0, "fn":0}) - best_cats = {} - # Process each sentence - sents = zip(hyp_m2, ref_m2) - for sent_id, sent in enumerate(sents): - # Simplify the edits into lists of lists - # if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons: - # sent_id_cons.append(sent_id) - src = sent[0].split("\n")[0] - hyp_edits = simplify_edits(sent[0], args.max_answer_num) - ref_edits = simplify_edits(sent[1], args.max_answer_num) - # Process the edits for detection/correction based on args - hyp_dict = process_edits(hyp_edits, args) - ref_dict = process_edits(ref_edits, args) - if args.reference_num is None or len(ref_dict.keys()) == args.reference_num: - # Evaluate edits and get best TP, FP, FN hyp+ref combo. - count_dict, cat_dict = evaluate_edits(src, - hyp_dict, ref_dict, best_dict, sent_id, args) - # Merge these dicts with best_dict and best_cats - best_dict += Counter(count_dict) - best_cats = merge_dict(best_cats, cat_dict) - # Print results - print_results(best_dict, best_cats, args) - -# Parse command line args -def parse_args(): - parser = argparse.ArgumentParser( - description="Calculate F-scores for error detection and/or correction.\n" - "Flags let you evaluate at different levels of granularity.", - formatter_class=argparse.RawTextHelpFormatter, - usage="%(prog)s [options] -hyp HYP -ref REF") - parser.add_argument( - "-hyp", - help="A hypothesis M2 file.", - required=True) - parser.add_argument( - "-ref", - help="A reference M2 file.", - required=True) - parser.add_argument( - "--start", - type=int, - default=None - ) - parser.add_argument( - "--end", - type=int, - default=None - ) - parser.add_argument( - "--max_answer_num", - type=int, - default=None - ) - parser.add_argument( - "--reference_num", - type=int, - default=None - ) - parser.add_argument( - "-b", - "--beta", - help="Value of beta in F-score. (default: 0.5)", - default=0.5, - type=float) - parser.add_argument( - "-v", - "--verbose", - help="Print verbose output.", - action="store_true") - eval_type = parser.add_mutually_exclusive_group() - eval_type.add_argument( - "-dt", - help="Evaluate Detection in terms of Tokens.", - action="store_true") - eval_type.add_argument( - "-ds", - help="Evaluate Detection in terms of Spans.", - action="store_true") - eval_type.add_argument( - "-cs", - help="Evaluate Correction in terms of Spans. (default)", - action="store_true") - eval_type.add_argument( - "-cse", - help="Evaluate Correction in terms of Spans and Error types.", - action="store_true") - parser.add_argument( - "-single", - help="Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1", - action="store_true") - parser.add_argument( - "-multi", - help="Only evaluate multi token edits; i.e. 2+:n or n:2+", - action="store_true") - parser.add_argument( - "-multi_hyp_avg", - help="When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence.", - action="store_true") # For IAA calculation - parser.add_argument( - "-multi_hyp_max", - help="When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence.", - action="store_true") # For multiple hypotheses system evaluation - parser.add_argument( - "-filt", - help="Do not evaluate the specified error types.", - nargs="+", - default=[]) - parser.add_argument( - "-cat", - help="Show error category scores.\n" - "1: Only show operation tier scores; e.g. R.\n" - "2: Only show main tier scores; e.g. NOUN.\n" - "3: Show all category scores; e.g. R:NOUN.", - choices=[1, 2, 3], - type=int) - args = parser.parse_args() - return args - -# Input: An m2 format sentence with edits. -# Output: A list of lists. Each edit: [start, end, cat, cor, coder] -def simplify_edits(sent, max_answer_num): - out_edits = [] - # Get the edit lines from an m2 block. - edits = sent.split("\n") - # Loop through the edits - for edit in edits: - # Preprocessing - if edit.startswith("A "): - edit = edit[2:].split("|||") # Ignore "A " then split. - span = edit[0].split() - start = int(span[0]) - end = int(span[1]) - cat = edit[1] - cor = edit[2].replace(" ", "") - coder = int(edit[-1]) - out_edit = [start, end, cat, cor, coder] - out_edits.append(out_edit) - # return [edit for edit in out_edits if edit[-1] in [0,1]] - if max_answer_num is None: - return out_edits - elif max_answer_num == 1: - return [edit for edit in out_edits if edit[-1] == 0] - elif max_answer_num == 2: - return [edit for edit in out_edits if edit[-1] in [0, 1]] - elif max_answer_num == 3: - return [edit for edit in out_edits if edit[-1] in [0, 1, 2]] - -# Input 1: A list of edits. Each edit: [start, end, cat, cor, coder] -# Input 2: Command line args -# Output: A dict; key is coder, value is edit dict. -def process_edits(edits, args): - coder_dict = {} - # Add an explicit noop edit if there are no edits. - if not edits: edits = [[-1, -1, "noop", "-NONE-", 0]] - # Loop through the edits - for edit in edits: - # Name the edit elements for clarity - start = edit[0] - end = edit[1] - cat = edit[2] - cor = edit[3] - coder = edit[4] - # Add the coder to the coder_dict if necessary - if coder not in coder_dict: coder_dict[coder] = {} - - # Optionally apply filters based on args - # 1. UNK type edits are only useful for detection, not correction. - if not args.dt and not args.ds and cat == "UNK": continue - # 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1 - if args.single and (end-start >= 2 or len(cor.split()) >= 2): continue - # 3. Only evaluate multi token edits; i.e. 2+:n or n:2+ - if args.multi and end-start < 2 and len(cor.split()) < 2: continue - # 4. If there is a filter, ignore the specified error types - if args.filt and cat in args.filt: continue - - # Token Based Detection - if args.dt: - # Preserve noop edits. - if start == -1: - if (start, start) in coder_dict[coder].keys(): - coder_dict[coder][(start, start)].append(cat) - else: - coder_dict[coder][(start, start)] = [cat] - # Insertions defined as affecting the token on the right - elif start == end and start >= 0: - if (start, start+1) in coder_dict[coder].keys(): - coder_dict[coder][(start, start+1)].append(cat) - else: - coder_dict[coder][(start, start+1)] = [cat] - # Edit spans are split for each token in the range. - else: - for tok_id in range(start, end): - if (tok_id, tok_id+1) in coder_dict[coder].keys(): - coder_dict[coder][(tok_id, tok_id+1)].append(cat) - else: - coder_dict[coder][(tok_id, tok_id+1)] = [cat] - - # Span Based Detection - elif args.ds: - if (start, end) in coder_dict[coder].keys(): - coder_dict[coder][(start, end)].append(cat) - else: - coder_dict[coder][(start, end)] = [cat] - - # Span Based Correction - else: - # With error type classification - if args.cse: - if (start, end, cat, cor) in coder_dict[coder].keys(): - coder_dict[coder][(start, end, cat, cor)].append(cat) - else: - coder_dict[coder][(start, end, cat, cor)] = [cat] - # Without error type classification - else: - if (start, end, cor) in coder_dict[coder].keys(): - coder_dict[coder][(start, end, cor)].append(cat) - else: - coder_dict[coder][(start, end, cor)] = [cat] - return coder_dict - -# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits. -# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits. -# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far. -# Input 4: Sentence ID (for verbose output only) -# Input 5: Command line args -# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence. -# Output 2: The corresponding error type dict for the above dict. -def evaluate_edits(src, hyp_dict, ref_dict, best, sent_id, args): - # Store the best sentence level scores and hyp+ref combination IDs - # best_f is initialised as -1 cause 0 is a valid result. - best_tp, best_fp, best_fn, best_f, best_hyp, best_ref = 0, 0, 0, -1, 0, 0 - best_cat = {} - # skip not annotatable sentence - if len(ref_dict.keys()) == 1: - ref_id = list(ref_dict.keys())[0] - if len(ref_dict[ref_id].keys()) == 1: - cat = list(ref_dict[ref_id].values())[0][0] - if cat == "NA": - best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn} - return best_dict, best_cat - - # Compare each hyp and ref combination - for hyp_id in hyp_dict.keys(): - for ref_id in ref_dict.keys(): - # Get the local counts for the current combination. - tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id]) - # Compute the local sentence scores (for verbose output only) - loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta) - # Compute the global sentence scores - p, r, f = computeFScore( - tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta) - # Save the scores if they are better in terms of: - # 1. Higher F-score - # 2. Same F-score, higher TP - # 3. Same F-score and TP, lower FP - # 4. Same F-score, TP and FP, lower FN - if (f > best_f) or \ - (f == best_f and tp > best_tp) or \ - (f == best_f and tp == best_tp and fp < best_fp) or \ - (f == best_f and tp == best_tp and fp == best_fp and fn < best_fn): - best_tp, best_fp, best_fn = tp, fp, fn - best_f, best_hyp, best_ref = f, hyp_id, ref_id - best_cat = cat_dict - # Verbose output - if args.verbose: - # Prepare verbose output edits. - hyp_verb = list(sorted(hyp_dict[hyp_id].keys())) - ref_verb = list(sorted(ref_dict[ref_id].keys())) - # Ignore noop edits - if not hyp_verb or hyp_verb[0][0] == -1: hyp_verb = [] - if not ref_verb or ref_verb[0][0] == -1: ref_verb = [] - # Print verbose info - print('{:-^40}'.format("")) - print("SENTENCE "+str(sent_id)+src[1:]) - print('{:-^40}'.format("")) - print("SENTENCE "+str(sent_id)+" - HYP "+str(hyp_id)+" - REF "+str(ref_id)) - print("HYPOTHESIS EDITS :", hyp_verb) - print("REFERENCE EDITS :", ref_verb) - print("Local TP/FP/FN :", str(tp), str(fp), str(fn)) - print("Local P/R/F"+str(args.beta)+" :", str(loc_p), str(loc_r), str(loc_f)) - print("Global TP/FP/FN :", str(tp+best["tp"]), str(fp+best["fp"]), str(fn+best["fn"])) - print("Global P/R/F"+str(args.beta)+" :", str(p), str(r), str(f)) - # Verbose output: display the best hyp+ref combination - if args.verbose: - print('{:-^40}'.format("")) - print("^^ HYP "+str(best_hyp)+", REF "+str(best_ref)+" chosen for sentence "+str(sent_id)) - # Save the best TP, FP and FNs as a dict, and return this and the best_cat dict - best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn} - return best_dict, best_cat - -# Input 1: A dictionary of hypothesis edits for a single system. -# Input 2: A dictionary of reference edits for a single annotator. -# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator. -# Output 4: A dictionary of the error type counts. -def compareEdits(hyp_edits, ref_edits): - tp = 0 # True Positives - fp = 0 # False Positives - fn = 0 # False Negatives - cat_dict = {} # {cat: [tp, fp, fn], ...} - - for h_edit, h_cats in hyp_edits.items(): - # noop hyp edits cannot be TP or FP - if h_cats[0] == "noop": continue - # TRUE POSITIVES - if h_edit in ref_edits.keys(): - # On occasion, multiple tokens at same span. - for h_cat in ref_edits[h_edit]: # Use ref dict for TP - tp += 1 - # Each dict value [TP, FP, FN] - if h_cat in cat_dict.keys(): - cat_dict[h_cat][0] += 1 - else: - cat_dict[h_cat] = [1, 0, 0] - # FALSE POSITIVES - else: - # On occasion, multiple tokens at same span. - for h_cat in h_cats: - fp += 1 - # Each dict value [TP, FP, FN] - if h_cat in cat_dict.keys(): - cat_dict[h_cat][1] += 1 - else: - cat_dict[h_cat] = [0, 1, 0] - for r_edit, r_cats in ref_edits.items(): - # noop ref edits cannot be FN - if r_cats[0] == "noop": continue - # FALSE NEGATIVES - if r_edit not in hyp_edits.keys(): - # On occasion, multiple tokens at same span. - for r_cat in r_cats: - fn += 1 - # Each dict value [TP, FP, FN] - if r_cat in cat_dict.keys(): - cat_dict[r_cat][2] += 1 - else: - cat_dict[r_cat] = [0, 0, 1] - return tp, fp, fn, cat_dict - -# Input 1-3: True positives, false positives, false negatives -# Input 4: Value of beta in F-score. -# Output 1-3: Precision, Recall and F-score rounded to 4dp. -def computeFScore(tp, fp, fn, beta): - p = float(tp)/(tp+fp) if fp else 1.0 - r = float(tp)/(tp+fn) if fn else 1.0 - f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0 - return round(p, 4), round(r, 4), round(f, 4) - -# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN. -# Output: The dictionaries combined with cumulative TP, FP, FN. -def merge_dict(dict1, dict2): - for cat, stats in dict2.items(): - if cat in dict1.keys(): - dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)] - else: - dict1[cat] = stats - return dict1 - -# Input 1: A dict; key is error cat, value is counts for [tp, fp, fn] -# Input 2: Integer value denoting level of error category granularity. -# 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything. -# Output: A dictionary of category TP, FP and FN based on Input 2. -def processCategories(cat_dict, setting): - # Otherwise, do some processing. - proc_cat_dict = {} - for cat, cnt in cat_dict.items(): - if cat == "UNK": - proc_cat_dict[cat] = cnt - continue - # M, U, R or UNK combined only. - if setting == 1: - if cat[0] in proc_cat_dict.keys(): - proc_cat_dict[cat[0]] = [x+y for x, y in zip(proc_cat_dict[cat[0]], cnt)] - else: - proc_cat_dict[cat[0]] = cnt - # Everything without M, U or R. - elif setting == 2: - if cat[2:] in proc_cat_dict.keys(): - proc_cat_dict[cat[2:]] = [x+y for x, y in zip(proc_cat_dict[cat[2:]], cnt)] - else: - proc_cat_dict[cat[2:]] = cnt - # All error category combinations - else: - return cat_dict - return proc_cat_dict - -# Input 1: A dict of global best TP, FP and FNs -# Input 2: A dict of error types and counts for those TP, FP and FNs -# Input 3: Command line args -def print_results(best, best_cats, args): - # Prepare output title. - if args.dt: title = " Token-Based Detection " - elif args.ds: title = " Span-Based Detection " - elif args.cse: title = " Span-Based Correction + Classification " - else: title = " Span-Based Correction " - - # Category Scores - if args.cat: - best_cats = processCategories(best_cats, args.cat) - print("") - print('{:=^66}'.format(title)) - print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8), - "P".ljust(8), "R".ljust(8), "F"+str(args.beta)) - for cat, cnts in sorted(best_cats.items()): - cat_p, cat_r, cat_f = computeFScore(cnts[0], cnts[1], cnts[2], args.beta) - print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8), - str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f) - - # Print the overall results. - print("") - print('{:=^46}'.format(title)) - print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)])) - print("\t".join(map(str, [best["tp"], best["fp"], - best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta))))) - print('{:=^46}'.format("")) - print("") - -if __name__ == "__main__": - # Run the program - main() +import argparse +from collections import Counter + +def main(): + # Parse command line args + args = parse_args() + # Open hypothesis and reference m2 files and split into chunks + hyp_m2 = open(args.hyp).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.hyp).read().strip().split("\n\n") + ref_m2 = open(args.ref).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.ref).read().strip().split("\n\n") + # Make sure they have the same number of sentences + assert len(hyp_m2) == len(ref_m2), print(len(hyp_m2), len(ref_m2)) + + # Store global corpus level best counts here + best_dict = Counter({"tp":0, "fp":0, "fn":0}) + best_cats = {} + # Process each sentence + sents = zip(hyp_m2, ref_m2) + for sent_id, sent in enumerate(sents): + # Simplify the edits into lists of lists + # if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons: + # sent_id_cons.append(sent_id) + src = sent[0].split("\n")[0] + hyp_edits = simplify_edits(sent[0], args.max_answer_num) + ref_edits = simplify_edits(sent[1], args.max_answer_num) + # Process the edits for detection/correction based on args + hyp_dict = process_edits(hyp_edits, args) + ref_dict = process_edits(ref_edits, args) + if args.reference_num is None or len(ref_dict.keys()) == args.reference_num: + # Evaluate edits and get best TP, FP, FN hyp+ref combo. + count_dict, cat_dict = evaluate_edits(src, + hyp_dict, ref_dict, best_dict, sent_id, args) + # Merge these dicts with best_dict and best_cats + best_dict += Counter(count_dict) + best_cats = merge_dict(best_cats, cat_dict) + # Print results + print_results(best_dict, best_cats, args) + +# Parse command line args +def parse_args(): + parser = argparse.ArgumentParser( + description="Calculate F-scores for error detection and/or correction.\n" + "Flags let you evaluate at different levels of granularity.", + formatter_class=argparse.RawTextHelpFormatter, + usage="%(prog)s [options] -hyp HYP -ref REF") + parser.add_argument( + "-hyp", + help="A hypothesis M2 file.", + required=True) + parser.add_argument( + "-ref", + help="A reference M2 file.", + required=True) + parser.add_argument( + "--start", + type=int, + default=None + ) + parser.add_argument( + "--end", + type=int, + default=None + ) + parser.add_argument( + "--max_answer_num", + type=int, + default=None + ) + parser.add_argument( + "--reference_num", + type=int, + default=None + ) + parser.add_argument( + "-b", + "--beta", + help="Value of beta in F-score. (default: 0.5)", + default=0.5, + type=float) + parser.add_argument( + "-v", + "--verbose", + help="Print verbose output.", + action="store_true") + eval_type = parser.add_mutually_exclusive_group() + eval_type.add_argument( + "-dt", + help="Evaluate Detection in terms of Tokens.", + action="store_true") + eval_type.add_argument( + "-ds", + help="Evaluate Detection in terms of Spans.", + action="store_true") + eval_type.add_argument( + "-cs", + help="Evaluate Correction in terms of Spans. (default)", + action="store_true") + eval_type.add_argument( + "-cse", + help="Evaluate Correction in terms of Spans and Error types.", + action="store_true") + parser.add_argument( + "-single", + help="Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1", + action="store_true") + parser.add_argument( + "-multi", + help="Only evaluate multi token edits; i.e. 2+:n or n:2+", + action="store_true") + parser.add_argument( + "-multi_hyp_avg", + help="When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence.", + action="store_true") # For IAA calculation + parser.add_argument( + "-multi_hyp_max", + help="When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence.", + action="store_true") # For multiple hypotheses system evaluation + parser.add_argument( + "-filt", + help="Do not evaluate the specified error types.", + nargs="+", + default=[]) + parser.add_argument( + "-cat", + help="Show error category scores.\n" + "1: Only show operation tier scores; e.g. R.\n" + "2: Only show main tier scores; e.g. NOUN.\n" + "3: Show all category scores; e.g. R:NOUN.", + choices=[1, 2, 3], + type=int) + args = parser.parse_args() + return args + +# Input: An m2 format sentence with edits. +# Output: A list of lists. Each edit: [start, end, cat, cor, coder] +def simplify_edits(sent, max_answer_num): + out_edits = [] + # Get the edit lines from an m2 block. + edits = sent.split("\n") + # Loop through the edits + for edit in edits: + # Preprocessing + if edit.startswith("A "): + edit = edit[2:].split("|||") # Ignore "A " then split. + span = edit[0].split() + start = int(span[0]) + end = int(span[1]) + cat = edit[1] + cor = edit[2].replace(" ", "") + coder = int(edit[-1]) + out_edit = [start, end, cat, cor, coder] + out_edits.append(out_edit) + # return [edit for edit in out_edits if edit[-1] in [0,1]] + if max_answer_num is None: + return out_edits + elif max_answer_num == 1: + return [edit for edit in out_edits if edit[-1] == 0] + elif max_answer_num == 2: + return [edit for edit in out_edits if edit[-1] in [0, 1]] + elif max_answer_num == 3: + return [edit for edit in out_edits if edit[-1] in [0, 1, 2]] + +# Input 1: A list of edits. Each edit: [start, end, cat, cor, coder] +# Input 2: Command line args +# Output: A dict; key is coder, value is edit dict. +def process_edits(edits, args): + coder_dict = {} + # Add an explicit noop edit if there are no edits. + if not edits: edits = [[-1, -1, "noop", "-NONE-", 0]] + # Loop through the edits + for edit in edits: + # Name the edit elements for clarity + start = edit[0] + end = edit[1] + cat = edit[2] + cor = edit[3] + coder = edit[4] + # Add the coder to the coder_dict if necessary + if coder not in coder_dict: coder_dict[coder] = {} + + # Optionally apply filters based on args + # 1. UNK type edits are only useful for detection, not correction. + if not args.dt and not args.ds and cat == "UNK": continue + # 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1 + if args.single and (end-start >= 2 or len(cor.split()) >= 2): continue + # 3. Only evaluate multi token edits; i.e. 2+:n or n:2+ + if args.multi and end-start < 2 and len(cor.split()) < 2: continue + # 4. If there is a filter, ignore the specified error types + if args.filt and cat in args.filt: continue + + # Token Based Detection + if args.dt: + # Preserve noop edits. + if start == -1: + if (start, start) in coder_dict[coder].keys(): + coder_dict[coder][(start, start)].append(cat) + else: + coder_dict[coder][(start, start)] = [cat] + # Insertions defined as affecting the token on the right + elif start == end and start >= 0: + if (start, start+1) in coder_dict[coder].keys(): + coder_dict[coder][(start, start+1)].append(cat) + else: + coder_dict[coder][(start, start+1)] = [cat] + # Edit spans are split for each token in the range. + else: + for tok_id in range(start, end): + if (tok_id, tok_id+1) in coder_dict[coder].keys(): + coder_dict[coder][(tok_id, tok_id+1)].append(cat) + else: + coder_dict[coder][(tok_id, tok_id+1)] = [cat] + + # Span Based Detection + elif args.ds: + if (start, end) in coder_dict[coder].keys(): + coder_dict[coder][(start, end)].append(cat) + else: + coder_dict[coder][(start, end)] = [cat] + + # Span Based Correction + else: + # With error type classification + if args.cse: + if (start, end, cat, cor) in coder_dict[coder].keys(): + coder_dict[coder][(start, end, cat, cor)].append(cat) + else: + coder_dict[coder][(start, end, cat, cor)] = [cat] + # Without error type classification + else: + if (start, end, cor) in coder_dict[coder].keys(): + coder_dict[coder][(start, end, cor)].append(cat) + else: + coder_dict[coder][(start, end, cor)] = [cat] + return coder_dict + +# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits. +# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits. +# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far. +# Input 4: Sentence ID (for verbose output only) +# Input 5: Command line args +# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence. +# Output 2: The corresponding error type dict for the above dict. +def evaluate_edits(src, hyp_dict, ref_dict, best, sent_id, args): + # Store the best sentence level scores and hyp+ref combination IDs + # best_f is initialised as -1 cause 0 is a valid result. + best_tp, best_fp, best_fn, best_f, best_hyp, best_ref = 0, 0, 0, -1, 0, 0 + best_cat = {} + # skip not annotatable sentence + if len(ref_dict.keys()) == 1: + ref_id = list(ref_dict.keys())[0] + if len(ref_dict[ref_id].keys()) == 1: + cat = list(ref_dict[ref_id].values())[0][0] + if cat == "NA": + best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn} + return best_dict, best_cat + + # Compare each hyp and ref combination + for hyp_id in hyp_dict.keys(): + for ref_id in ref_dict.keys(): + # Get the local counts for the current combination. + tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id]) + # Compute the local sentence scores (for verbose output only) + loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta) + # Compute the global sentence scores + p, r, f = computeFScore( + tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta) + # Save the scores if they are better in terms of: + # 1. Higher F-score + # 2. Same F-score, higher TP + # 3. Same F-score and TP, lower FP + # 4. Same F-score, TP and FP, lower FN + if (f > best_f) or \ + (f == best_f and tp > best_tp) or \ + (f == best_f and tp == best_tp and fp < best_fp) or \ + (f == best_f and tp == best_tp and fp == best_fp and fn < best_fn): + best_tp, best_fp, best_fn = tp, fp, fn + best_f, best_hyp, best_ref = f, hyp_id, ref_id + best_cat = cat_dict + # Verbose output + if args.verbose: + # Prepare verbose output edits. + hyp_verb = list(sorted(hyp_dict[hyp_id].keys())) + ref_verb = list(sorted(ref_dict[ref_id].keys())) + # Ignore noop edits + if not hyp_verb or hyp_verb[0][0] == -1: hyp_verb = [] + if not ref_verb or ref_verb[0][0] == -1: ref_verb = [] + # Print verbose info + print('{:-^40}'.format("")) + print("SENTENCE "+str(sent_id)+src[1:]) + print('{:-^40}'.format("")) + print("SENTENCE "+str(sent_id)+" - HYP "+str(hyp_id)+" - REF "+str(ref_id)) + print("HYPOTHESIS EDITS :", hyp_verb) + print("REFERENCE EDITS :", ref_verb) + print("Local TP/FP/FN :", str(tp), str(fp), str(fn)) + print("Local P/R/F"+str(args.beta)+" :", str(loc_p), str(loc_r), str(loc_f)) + print("Global TP/FP/FN :", str(tp+best["tp"]), str(fp+best["fp"]), str(fn+best["fn"])) + print("Global P/R/F"+str(args.beta)+" :", str(p), str(r), str(f)) + # Verbose output: display the best hyp+ref combination + if args.verbose: + print('{:-^40}'.format("")) + print("^^ HYP "+str(best_hyp)+", REF "+str(best_ref)+" chosen for sentence "+str(sent_id)) + # Save the best TP, FP and FNs as a dict, and return this and the best_cat dict + best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn} + return best_dict, best_cat + +# Input 1: A dictionary of hypothesis edits for a single system. +# Input 2: A dictionary of reference edits for a single annotator. +# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator. +# Output 4: A dictionary of the error type counts. +def compareEdits(hyp_edits, ref_edits): + tp = 0 # True Positives + fp = 0 # False Positives + fn = 0 # False Negatives + cat_dict = {} # {cat: [tp, fp, fn], ...} + + for h_edit, h_cats in hyp_edits.items(): + # noop hyp edits cannot be TP or FP + if h_cats[0] == "noop": continue + # TRUE POSITIVES + if h_edit in ref_edits.keys(): + # On occasion, multiple tokens at same span. + for h_cat in ref_edits[h_edit]: # Use ref dict for TP + tp += 1 + # Each dict value [TP, FP, FN] + if h_cat in cat_dict.keys(): + cat_dict[h_cat][0] += 1 + else: + cat_dict[h_cat] = [1, 0, 0] + # FALSE POSITIVES + else: + # On occasion, multiple tokens at same span. + for h_cat in h_cats: + fp += 1 + # Each dict value [TP, FP, FN] + if h_cat in cat_dict.keys(): + cat_dict[h_cat][1] += 1 + else: + cat_dict[h_cat] = [0, 1, 0] + for r_edit, r_cats in ref_edits.items(): + # noop ref edits cannot be FN + if r_cats[0] == "noop": continue + # FALSE NEGATIVES + if r_edit not in hyp_edits.keys(): + # On occasion, multiple tokens at same span. + for r_cat in r_cats: + fn += 1 + # Each dict value [TP, FP, FN] + if r_cat in cat_dict.keys(): + cat_dict[r_cat][2] += 1 + else: + cat_dict[r_cat] = [0, 0, 1] + return tp, fp, fn, cat_dict + +# Input 1-3: True positives, false positives, false negatives +# Input 4: Value of beta in F-score. +# Output 1-3: Precision, Recall and F-score rounded to 4dp. +def computeFScore(tp, fp, fn, beta): + p = float(tp)/(tp+fp) if fp else 1.0 + r = float(tp)/(tp+fn) if fn else 1.0 + f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0 + return round(p, 4), round(r, 4), round(f, 4) + +# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN. +# Output: The dictionaries combined with cumulative TP, FP, FN. +def merge_dict(dict1, dict2): + for cat, stats in dict2.items(): + if cat in dict1.keys(): + dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)] + else: + dict1[cat] = stats + return dict1 + +# Input 1: A dict; key is error cat, value is counts for [tp, fp, fn] +# Input 2: Integer value denoting level of error category granularity. +# 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything. +# Output: A dictionary of category TP, FP and FN based on Input 2. +def processCategories(cat_dict, setting): + # Otherwise, do some processing. + proc_cat_dict = {} + for cat, cnt in cat_dict.items(): + if cat == "UNK": + proc_cat_dict[cat] = cnt + continue + # M, U, R or UNK combined only. + if setting == 1: + if cat[0] in proc_cat_dict.keys(): + proc_cat_dict[cat[0]] = [x+y for x, y in zip(proc_cat_dict[cat[0]], cnt)] + else: + proc_cat_dict[cat[0]] = cnt + # Everything without M, U or R. + elif setting == 2: + if cat[2:] in proc_cat_dict.keys(): + proc_cat_dict[cat[2:]] = [x+y for x, y in zip(proc_cat_dict[cat[2:]], cnt)] + else: + proc_cat_dict[cat[2:]] = cnt + # All error category combinations + else: + return cat_dict + return proc_cat_dict + +# Input 1: A dict of global best TP, FP and FNs +# Input 2: A dict of error types and counts for those TP, FP and FNs +# Input 3: Command line args +def print_results(best, best_cats, args): + # Prepare output title. + if args.dt: title = " Token-Based Detection " + elif args.ds: title = " Span-Based Detection " + elif args.cse: title = " Span-Based Correction + Classification " + else: title = " Span-Based Correction " + + # Category Scores + if args.cat: + best_cats = processCategories(best_cats, args.cat) + print("") + print('{:=^66}'.format(title)) + print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8), + "P".ljust(8), "R".ljust(8), "F"+str(args.beta)) + for cat, cnts in sorted(best_cats.items()): + cat_p, cat_r, cat_f = computeFScore(cnts[0], cnts[1], cnts[2], args.beta) + print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8), + str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f) + + # Print the overall results. + print("") + print('{:=^46}'.format(title)) + print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)])) + print("\t".join(map(str, [best["tp"], best["fp"], + best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta))))) + print('{:=^46}'.format("")) + print("") + +if __name__ == "__main__": + # Run the program + main() diff --git a/opencompass/datasets/lawbench/utils/function_utils.py b/opencompass/datasets/lawbench/utils/function_utils.py index a5c469a2..e4c6659d 100644 --- a/opencompass/datasets/lawbench/utils/function_utils.py +++ b/opencompass/datasets/lawbench/utils/function_utils.py @@ -1,49 +1,49 @@ -from rouge_chinese import Rouge -import jieba -from nltk.translate.gleu_score import corpus_gleu - -def compute_f1_two_sets(pred_set, gt_set): - precision = len(pred_set.intersection(gt_set)) / len(pred_set) if len(pred_set) > 0 else 0 - recall = len(pred_set.intersection(gt_set)) / len(gt_set) if len(gt_set) > 0 else 0 - f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0 - return f1 - -def multi_choice_judge(prediction, option_list, answer_token): - # a dict, key: letters in the option list, value: count of the letter in the prediction - count_dict, abstention, accuracy = {}, 0, 0 - for option in option_list: - option_count = prediction.count(option) - count_dict[option] = 1 if option_count > 0 else 0 # multiple occurrence of the same letter is counted as 1 - - if sum(count_dict.values()) == 0: - abstention = 1 - # if the answer token is the only predicted token, the prediction is correct - elif count_dict[answer_token] == 1 and sum(count_dict.values()) == 1: - accuracy = 1 - return {"score": accuracy, "abstention": abstention} - -""" -compute the rouge score. -hyps and refs are lists of hyposisis and reference strings -empty predictions are replaces with 无内容 -""" - - -def compute_rouge(hyps, refs): - assert(len(hyps) == len(refs)) - hyps = [' '.join(jieba.cut(h)) for h in hyps] - hyps = [h if h.strip() != "" else "无内容" for h in hyps] - refs = [' '.join(jieba.cut(r)) for r in refs] - return Rouge().get_scores(hyps, refs) - -""" -compute the gleu score. -hyps and refs are lists of hyposisis and reference strings -empty predictions are replaces with 无内容 -""" -def compute_gleu(hyps, refs): - assert(len(hyps) == len(refs)) - hyps = [' '.join(jieba.cut(h)) for h in hyps] - hyps = [h if h.strip() != "" else "无内容" for h in hyps] - refs = [[' '.join(jieba.cut(r))] for r in refs] - return corpus_gleu(refs, hyps) +from rouge_chinese import Rouge +import jieba +from nltk.translate.gleu_score import corpus_gleu + +def compute_f1_two_sets(pred_set, gt_set): + precision = len(pred_set.intersection(gt_set)) / len(pred_set) if len(pred_set) > 0 else 0 + recall = len(pred_set.intersection(gt_set)) / len(gt_set) if len(gt_set) > 0 else 0 + f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0 + return f1 + +def multi_choice_judge(prediction, option_list, answer_token): + # a dict, key: letters in the option list, value: count of the letter in the prediction + count_dict, abstention, accuracy = {}, 0, 0 + for option in option_list: + option_count = prediction.count(option) + count_dict[option] = 1 if option_count > 0 else 0 # multiple occurrence of the same letter is counted as 1 + + if sum(count_dict.values()) == 0: + abstention = 1 + # if the answer token is the only predicted token, the prediction is correct + elif count_dict[answer_token] == 1 and sum(count_dict.values()) == 1: + accuracy = 1 + return {"score": accuracy, "abstention": abstention} + +""" +compute the rouge score. +hyps and refs are lists of hyposisis and reference strings +empty predictions are replaces with 无内容 +""" + + +def compute_rouge(hyps, refs): + assert(len(hyps) == len(refs)) + hyps = [' '.join(jieba.cut(h)) for h in hyps] + hyps = [h if h.strip() != "" else "无内容" for h in hyps] + refs = [' '.join(jieba.cut(r)) for r in refs] + return Rouge().get_scores(hyps, refs) + +""" +compute the gleu score. +hyps and refs are lists of hyposisis and reference strings +empty predictions are replaces with 无内容 +""" +def compute_gleu(hyps, refs): + assert(len(hyps) == len(refs)) + hyps = [' '.join(jieba.cut(h)) for h in hyps] + hyps = [h if h.strip() != "" else "无内容" for h in hyps] + refs = [[' '.join(jieba.cut(r))] for r in refs] + return corpus_gleu(refs, hyps) diff --git a/opencompass/datasets/lawbench/utils/modules/alignment.py b/opencompass/datasets/lawbench/utils/modules/alignment.py index 5549ae2b..cee5124c 100644 --- a/opencompass/datasets/lawbench/utils/modules/alignment.py +++ b/opencompass/datasets/lawbench/utils/modules/alignment.py @@ -1,334 +1,334 @@ -import numpy as np -from typing import List, Tuple, Dict -from modules.tokenizer import Tokenizer -import os -from string import punctuation - -REAL_PATH = os.path.split(os.path.realpath(__file__))[0] -chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏" -english_punct = punctuation -punct = chinese_punct + english_punct - -def check_all_chinese(word): - """ - 判断一个单词是否全部由中文组成 - :param word: - :return: - """ - return all(['\u4e00' <= ch <= '\u9fff' for ch in word]) - -def read_cilin(): - """ - Cilin 詞林 is a thesaurus with semantic information - """ - # TODO -- fix this path - project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path - lines = open(os.path.join(project_dir, "data", "cilin.txt"), "r", encoding="gbk").read().strip().split("\n") - semantic_dict = {} - semantic_classes = {} - for line in lines: - code, *words = line.split(" ") - for word in words: - semantic_dict[word] = code - # make reverse dict - if code in semantic_classes: - semantic_classes[code] += words - else: - semantic_classes[code] = words - return semantic_dict, semantic_classes - - -def read_confusion(): - confusion_dict = {} - project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path - with open(os.path.join(project_dir, "data", "confusion_dict.txt"), "r", encoding="utf-8") as f: - for line in f: - li = line.rstrip('\n').split(" ") - confusion_dict[li[0]] = li[1:] - return confusion_dict - -class Alignment: - """ - 对齐错误句子和正确句子, - 使用编辑距离算法抽取编辑操作 - """ - - def __init__( - self, - semantic_dict: Dict, - confusion_dict: Dict, - granularity: str = "word", - ) -> None: - """ - 构造函数 - :param semantic_dict: 语义词典(大词林) - :param confusion_dict: 字符混淆集 - """ - self.insertion_cost = 1 - self.deletion_cost = 1 - self.semantic_dict = semantic_dict - self.confusion_dict = confusion_dict - # Because we use character level tokenization, this doesn't currently use POS - self._open_pos = {} # 如果是词级别,还可以利用词性是否相同来计算cost - self.granularity = granularity # word-level or character-level - self.align_seqs = [] - - def __call__(self, - src: List[Tuple], - tgt: List[Tuple], - verbose: bool = False): - cost_matrix, oper_matrix = self.align(src, tgt) - align_seq = self.get_cheapest_align_seq(oper_matrix) - - if verbose: - print("========== Seg. and POS: ==========") - print(src) - print(tgt) - print("========== Cost Matrix ==========") - print(cost_matrix) - print("========== Oper Matrix ==========") - print(oper_matrix) - print("========== Alignment ==========") - print(align_seq) - print("========== Results ==========") - for a in align_seq: - print(a[0], src[a[1]: a[2]], tgt[a[3]: a[4]]) - return align_seq - - def _get_semantic_class(self, word): - """ - NOTE: Based on the paper: - Improved-Edit-Distance Kernel for Chinese Relation Extraction - 获取每个词语的语义类别(基于大词林,有三个级别) - """ - if word in self.semantic_dict: - code = self.semantic_dict[word] - high, mid, low = code[0], code[1], code[2:4] - return high, mid, low - else: # unknown - return None - - @staticmethod - def _get_class_diff(a_class, b_class): - """ - d == 3 for equivalent semantics - d == 0 for completely different semantics - 根据大词林的信息,计算两个词的语义类别的差距 - """ - d = sum([a == b for a, b in zip(a_class, b_class)]) - return d - - def _get_semantic_cost(self, a, b): - """ - 计算基于语义信息的替换操作cost - :param a: 单词a的语义类别 - :param b: 单词b的语义类别 - :return: 替换编辑代价 - """ - a_class = self._get_semantic_class(a) - b_class = self._get_semantic_class(b) - # unknown class, default to 1 - if a_class is None or b_class is None: - return 4 - elif a_class == b_class: - return 0 - else: - return 2 * (3 - self._get_class_diff(a_class, b_class)) - - def _get_pos_cost(self, a_pos, b_pos): - """ - 计算基于词性信息的编辑距离cost - :param a_pos: 单词a的词性 - :param b_pos: 单词b的词性 - :return: 替换编辑代价 - """ - if a_pos == b_pos: - return 0 - elif a_pos in self._open_pos and b_pos in self._open_pos: - return 0.25 - else: - return 0.499 - - def _get_char_cost(self, a, b, pinyin_a, pinyin_b): - """ - NOTE: This is a replacement of ERRANTS lemma cost for Chinese - 计算基于字符相似度的编辑距离cost - """ - if not (check_all_chinese(a) and check_all_chinese(b)): - return 0.5 - if len(a) > len(b): - a, b = b, a - pinyin_a, pinyin_b = pinyin_b, pinyin_a - if a == b: - return 0 - else: - return self._get_spell_cost(a, b, pinyin_a, pinyin_b) - - def _get_spell_cost(self, a, b, pinyin_a, pinyin_b): - """ - 计算两个单词拼写相似度,分别由字形相似度和字音相似度组成 - :param a: 单词a - :param b: 单词b,且单词a的长度小于等于b - :param pinyin_a: 单词a的拼音 - :param pinyin_b: 单词b的拼音 - :return: 替换操作cost - """ - count = 0 - for i in range(len(a)): - for j in range(len(b)): - if a[i] == b[j] or (set(pinyin_a) & set(pinyin_b)) or (b[j] in self.confusion_dict.keys() and a[i] in self.confusion_dict[b[j]]) or (a[i] in self.confusion_dict.keys() and b[j] in self.confusion_dict[a[i]]): - count += 1 - break - return (len(a) - count) / (len(a) * 2) - - def get_sub_cost(self, a_seg, b_seg): - """ - Calculate the substitution cost between words a and b - 计算两个单词替换操作的编辑cost,最大为2,等于一次删除和一次添加 - """ - if a_seg[0] == b_seg[0]: - return 0 - - if self.granularity == "word": # 词级别可以额外利用词性信息 - semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0 - pos_cost = self._get_pos_cost(a_seg[1], b_seg[1]) - char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2]) - return semantic_cost + pos_cost + char_cost - else: # 字级别只能利用字义信息(从大词林中获取)和字面相似度信息 - semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0 - if a_seg[0] in punct and b_seg[0] in punct: - pos_cost = 0.0 - elif a_seg[0] not in punct and b_seg[0] not in punct: - pos_cost = 0.25 - else: - pos_cost = 0.499 - # pos_cost = 0.0 if (a_seg[0] in punct and b_seg[0] in punct) or (a_seg[0] not in punct and b_seg[0] not in punct) else 0.5 - char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2]) - return semantic_cost + char_cost + pos_cost - - def align(self, - src: List[Tuple], - tgt: List[Tuple]): - """ - Based on ERRANT's alignment - 基于改进的动态规划算法,为原句子的每个字打上编辑标签,以便使它能够成功转换为目标句子。 - 编辑操作类别: - 1) M:Match,即KEEP,即当前字保持不变 - 2) D:Delete,删除,即当前字需要被删除 - 3) I:Insert,插入,即当前字需要被插入 - 4) T:Transposition,移位操作,即涉及到词序问题 - """ - cost_matrix = np.zeros((len(src) + 1, len(tgt) + 1)) # 编辑cost矩阵 - oper_matrix = np.full( - (len(src) + 1, len(tgt) + 1), "O", dtype=object - ) # 操作矩阵 - # Fill in the edges - for i in range(1, len(src) + 1): - cost_matrix[i][0] = cost_matrix[i - 1][0] + 1 - oper_matrix[i][0] = ["D"] - for j in range(1, len(tgt) + 1): - cost_matrix[0][j] = cost_matrix[0][j - 1] + 1 - oper_matrix[0][j] = ["I"] - - # Loop through the cost matrix - for i in range(len(src)): - for j in range(len(tgt)): - # Matches - if src[i][0] == tgt[j][0]: # 如果两个字相等,则匹配成功(Match),编辑距离为0 - cost_matrix[i + 1][j + 1] = cost_matrix[i][j] - oper_matrix[i + 1][j + 1] = ["M"] - # Non-matches - else: - del_cost = cost_matrix[i][j + 1] + self.deletion_cost # 由删除动作得到的总cost - ins_cost = cost_matrix[i + 1][j] + self.insertion_cost # 由插入动作得到的总cost - sub_cost = cost_matrix[i][j] + self.get_sub_cost( - src[i], tgt[j] - ) # 由替换动作得到的总cost - # Calculate transposition cost - # 计算移位操作的总cost - trans_cost = float("inf") - k = 1 - while ( - i - k >= 0 - and j - k >= 0 - and cost_matrix[i - k + 1][j - k + 1] - != cost_matrix[i - k][j - k] - ): - p1 = sorted([a[0] for a in src][i - k: i + 1]) - p2 = sorted([b[0] for b in tgt][j - k: j + 1]) - if p1 == p2: - trans_cost = cost_matrix[i - k][j - k] + k - break - k += 1 - - costs = [trans_cost, sub_cost, ins_cost, del_cost] - ind = costs.index(min(costs)) - cost_matrix[i + 1][j + 1] = costs[ind] - # ind = costs.index(costs[ind], ind+1) - for idx, cost in enumerate(costs): - if cost == costs[ind]: - if idx == 0: - if oper_matrix[i + 1][j + 1] == "O": - oper_matrix[i + 1][j + 1] = ["T" + str(k + 1)] - else: - oper_matrix[i + 1][j + 1].append("T" + str(k + 1)) - elif idx == 1: - if oper_matrix[i + 1][j + 1] == "O": - oper_matrix[i + 1][j + 1] = ["S"] - else: - oper_matrix[i + 1][j + 1].append("S") - elif idx == 2: - if oper_matrix[i + 1][j + 1] == "O": - oper_matrix[i + 1][j + 1] = ["I"] - else: - oper_matrix[i + 1][j + 1].append("I") - else: - if oper_matrix[i + 1][j + 1] == "O": - oper_matrix[i + 1][j + 1] = ["D"] - else: - oper_matrix[i + 1][j + 1].append("D") - return cost_matrix, oper_matrix - - def _dfs(self, i, j, align_seq_now, oper_matrix, strategy="all"): - """ - 深度优先遍历,获取最小编辑距离相同的所有序列 - """ - if i + j == 0: - self.align_seqs.append(align_seq_now) - else: - ops = oper_matrix[i][j] # 可以类比成搜索一棵树从根结点到叶子结点的所有路径 - if strategy != "all": ops = ops[:1] - for op in ops: - if op in {"M", "S"}: - self._dfs(i - 1, j - 1, align_seq_now + [(op, i - 1, i, j - 1, j)], oper_matrix, strategy) - elif op == "D": - self._dfs(i - 1, j, align_seq_now + [(op, i - 1, i, j, j)], oper_matrix, strategy) - elif op == "I": - self._dfs(i, j - 1, align_seq_now + [(op, i, i, j - 1, j)], oper_matrix, strategy) - else: - k = int(op[1:]) - self._dfs(i - k, j - k, align_seq_now + [(op, i - k, i, j - k, j)], oper_matrix, strategy) - - def get_cheapest_align_seq(self, oper_matrix): - """ - 回溯获得编辑距离最小的编辑序列 - """ - self.align_seqs = [] - i = oper_matrix.shape[0] - 1 - j = oper_matrix.shape[1] - 1 - if abs(i - j) > 10: - self._dfs(i, j , [], oper_matrix, "first") - else: - self._dfs(i, j , [], oper_matrix, "all") - final_align_seqs = [seq[::-1] for seq in self.align_seqs] - return final_align_seqs - - -if __name__ == "__main__": - tokenizer = Tokenizer("word") - semantic_dict, semantic_class = read_cilin() - confusion_dict = read_confusion() - alignment = Alignment(semantic_dict, confusion_dict) - sents = ["首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 搾 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 6 粒 , 纯净 水 4量杯 、 香菜 半量杯 和 草菇 10 个 。".replace(" ", ""), "首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 榨 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 六 粒 , 纯净 水 四 量杯 、 香菜 半量杯 和 草菇 十 个 。".replace(" ", "")] - src, tgt = tokenizer(sents) +import numpy as np +from typing import List, Tuple, Dict +from modules.tokenizer import Tokenizer +import os +from string import punctuation + +REAL_PATH = os.path.split(os.path.realpath(__file__))[0] +chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏" +english_punct = punctuation +punct = chinese_punct + english_punct + +def check_all_chinese(word): + """ + 判断一个单词是否全部由中文组成 + :param word: + :return: + """ + return all(['\u4e00' <= ch <= '\u9fff' for ch in word]) + +def read_cilin(): + """ + Cilin 詞林 is a thesaurus with semantic information + """ + # TODO -- fix this path + project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path + lines = open(os.path.join(project_dir, "data", "cilin.txt"), "r", encoding="gbk").read().strip().split("\n") + semantic_dict = {} + semantic_classes = {} + for line in lines: + code, *words = line.split(" ") + for word in words: + semantic_dict[word] = code + # make reverse dict + if code in semantic_classes: + semantic_classes[code] += words + else: + semantic_classes[code] = words + return semantic_dict, semantic_classes + + +def read_confusion(): + confusion_dict = {} + project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path + with open(os.path.join(project_dir, "data", "confusion_dict.txt"), "r", encoding="utf-8") as f: + for line in f: + li = line.rstrip('\n').split(" ") + confusion_dict[li[0]] = li[1:] + return confusion_dict + +class Alignment: + """ + 对齐错误句子和正确句子, + 使用编辑距离算法抽取编辑操作 + """ + + def __init__( + self, + semantic_dict: Dict, + confusion_dict: Dict, + granularity: str = "word", + ) -> None: + """ + 构造函数 + :param semantic_dict: 语义词典(大词林) + :param confusion_dict: 字符混淆集 + """ + self.insertion_cost = 1 + self.deletion_cost = 1 + self.semantic_dict = semantic_dict + self.confusion_dict = confusion_dict + # Because we use character level tokenization, this doesn't currently use POS + self._open_pos = {} # 如果是词级别,还可以利用词性是否相同来计算cost + self.granularity = granularity # word-level or character-level + self.align_seqs = [] + + def __call__(self, + src: List[Tuple], + tgt: List[Tuple], + verbose: bool = False): + cost_matrix, oper_matrix = self.align(src, tgt) + align_seq = self.get_cheapest_align_seq(oper_matrix) + + if verbose: + print("========== Seg. and POS: ==========") + print(src) + print(tgt) + print("========== Cost Matrix ==========") + print(cost_matrix) + print("========== Oper Matrix ==========") + print(oper_matrix) + print("========== Alignment ==========") + print(align_seq) + print("========== Results ==========") + for a in align_seq: + print(a[0], src[a[1]: a[2]], tgt[a[3]: a[4]]) + return align_seq + + def _get_semantic_class(self, word): + """ + NOTE: Based on the paper: + Improved-Edit-Distance Kernel for Chinese Relation Extraction + 获取每个词语的语义类别(基于大词林,有三个级别) + """ + if word in self.semantic_dict: + code = self.semantic_dict[word] + high, mid, low = code[0], code[1], code[2:4] + return high, mid, low + else: # unknown + return None + + @staticmethod + def _get_class_diff(a_class, b_class): + """ + d == 3 for equivalent semantics + d == 0 for completely different semantics + 根据大词林的信息,计算两个词的语义类别的差距 + """ + d = sum([a == b for a, b in zip(a_class, b_class)]) + return d + + def _get_semantic_cost(self, a, b): + """ + 计算基于语义信息的替换操作cost + :param a: 单词a的语义类别 + :param b: 单词b的语义类别 + :return: 替换编辑代价 + """ + a_class = self._get_semantic_class(a) + b_class = self._get_semantic_class(b) + # unknown class, default to 1 + if a_class is None or b_class is None: + return 4 + elif a_class == b_class: + return 0 + else: + return 2 * (3 - self._get_class_diff(a_class, b_class)) + + def _get_pos_cost(self, a_pos, b_pos): + """ + 计算基于词性信息的编辑距离cost + :param a_pos: 单词a的词性 + :param b_pos: 单词b的词性 + :return: 替换编辑代价 + """ + if a_pos == b_pos: + return 0 + elif a_pos in self._open_pos and b_pos in self._open_pos: + return 0.25 + else: + return 0.499 + + def _get_char_cost(self, a, b, pinyin_a, pinyin_b): + """ + NOTE: This is a replacement of ERRANTS lemma cost for Chinese + 计算基于字符相似度的编辑距离cost + """ + if not (check_all_chinese(a) and check_all_chinese(b)): + return 0.5 + if len(a) > len(b): + a, b = b, a + pinyin_a, pinyin_b = pinyin_b, pinyin_a + if a == b: + return 0 + else: + return self._get_spell_cost(a, b, pinyin_a, pinyin_b) + + def _get_spell_cost(self, a, b, pinyin_a, pinyin_b): + """ + 计算两个单词拼写相似度,分别由字形相似度和字音相似度组成 + :param a: 单词a + :param b: 单词b,且单词a的长度小于等于b + :param pinyin_a: 单词a的拼音 + :param pinyin_b: 单词b的拼音 + :return: 替换操作cost + """ + count = 0 + for i in range(len(a)): + for j in range(len(b)): + if a[i] == b[j] or (set(pinyin_a) & set(pinyin_b)) or (b[j] in self.confusion_dict.keys() and a[i] in self.confusion_dict[b[j]]) or (a[i] in self.confusion_dict.keys() and b[j] in self.confusion_dict[a[i]]): + count += 1 + break + return (len(a) - count) / (len(a) * 2) + + def get_sub_cost(self, a_seg, b_seg): + """ + Calculate the substitution cost between words a and b + 计算两个单词替换操作的编辑cost,最大为2,等于一次删除和一次添加 + """ + if a_seg[0] == b_seg[0]: + return 0 + + if self.granularity == "word": # 词级别可以额外利用词性信息 + semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0 + pos_cost = self._get_pos_cost(a_seg[1], b_seg[1]) + char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2]) + return semantic_cost + pos_cost + char_cost + else: # 字级别只能利用字义信息(从大词林中获取)和字面相似度信息 + semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0 + if a_seg[0] in punct and b_seg[0] in punct: + pos_cost = 0.0 + elif a_seg[0] not in punct and b_seg[0] not in punct: + pos_cost = 0.25 + else: + pos_cost = 0.499 + # pos_cost = 0.0 if (a_seg[0] in punct and b_seg[0] in punct) or (a_seg[0] not in punct and b_seg[0] not in punct) else 0.5 + char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2]) + return semantic_cost + char_cost + pos_cost + + def align(self, + src: List[Tuple], + tgt: List[Tuple]): + """ + Based on ERRANT's alignment + 基于改进的动态规划算法,为原句子的每个字打上编辑标签,以便使它能够成功转换为目标句子。 + 编辑操作类别: + 1) M:Match,即KEEP,即当前字保持不变 + 2) D:Delete,删除,即当前字需要被删除 + 3) I:Insert,插入,即当前字需要被插入 + 4) T:Transposition,移位操作,即涉及到词序问题 + """ + cost_matrix = np.zeros((len(src) + 1, len(tgt) + 1)) # 编辑cost矩阵 + oper_matrix = np.full( + (len(src) + 1, len(tgt) + 1), "O", dtype=object + ) # 操作矩阵 + # Fill in the edges + for i in range(1, len(src) + 1): + cost_matrix[i][0] = cost_matrix[i - 1][0] + 1 + oper_matrix[i][0] = ["D"] + for j in range(1, len(tgt) + 1): + cost_matrix[0][j] = cost_matrix[0][j - 1] + 1 + oper_matrix[0][j] = ["I"] + + # Loop through the cost matrix + for i in range(len(src)): + for j in range(len(tgt)): + # Matches + if src[i][0] == tgt[j][0]: # 如果两个字相等,则匹配成功(Match),编辑距离为0 + cost_matrix[i + 1][j + 1] = cost_matrix[i][j] + oper_matrix[i + 1][j + 1] = ["M"] + # Non-matches + else: + del_cost = cost_matrix[i][j + 1] + self.deletion_cost # 由删除动作得到的总cost + ins_cost = cost_matrix[i + 1][j] + self.insertion_cost # 由插入动作得到的总cost + sub_cost = cost_matrix[i][j] + self.get_sub_cost( + src[i], tgt[j] + ) # 由替换动作得到的总cost + # Calculate transposition cost + # 计算移位操作的总cost + trans_cost = float("inf") + k = 1 + while ( + i - k >= 0 + and j - k >= 0 + and cost_matrix[i - k + 1][j - k + 1] + != cost_matrix[i - k][j - k] + ): + p1 = sorted([a[0] for a in src][i - k: i + 1]) + p2 = sorted([b[0] for b in tgt][j - k: j + 1]) + if p1 == p2: + trans_cost = cost_matrix[i - k][j - k] + k + break + k += 1 + + costs = [trans_cost, sub_cost, ins_cost, del_cost] + ind = costs.index(min(costs)) + cost_matrix[i + 1][j + 1] = costs[ind] + # ind = costs.index(costs[ind], ind+1) + for idx, cost in enumerate(costs): + if cost == costs[ind]: + if idx == 0: + if oper_matrix[i + 1][j + 1] == "O": + oper_matrix[i + 1][j + 1] = ["T" + str(k + 1)] + else: + oper_matrix[i + 1][j + 1].append("T" + str(k + 1)) + elif idx == 1: + if oper_matrix[i + 1][j + 1] == "O": + oper_matrix[i + 1][j + 1] = ["S"] + else: + oper_matrix[i + 1][j + 1].append("S") + elif idx == 2: + if oper_matrix[i + 1][j + 1] == "O": + oper_matrix[i + 1][j + 1] = ["I"] + else: + oper_matrix[i + 1][j + 1].append("I") + else: + if oper_matrix[i + 1][j + 1] == "O": + oper_matrix[i + 1][j + 1] = ["D"] + else: + oper_matrix[i + 1][j + 1].append("D") + return cost_matrix, oper_matrix + + def _dfs(self, i, j, align_seq_now, oper_matrix, strategy="all"): + """ + 深度优先遍历,获取最小编辑距离相同的所有序列 + """ + if i + j == 0: + self.align_seqs.append(align_seq_now) + else: + ops = oper_matrix[i][j] # 可以类比成搜索一棵树从根结点到叶子结点的所有路径 + if strategy != "all": ops = ops[:1] + for op in ops: + if op in {"M", "S"}: + self._dfs(i - 1, j - 1, align_seq_now + [(op, i - 1, i, j - 1, j)], oper_matrix, strategy) + elif op == "D": + self._dfs(i - 1, j, align_seq_now + [(op, i - 1, i, j, j)], oper_matrix, strategy) + elif op == "I": + self._dfs(i, j - 1, align_seq_now + [(op, i, i, j - 1, j)], oper_matrix, strategy) + else: + k = int(op[1:]) + self._dfs(i - k, j - k, align_seq_now + [(op, i - k, i, j - k, j)], oper_matrix, strategy) + + def get_cheapest_align_seq(self, oper_matrix): + """ + 回溯获得编辑距离最小的编辑序列 + """ + self.align_seqs = [] + i = oper_matrix.shape[0] - 1 + j = oper_matrix.shape[1] - 1 + if abs(i - j) > 10: + self._dfs(i, j , [], oper_matrix, "first") + else: + self._dfs(i, j , [], oper_matrix, "all") + final_align_seqs = [seq[::-1] for seq in self.align_seqs] + return final_align_seqs + + +if __name__ == "__main__": + tokenizer = Tokenizer("word") + semantic_dict, semantic_class = read_cilin() + confusion_dict = read_confusion() + alignment = Alignment(semantic_dict, confusion_dict) + sents = ["首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 搾 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 6 粒 , 纯净 水 4量杯 、 香菜 半量杯 和 草菇 10 个 。".replace(" ", ""), "首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 榨 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 六 粒 , 纯净 水 四 量杯 、 香菜 半量杯 和 草菇 十 个 。".replace(" ", "")] + src, tgt = tokenizer(sents) alignment(src, tgt, verbose=True) \ No newline at end of file diff --git a/opencompass/datasets/lawbench/utils/modules/annotator.py b/opencompass/datasets/lawbench/utils/modules/annotator.py index 4bb40597..d7b00d06 100644 --- a/opencompass/datasets/lawbench/utils/modules/annotator.py +++ b/opencompass/datasets/lawbench/utils/modules/annotator.py @@ -1,76 +1,76 @@ -from typing import List, Tuple -from modules.alignment import read_cilin, read_confusion, Alignment -from modules.merger import Merger -from modules.classifier import Classifier - -class Annotator: - def __init__(self, - align: Alignment, - merger: Merger, - classifier: Classifier, - granularity: str = "word", - strategy: str = "first"): - self.align = align - self.merger = merger - self.classifier = classifier - self.granularity = granularity - self.strategy = strategy - - @classmethod - def create_default(cls, granularity: str = "word", strategy: str = "first"): - """ - Default parameters used in the paper - """ - semantic_dict, semantic_class = read_cilin() - confusion_dict = read_confusion() - align = Alignment(semantic_dict, confusion_dict, granularity) - merger = Merger(granularity) - classifier = Classifier(granularity) - return cls(align, merger, classifier, granularity, strategy) - - def __call__(self, - src: List[Tuple], - tgt: List[Tuple], - annotator_id: int = 0, - verbose: bool = False): - """ - Align sentences and annotate them with error type information - """ - src_tokens = [x[0] for x in src] - tgt_tokens = [x[0] for x in tgt] - src_str = "".join(src_tokens) - tgt_str = "".join(tgt_tokens) - # convert to text form - annotations_out = ["S " + " ".join(src_tokens) + "\n"] - if tgt_str == "没有错误" or src_str == tgt_str: # Error Free Case - annotations_out.append(f"T{annotator_id} 没有错误\n") - cors = [tgt_str] - op, toks, inds = "noop", "-NONE-", (-1, -1) - a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n" - annotations_out.append(a_str) - elif tgt_str == "无法标注": # Not Annotatable Case - annotations_out.append(f"T{annotator_id} 无法标注\n") - cors = [tgt_str] - op, toks, inds = "NA", "-NONE-", (-1, -1) - a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n" - annotations_out.append(a_str) - else: # Other - align_objs = self.align(src, tgt) - edit_objs = [] - align_idx = 0 - if self.strategy == "first": - align_objs = align_objs[:1] - for align_obj in align_objs: - edits = self.merger(align_obj, src, tgt, verbose) - if edits not in edit_objs: - edit_objs.append(edits) - annotations_out.append(f"T{annotator_id}-A{align_idx} " + " ".join(tgt_tokens) + "\n") - align_idx += 1 - cors = self.classifier(src, tgt, edits, verbose) - # annotations_out = [] - for cor in cors: - op, toks, inds = cor.op, cor.toks, cor.inds - a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n" - annotations_out.append(a_str) - annotations_out.append("\n") - return annotations_out, cors +from typing import List, Tuple +from modules.alignment import read_cilin, read_confusion, Alignment +from modules.merger import Merger +from modules.classifier import Classifier + +class Annotator: + def __init__(self, + align: Alignment, + merger: Merger, + classifier: Classifier, + granularity: str = "word", + strategy: str = "first"): + self.align = align + self.merger = merger + self.classifier = classifier + self.granularity = granularity + self.strategy = strategy + + @classmethod + def create_default(cls, granularity: str = "word", strategy: str = "first"): + """ + Default parameters used in the paper + """ + semantic_dict, semantic_class = read_cilin() + confusion_dict = read_confusion() + align = Alignment(semantic_dict, confusion_dict, granularity) + merger = Merger(granularity) + classifier = Classifier(granularity) + return cls(align, merger, classifier, granularity, strategy) + + def __call__(self, + src: List[Tuple], + tgt: List[Tuple], + annotator_id: int = 0, + verbose: bool = False): + """ + Align sentences and annotate them with error type information + """ + src_tokens = [x[0] for x in src] + tgt_tokens = [x[0] for x in tgt] + src_str = "".join(src_tokens) + tgt_str = "".join(tgt_tokens) + # convert to text form + annotations_out = ["S " + " ".join(src_tokens) + "\n"] + if tgt_str == "没有错误" or src_str == tgt_str: # Error Free Case + annotations_out.append(f"T{annotator_id} 没有错误\n") + cors = [tgt_str] + op, toks, inds = "noop", "-NONE-", (-1, -1) + a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n" + annotations_out.append(a_str) + elif tgt_str == "无法标注": # Not Annotatable Case + annotations_out.append(f"T{annotator_id} 无法标注\n") + cors = [tgt_str] + op, toks, inds = "NA", "-NONE-", (-1, -1) + a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n" + annotations_out.append(a_str) + else: # Other + align_objs = self.align(src, tgt) + edit_objs = [] + align_idx = 0 + if self.strategy == "first": + align_objs = align_objs[:1] + for align_obj in align_objs: + edits = self.merger(align_obj, src, tgt, verbose) + if edits not in edit_objs: + edit_objs.append(edits) + annotations_out.append(f"T{annotator_id}-A{align_idx} " + " ".join(tgt_tokens) + "\n") + align_idx += 1 + cors = self.classifier(src, tgt, edits, verbose) + # annotations_out = [] + for cor in cors: + op, toks, inds = cor.op, cor.toks, cor.inds + a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n" + annotations_out.append(a_str) + annotations_out.append("\n") + return annotations_out, cors diff --git a/opencompass/datasets/lawbench/utils/modules/classifier.py b/opencompass/datasets/lawbench/utils/modules/classifier.py index f62d90bb..66c225d4 100644 --- a/opencompass/datasets/lawbench/utils/modules/classifier.py +++ b/opencompass/datasets/lawbench/utils/modules/classifier.py @@ -1,151 +1,151 @@ -from char_smi import CharFuncs -from collections import namedtuple -from pypinyin import pinyin, Style -import os -Correction = namedtuple( - "Correction", - [ - "op", - "toks", - "inds", - ], -) -file_path = os.path.dirname(os.path.abspath(__file__)) -char_smi = CharFuncs(os.path.join(file_path.replace("modules", ""), 'data/char_meta.txt')) - -def check_spell_error(src_span: str, - tgt_span: str, - threshold: float = 0.8) -> bool: - if len(src_span) != len(tgt_span): - return False - src_chars = [ch for ch in src_span] - tgt_chars = [ch for ch in tgt_span] - if sorted(src_chars) == sorted(tgt_chars): # 词内部字符异位 - return True - for src_char, tgt_char in zip(src_chars, tgt_chars): - if src_char != tgt_char: - if src_char not in char_smi.data or tgt_char not in char_smi.data: - return False - v_sim = char_smi.shape_similarity(src_char, tgt_char) - p_sim = char_smi.pronunciation_similarity(src_char, tgt_char) - if v_sim + p_sim < threshold and not ( - set(pinyin(src_char, style=Style.NORMAL, heteronym=True)[0]) & set(pinyin(tgt_char, style=Style.NORMAL, heteronym=True)[0])): - return False - return True - -class Classifier: - """ - 错误类型分类器 - """ - def __init__(self, - granularity: str = "word"): - - self.granularity = granularity - - @staticmethod - def get_pos_type(pos): - if pos in {"n", "nd"}: - return "NOUN" - if pos in {"nh", "ni", "nl", "ns", "nt", "nz"}: - return "NOUN-NE" - if pos in {"v"}: - return "VERB" - if pos in {"a", "b"}: - return "ADJ" - if pos in {"c"}: - return "CONJ" - if pos in {"r"}: - return "PRON" - if pos in {"d"}: - return "ADV" - if pos in {"u"}: - return "AUX" - # if pos in {"k"}: # TODO 后缀词比例太少,暂且分入其它 - # return "SUFFIX" - if pos in {"m"}: - return "NUM" - if pos in {"p"}: - return "PREP" - if pos in {"q"}: - return "QUAN" - if pos in {"wp"}: - return "PUNCT" - return "OTHER" - - def __call__(self, - src, - tgt, - edits, - verbose: bool = False): - """ - 为编辑操作划分错误类型 - :param src: 错误句子信息 - :param tgt: 正确句子信息 - :param edits: 编辑操作 - :param verbose: 是否打印信息 - :return: 划分完错误类型后的编辑操作 - """ - results = [] - src_tokens = [x[0] for x in src] - tgt_tokens = [x[0] for x in tgt] - for edit in edits: - error_type = edit[0] - src_span = " ".join(src_tokens[edit[1]: edit[2]]) - tgt_span = " ".join(tgt_tokens[edit[3]: edit[4]]) - # print(tgt_span) - cor = None - if error_type[0] == "T": - cor = Correction("W", tgt_span, (edit[1], edit[2])) - elif error_type[0] == "D": - if self.granularity == "word": # 词级别可以细分错误类型 - if edit[2] - edit[1] > 1: # 词组冗余暂时分为OTHER - cor = Correction("R:OTHER", "-NONE-", (edit[1], edit[2])) - else: - pos = self.get_pos_type(src[edit[1]][1]) - pos = "NOUN" if pos == "NOUN-NE" else pos - pos = "MC" if tgt_span == "[缺失成分]" else pos - cor = Correction("R:{:s}".format(pos), "-NONE-", (edit[1], edit[2])) - else: # 字级别可以只需要根据操作划分类型即可 - cor = Correction("R", "-NONE-", (edit[1], edit[2])) - elif error_type[0] == "I": - if self.granularity == "word": # 词级别可以细分错误类型 - if edit[4] - edit[3] > 1: # 词组丢失暂时分为OTHER - cor = Correction("M:OTHER", tgt_span, (edit[1], edit[2])) - else: - pos = self.get_pos_type(tgt[edit[3]][1]) - pos = "NOUN" if pos == "NOUN-NE" else pos - pos = "MC" if tgt_span == "[缺失成分]" else pos - cor = Correction("M:{:s}".format(pos), tgt_span, (edit[1], edit[2])) - else: # 字级别可以只需要根据操作划分类型即可 - cor = Correction("M", tgt_span, (edit[1], edit[2])) - elif error_type[0] == "S": - if self.granularity == "word": # 词级别可以细分错误类型 - if check_spell_error(src_span.replace(" ", ""), tgt_span.replace(" ", "")): - cor = Correction("S:SPELL", tgt_span, (edit[1], edit[2])) - # Todo 暂且不单独区分命名实体拼写错误 - # if edit[4] - edit[3] > 1: - # cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2])) - # else: - # pos = self.get_pos_type(tgt[edit[3]][1]) - # if pos == "NOUN-NE": # 命名实体拼写有误 - # cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2])) - # else: # 普通词语拼写有误 - # cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2])) - else: - if edit[4] - edit[3] > 1: # 词组被替换暂时分为OTHER - cor = Correction("S:OTHER", tgt_span, (edit[1], edit[2])) - else: - pos = self.get_pos_type(tgt[edit[3]][1]) - pos = "NOUN" if pos == "NOUN-NE" else pos - pos = "MC" if tgt_span == "[缺失成分]" else pos - cor = Correction("S:{:s}".format(pos), tgt_span, (edit[1], edit[2])) - else: # 字级别可以只需要根据操作划分类型即可 - cor = Correction("S", tgt_span, (edit[1], edit[2])) - results.append(cor) - if verbose: - print("========== Corrections ==========") - for cor in results: - print("Type: {:s}, Position: {:d} -> {:d}, Target: {:s}".format(cor.op, cor.inds[0], cor.inds[1], cor.toks)) - return results - -# print(pinyin("朝", style=Style.NORMAL)) +from char_smi import CharFuncs +from collections import namedtuple +from pypinyin import pinyin, Style +import os +Correction = namedtuple( + "Correction", + [ + "op", + "toks", + "inds", + ], +) +file_path = os.path.dirname(os.path.abspath(__file__)) +char_smi = CharFuncs(os.path.join(file_path.replace("modules", ""), 'data/char_meta.txt')) + +def check_spell_error(src_span: str, + tgt_span: str, + threshold: float = 0.8) -> bool: + if len(src_span) != len(tgt_span): + return False + src_chars = [ch for ch in src_span] + tgt_chars = [ch for ch in tgt_span] + if sorted(src_chars) == sorted(tgt_chars): # 词内部字符异位 + return True + for src_char, tgt_char in zip(src_chars, tgt_chars): + if src_char != tgt_char: + if src_char not in char_smi.data or tgt_char not in char_smi.data: + return False + v_sim = char_smi.shape_similarity(src_char, tgt_char) + p_sim = char_smi.pronunciation_similarity(src_char, tgt_char) + if v_sim + p_sim < threshold and not ( + set(pinyin(src_char, style=Style.NORMAL, heteronym=True)[0]) & set(pinyin(tgt_char, style=Style.NORMAL, heteronym=True)[0])): + return False + return True + +class Classifier: + """ + 错误类型分类器 + """ + def __init__(self, + granularity: str = "word"): + + self.granularity = granularity + + @staticmethod + def get_pos_type(pos): + if pos in {"n", "nd"}: + return "NOUN" + if pos in {"nh", "ni", "nl", "ns", "nt", "nz"}: + return "NOUN-NE" + if pos in {"v"}: + return "VERB" + if pos in {"a", "b"}: + return "ADJ" + if pos in {"c"}: + return "CONJ" + if pos in {"r"}: + return "PRON" + if pos in {"d"}: + return "ADV" + if pos in {"u"}: + return "AUX" + # if pos in {"k"}: # TODO 后缀词比例太少,暂且分入其它 + # return "SUFFIX" + if pos in {"m"}: + return "NUM" + if pos in {"p"}: + return "PREP" + if pos in {"q"}: + return "QUAN" + if pos in {"wp"}: + return "PUNCT" + return "OTHER" + + def __call__(self, + src, + tgt, + edits, + verbose: bool = False): + """ + 为编辑操作划分错误类型 + :param src: 错误句子信息 + :param tgt: 正确句子信息 + :param edits: 编辑操作 + :param verbose: 是否打印信息 + :return: 划分完错误类型后的编辑操作 + """ + results = [] + src_tokens = [x[0] for x in src] + tgt_tokens = [x[0] for x in tgt] + for edit in edits: + error_type = edit[0] + src_span = " ".join(src_tokens[edit[1]: edit[2]]) + tgt_span = " ".join(tgt_tokens[edit[3]: edit[4]]) + # print(tgt_span) + cor = None + if error_type[0] == "T": + cor = Correction("W", tgt_span, (edit[1], edit[2])) + elif error_type[0] == "D": + if self.granularity == "word": # 词级别可以细分错误类型 + if edit[2] - edit[1] > 1: # 词组冗余暂时分为OTHER + cor = Correction("R:OTHER", "-NONE-", (edit[1], edit[2])) + else: + pos = self.get_pos_type(src[edit[1]][1]) + pos = "NOUN" if pos == "NOUN-NE" else pos + pos = "MC" if tgt_span == "[缺失成分]" else pos + cor = Correction("R:{:s}".format(pos), "-NONE-", (edit[1], edit[2])) + else: # 字级别可以只需要根据操作划分类型即可 + cor = Correction("R", "-NONE-", (edit[1], edit[2])) + elif error_type[0] == "I": + if self.granularity == "word": # 词级别可以细分错误类型 + if edit[4] - edit[3] > 1: # 词组丢失暂时分为OTHER + cor = Correction("M:OTHER", tgt_span, (edit[1], edit[2])) + else: + pos = self.get_pos_type(tgt[edit[3]][1]) + pos = "NOUN" if pos == "NOUN-NE" else pos + pos = "MC" if tgt_span == "[缺失成分]" else pos + cor = Correction("M:{:s}".format(pos), tgt_span, (edit[1], edit[2])) + else: # 字级别可以只需要根据操作划分类型即可 + cor = Correction("M", tgt_span, (edit[1], edit[2])) + elif error_type[0] == "S": + if self.granularity == "word": # 词级别可以细分错误类型 + if check_spell_error(src_span.replace(" ", ""), tgt_span.replace(" ", "")): + cor = Correction("S:SPELL", tgt_span, (edit[1], edit[2])) + # Todo 暂且不单独区分命名实体拼写错误 + # if edit[4] - edit[3] > 1: + # cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2])) + # else: + # pos = self.get_pos_type(tgt[edit[3]][1]) + # if pos == "NOUN-NE": # 命名实体拼写有误 + # cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2])) + # else: # 普通词语拼写有误 + # cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2])) + else: + if edit[4] - edit[3] > 1: # 词组被替换暂时分为OTHER + cor = Correction("S:OTHER", tgt_span, (edit[1], edit[2])) + else: + pos = self.get_pos_type(tgt[edit[3]][1]) + pos = "NOUN" if pos == "NOUN-NE" else pos + pos = "MC" if tgt_span == "[缺失成分]" else pos + cor = Correction("S:{:s}".format(pos), tgt_span, (edit[1], edit[2])) + else: # 字级别可以只需要根据操作划分类型即可 + cor = Correction("S", tgt_span, (edit[1], edit[2])) + results.append(cor) + if verbose: + print("========== Corrections ==========") + for cor in results: + print("Type: {:s}, Position: {:d} -> {:d}, Target: {:s}".format(cor.op, cor.inds[0], cor.inds[1], cor.toks)) + return results + +# print(pinyin("朝", style=Style.NORMAL)) diff --git a/opencompass/datasets/lawbench/utils/modules/merger.py b/opencompass/datasets/lawbench/utils/modules/merger.py index 8c0f6db6..26e7039b 100644 --- a/opencompass/datasets/lawbench/utils/modules/merger.py +++ b/opencompass/datasets/lawbench/utils/modules/merger.py @@ -1,273 +1,273 @@ -from itertools import groupby -from string import punctuation -from typing import List -from modules.tokenizer import Tokenizer -from modules.alignment import Alignment, read_cilin, read_confusion -import Levenshtein - -class Merger: - """ - 合并编辑操作,从Token-Level转换为Span-Level - """ - - def __init__(self, - granularity: str = "word", - merge: bool = False): - chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧." - self.punctuation = punctuation + chinese_punct - self.not_merge_token = [punct for punct in self.punctuation] - self.granularity = granularity - self.merge = merge - - @staticmethod - def _merge_edits(seq, tag="X"): - if seq: - return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])] - else: - return seq - - @staticmethod - def _check_revolve(span_a, span_b): - span_a = span_a + span_a - return span_b in span_a - - def _process_seq(self, seq, src_tokens, tgt_tokens): - if len(seq) <= 1: - return seq - - ops = [op[0] for op in seq] - if set(ops) == {"D"} or set(ops) == {"I"}: - return self._merge_edits(seq, set(ops).pop()) - - if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}: - # do not merge this pattern_from_qua.txt - return seq - - if set(ops) == {"S"}: - if self.granularity == "word": - return seq - else: - return self._merge_edits(seq, "S") - - if set(ops) == {"M"}: - return self._merge_edits(seq, "M") - - return self._merge_edits(seq, "S") - - def __call__(self, - align_obj, - src: List, - tgt: List, - verbose: bool = False): - """ - Based on ERRANT's merge, adapted for Chinese - """ - src_tokens = [x[0] for x in src] - tgt_tokens = [x[0] for x in tgt] - edits = [] - # Split alignment into groups of M, T and rest. (T has a number after it) - # Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并 - # Todo 缺失成分标签也不与其它编辑合并 - for op, group in groupby( - align_obj, - lambda x: x[0][0] if x[0][0] in {"M", "T"} else False, - ): - group = list(group) - # T is always split TODO: Evaluate this - if op == "T": - for seq in group: - edits.append(seq) - # Process D, I and S subsequence - else: - # Turn the processed sequence into edits - processed = self._process_seq(group, src_tokens, tgt_tokens) - for seq in processed: - edits.append(seq) - - filtered_edits = [] - i = 0 - while i < len(edits): - e1 = edits[i][0][0] - - if i < len(edits) - 2: - e2 = edits[i + 1][0][0] - e3 = edits[i + 2][0][0] - - # Find "S M S" patterns - # Ex: - # S M S - # 冬阴功 对 外国人 - # 外国人 对 冬阴功 - if e1 == "S" and e2 == "M" and e3 == "S": - w1 = "".join(src_tokens[edits[i][1]: edits[i][2]]) - w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]]) - w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]]) - w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]) - if min([len(w1), len(w2), len(w3), len(w4)]) == 1: - if w1 == w4 and w2 == w3: - group = [edits[i], edits[i + 1], edits[i + 2]] - processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1])) - for seq in processed: - filtered_edits.append(seq) - i += 3 - else: - filtered_edits.append(edits[i]) - i += 1 - else: - if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1: - group = [edits[i], edits[i + 1], edits[i + 2]] - processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1])) - for seq in processed: - filtered_edits.append(seq) - i += 3 - else: - filtered_edits.append(edits[i]) - i += 1 - # Find "D M I" or "I M D" patterns - # Ex: - # D M I - # 旅游 去 陌生 的 地方 - # 去 陌生 的 地方 旅游 - elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"): - if e1 == "D": - delete_token = src_tokens[edits[i][1]: edits[i][2]] - insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]] - else: - delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]] - insert_token = tgt_tokens[edits[i][3]: edits[i][4]] - a, b = "".join(delete_token), "".join(insert_token) - if len(a) < len(b): - a, b = b, a - if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1: - if len(b) == 1: - if a == b: - group = [edits[i], edits[i + 1], edits[i + 2]] - processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1])) - for seq in processed: - filtered_edits.append(seq) - i += 3 - else: - filtered_edits.append(edits[i]) - i += 1 - else: - if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)): - group = [edits[i], edits[i + 1], edits[i + 2]] - processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1])) - for seq in processed: - filtered_edits.append(seq) - i += 3 - else: - filtered_edits.append(edits[i]) - i += 1 - else: - filtered_edits.append(edits[i]) - i += 1 - else: - if e1 != "M": - filtered_edits.append(edits[i]) - i += 1 - else: - if e1 != "M": - filtered_edits.append(edits[i]) - i += 1 - # In rare cases with word-level tokenization, the following error can occur: - # M D S M - # 有 時 住 上層 - # 有 時住 上層 - # Which results in S: 時住 --> 時住 - # We need to filter this case out - second_filter = [] - for edit in filtered_edits: # 避免因为分词错误导致的mismatch现象 - span1 = "".join(src_tokens[edit[1] : edit[2]]) - span2 = "".join(tgt_tokens[edit[3] : edit[4]]) - - if span1 != span2: - if edit[0] == "S": - b = True - # In rare cases with word-level tokenization, the following error can occur: - # S I I M - # 负责任 老师 - # 负 责任 的 老师 - # Which results in S: 负责任 --> 负 责任 的 - # We need to convert this edit to I: --> 的 - - # 首部有重叠 - common_str = "" - tmp_new_start_1 = edit[1] - for i in range(edit[1], edit[2]): - if not span2.startswith(common_str + src_tokens[i]): - break - common_str += src_tokens[i] - tmp_new_start_1 = i + 1 - new_start_1, new_start_2 = edit[1], edit[3] - if common_str: - tmp_str = "" - for i in range(edit[3], edit[4]): - tmp_str += tgt_tokens[i] - if tmp_str == common_str: - new_start_1, new_start_2 = tmp_new_start_1, i + 1 - # second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4])) - b = False - break - elif len(tmp_str) > len(common_str): - break - # 尾部有重叠 - common_str = "" - new_end_1, new_end_2 = edit[2], edit[4] - tmp_new_end_1 = edit[2] - for i in reversed(range(new_start_1, edit[2])): - if not span2.endswith(src_tokens[i] + common_str): - break - common_str = src_tokens[i] + common_str - tmp_new_end_1 = i - if common_str: - tmp_str = "" - for i in reversed(range(new_start_2, edit[4])): - tmp_str = tgt_tokens[i] + tmp_str - if tmp_str == common_str: - new_end_1, new_end_2 = tmp_new_end_1, i - b = False - break - elif len(tmp_str) > len(common_str): - break - if b: - second_filter.append(edit) - else: - if new_start_1 == new_end_1: - new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2) - elif new_start_2 == new_end_2: - new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2) - else: - new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2) - second_filter.append(new_edit) - else: - second_filter.append(edit) - if verbose: - print("========== Parallels ==========") - print("".join(src_tokens)) - print("".join(tgt_tokens)) - print("========== Results ==========") - for edit in second_filter: - op = edit[0] - s = " ".join(src_tokens[edit[1]: edit[2]]) - t = " ".join(tgt_tokens[edit[3]: edit[4]]) - print(f"{op}:\t{s}\t-->\t{t}") - print("========== Infos ==========") - print(str(src)) - print(str(tgt)) - return second_filter - -if __name__ == "__main__": - tokenizer = Tokenizer("char") - semantic_dict, semantic_class = read_cilin() - confusion_dict = read_confusion() - alignment = Alignment(semantic_dict, confusion_dict) - sents = [ - "所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace( - " ", ""), - "所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace( - " ", "")] - src, tgt = tokenizer(sents) - align_obj = alignment(src, tgt) - m = Merger() +from itertools import groupby +from string import punctuation +from typing import List +from modules.tokenizer import Tokenizer +from modules.alignment import Alignment, read_cilin, read_confusion +import Levenshtein + +class Merger: + """ + 合并编辑操作,从Token-Level转换为Span-Level + """ + + def __init__(self, + granularity: str = "word", + merge: bool = False): + chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧." + self.punctuation = punctuation + chinese_punct + self.not_merge_token = [punct for punct in self.punctuation] + self.granularity = granularity + self.merge = merge + + @staticmethod + def _merge_edits(seq, tag="X"): + if seq: + return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])] + else: + return seq + + @staticmethod + def _check_revolve(span_a, span_b): + span_a = span_a + span_a + return span_b in span_a + + def _process_seq(self, seq, src_tokens, tgt_tokens): + if len(seq) <= 1: + return seq + + ops = [op[0] for op in seq] + if set(ops) == {"D"} or set(ops) == {"I"}: + return self._merge_edits(seq, set(ops).pop()) + + if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}: + # do not merge this pattern_from_qua.txt + return seq + + if set(ops) == {"S"}: + if self.granularity == "word": + return seq + else: + return self._merge_edits(seq, "S") + + if set(ops) == {"M"}: + return self._merge_edits(seq, "M") + + return self._merge_edits(seq, "S") + + def __call__(self, + align_obj, + src: List, + tgt: List, + verbose: bool = False): + """ + Based on ERRANT's merge, adapted for Chinese + """ + src_tokens = [x[0] for x in src] + tgt_tokens = [x[0] for x in tgt] + edits = [] + # Split alignment into groups of M, T and rest. (T has a number after it) + # Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并 + # Todo 缺失成分标签也不与其它编辑合并 + for op, group in groupby( + align_obj, + lambda x: x[0][0] if x[0][0] in {"M", "T"} else False, + ): + group = list(group) + # T is always split TODO: Evaluate this + if op == "T": + for seq in group: + edits.append(seq) + # Process D, I and S subsequence + else: + # Turn the processed sequence into edits + processed = self._process_seq(group, src_tokens, tgt_tokens) + for seq in processed: + edits.append(seq) + + filtered_edits = [] + i = 0 + while i < len(edits): + e1 = edits[i][0][0] + + if i < len(edits) - 2: + e2 = edits[i + 1][0][0] + e3 = edits[i + 2][0][0] + + # Find "S M S" patterns + # Ex: + # S M S + # 冬阴功 对 外国人 + # 外国人 对 冬阴功 + if e1 == "S" and e2 == "M" and e3 == "S": + w1 = "".join(src_tokens[edits[i][1]: edits[i][2]]) + w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]]) + w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]]) + w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]) + if min([len(w1), len(w2), len(w3), len(w4)]) == 1: + if w1 == w4 and w2 == w3: + group = [edits[i], edits[i + 1], edits[i + 2]] + processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1])) + for seq in processed: + filtered_edits.append(seq) + i += 3 + else: + filtered_edits.append(edits[i]) + i += 1 + else: + if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1: + group = [edits[i], edits[i + 1], edits[i + 2]] + processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1])) + for seq in processed: + filtered_edits.append(seq) + i += 3 + else: + filtered_edits.append(edits[i]) + i += 1 + # Find "D M I" or "I M D" patterns + # Ex: + # D M I + # 旅游 去 陌生 的 地方 + # 去 陌生 的 地方 旅游 + elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"): + if e1 == "D": + delete_token = src_tokens[edits[i][1]: edits[i][2]] + insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]] + else: + delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]] + insert_token = tgt_tokens[edits[i][3]: edits[i][4]] + a, b = "".join(delete_token), "".join(insert_token) + if len(a) < len(b): + a, b = b, a + if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1: + if len(b) == 1: + if a == b: + group = [edits[i], edits[i + 1], edits[i + 2]] + processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1])) + for seq in processed: + filtered_edits.append(seq) + i += 3 + else: + filtered_edits.append(edits[i]) + i += 1 + else: + if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)): + group = [edits[i], edits[i + 1], edits[i + 2]] + processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1])) + for seq in processed: + filtered_edits.append(seq) + i += 3 + else: + filtered_edits.append(edits[i]) + i += 1 + else: + filtered_edits.append(edits[i]) + i += 1 + else: + if e1 != "M": + filtered_edits.append(edits[i]) + i += 1 + else: + if e1 != "M": + filtered_edits.append(edits[i]) + i += 1 + # In rare cases with word-level tokenization, the following error can occur: + # M D S M + # 有 時 住 上層 + # 有 時住 上層 + # Which results in S: 時住 --> 時住 + # We need to filter this case out + second_filter = [] + for edit in filtered_edits: # 避免因为分词错误导致的mismatch现象 + span1 = "".join(src_tokens[edit[1] : edit[2]]) + span2 = "".join(tgt_tokens[edit[3] : edit[4]]) + + if span1 != span2: + if edit[0] == "S": + b = True + # In rare cases with word-level tokenization, the following error can occur: + # S I I M + # 负责任 老师 + # 负 责任 的 老师 + # Which results in S: 负责任 --> 负 责任 的 + # We need to convert this edit to I: --> 的 + + # 首部有重叠 + common_str = "" + tmp_new_start_1 = edit[1] + for i in range(edit[1], edit[2]): + if not span2.startswith(common_str + src_tokens[i]): + break + common_str += src_tokens[i] + tmp_new_start_1 = i + 1 + new_start_1, new_start_2 = edit[1], edit[3] + if common_str: + tmp_str = "" + for i in range(edit[3], edit[4]): + tmp_str += tgt_tokens[i] + if tmp_str == common_str: + new_start_1, new_start_2 = tmp_new_start_1, i + 1 + # second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4])) + b = False + break + elif len(tmp_str) > len(common_str): + break + # 尾部有重叠 + common_str = "" + new_end_1, new_end_2 = edit[2], edit[4] + tmp_new_end_1 = edit[2] + for i in reversed(range(new_start_1, edit[2])): + if not span2.endswith(src_tokens[i] + common_str): + break + common_str = src_tokens[i] + common_str + tmp_new_end_1 = i + if common_str: + tmp_str = "" + for i in reversed(range(new_start_2, edit[4])): + tmp_str = tgt_tokens[i] + tmp_str + if tmp_str == common_str: + new_end_1, new_end_2 = tmp_new_end_1, i + b = False + break + elif len(tmp_str) > len(common_str): + break + if b: + second_filter.append(edit) + else: + if new_start_1 == new_end_1: + new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2) + elif new_start_2 == new_end_2: + new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2) + else: + new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2) + second_filter.append(new_edit) + else: + second_filter.append(edit) + if verbose: + print("========== Parallels ==========") + print("".join(src_tokens)) + print("".join(tgt_tokens)) + print("========== Results ==========") + for edit in second_filter: + op = edit[0] + s = " ".join(src_tokens[edit[1]: edit[2]]) + t = " ".join(tgt_tokens[edit[3]: edit[4]]) + print(f"{op}:\t{s}\t-->\t{t}") + print("========== Infos ==========") + print(str(src)) + print(str(tgt)) + return second_filter + +if __name__ == "__main__": + tokenizer = Tokenizer("char") + semantic_dict, semantic_class = read_cilin() + confusion_dict = read_confusion() + alignment = Alignment(semantic_dict, confusion_dict) + sents = [ + "所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace( + " ", ""), + "所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace( + " ", "")] + src, tgt = tokenizer(sents) + align_obj = alignment(src, tgt) + m = Merger() m(align_obj, src, tgt, verbose=True) \ No newline at end of file diff --git a/opencompass/datasets/lawbench/utils/modules/tokenizer.py b/opencompass/datasets/lawbench/utils/modules/tokenizer.py index c9653e44..aa64cb97 100644 --- a/opencompass/datasets/lawbench/utils/modules/tokenizer.py +++ b/opencompass/datasets/lawbench/utils/modules/tokenizer.py @@ -1,92 +1,92 @@ -from ltp import LTP -from typing import List -from pypinyin import pinyin, Style, lazy_pinyin -import torch -import os -import functools - -class Tokenizer: - """ - 分词器 - """ - - def __init__(self, - granularity: str = "word", - device: str = "cpu", - segmented: bool = False, - bpe: bool = False, - ) -> None: - """ - 构造函数 - :param mode: 分词模式,可选级别:字级别(char)、词级别(word) - """ - self.ltp = None - if granularity == "word": - self.ltp = LTP(device=torch.device(device) if torch.cuda.is_available() else torch.device("cpu")) - self.ltp.add_words(words=["[缺失成分]"], max_window=6) - self.segmented = segmented - self.granularity = granularity - if self.granularity == "word": - self.tokenizer = self.split_word - elif self.granularity == "char": - self.tokenizer = functools.partial(self.split_char, bpe=bpe) - else: - raise NotImplementedError - - def __repr__(self) -> str: - return "{:s}\nMode:{:s}\n}".format(str(self.__class__.__name__), self.mode) - - def __call__(self, - input_strings: List[str] - ) -> List: - """ - 分词函数 - :param input_strings: 需要分词的字符串列表 - :return: 分词后的结果列表,由元组组成,元组为(token,pos_tag,pinyin)的形式 - """ - if not self.segmented: - input_strings = ["".join(s.split(" ")) for s in input_strings] - results = self.tokenizer(input_strings) - return results - - def split_char(self, input_strings: List[str], bpe=False) -> List: - """ - 分字函数 - :param input_strings: 需要分字的字符串 - :return: 分字结果 - """ - if bpe: - from . import tokenization - project_dir = os.path.dirname(os.path.dirname(__file__)) - tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(project_dir,"data","chinese_vocab.txt"), do_lower_case=False) - results = [] - for input_string in input_strings: - if not self.segmented: # 如果没有被分字,就按照每个字符隔开(不考虑英文标点的特殊处理,也不考虑BPE),否则遵循原分字结果 - segment_string = " ".join([char for char in input_string] if not bpe else tokenizer.tokenize(input_string)) - else: - segment_string = input_string - # print(segment_string) - segment_string = segment_string.replace("[ 缺 失 成 分 ]", "[缺失成分]").split(" ") # 缺失成分当成一个单独的token - results.append([(char, "unk", pinyin(char, style=Style.NORMAL, heteronym=True)[0]) for char in segment_string]) - return results - - def split_word(self, input_strings: List[str]) -> List: - """ - 分词函数 - :param input_strings: 需要分词的字符串 - :return: 分词结果 - """ - if self.segmented: - seg, hidden = self.ltp.seg([input_string.split(" ") for input_string in input_strings], is_preseged=True) - else: - seg, hidden = self.ltp.seg(input_strings) - pos = self.ltp.pos(hidden) - result = [] - for s, p in zip(seg, pos): - pinyin = [lazy_pinyin(word) for word in s] - result.append(list(zip(s, p, pinyin))) - return result - -if __name__ == "__main__": - tokenizer = Tokenizer("word") - print(tokenizer(["LAC是个优秀的分词工具", "百度是一家高科技公司"])) +from ltp import LTP +from typing import List +from pypinyin import pinyin, Style, lazy_pinyin +import torch +import os +import functools + +class Tokenizer: + """ + 分词器 + """ + + def __init__(self, + granularity: str = "word", + device: str = "cpu", + segmented: bool = False, + bpe: bool = False, + ) -> None: + """ + 构造函数 + :param mode: 分词模式,可选级别:字级别(char)、词级别(word) + """ + self.ltp = None + if granularity == "word": + self.ltp = LTP(device=torch.device(device) if torch.cuda.is_available() else torch.device("cpu")) + self.ltp.add_words(words=["[缺失成分]"], max_window=6) + self.segmented = segmented + self.granularity = granularity + if self.granularity == "word": + self.tokenizer = self.split_word + elif self.granularity == "char": + self.tokenizer = functools.partial(self.split_char, bpe=bpe) + else: + raise NotImplementedError + + def __repr__(self) -> str: + return "{:s}\nMode:{:s}\n}".format(str(self.__class__.__name__), self.mode) + + def __call__(self, + input_strings: List[str] + ) -> List: + """ + 分词函数 + :param input_strings: 需要分词的字符串列表 + :return: 分词后的结果列表,由元组组成,元组为(token,pos_tag,pinyin)的形式 + """ + if not self.segmented: + input_strings = ["".join(s.split(" ")) for s in input_strings] + results = self.tokenizer(input_strings) + return results + + def split_char(self, input_strings: List[str], bpe=False) -> List: + """ + 分字函数 + :param input_strings: 需要分字的字符串 + :return: 分字结果 + """ + if bpe: + from . import tokenization + project_dir = os.path.dirname(os.path.dirname(__file__)) + tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(project_dir,"data","chinese_vocab.txt"), do_lower_case=False) + results = [] + for input_string in input_strings: + if not self.segmented: # 如果没有被分字,就按照每个字符隔开(不考虑英文标点的特殊处理,也不考虑BPE),否则遵循原分字结果 + segment_string = " ".join([char for char in input_string] if not bpe else tokenizer.tokenize(input_string)) + else: + segment_string = input_string + # print(segment_string) + segment_string = segment_string.replace("[ 缺 失 成 分 ]", "[缺失成分]").split(" ") # 缺失成分当成一个单独的token + results.append([(char, "unk", pinyin(char, style=Style.NORMAL, heteronym=True)[0]) for char in segment_string]) + return results + + def split_word(self, input_strings: List[str]) -> List: + """ + 分词函数 + :param input_strings: 需要分词的字符串 + :return: 分词结果 + """ + if self.segmented: + seg, hidden = self.ltp.seg([input_string.split(" ") for input_string in input_strings], is_preseged=True) + else: + seg, hidden = self.ltp.seg(input_strings) + pos = self.ltp.pos(hidden) + result = [] + for s, p in zip(seg, pos): + pinyin = [lazy_pinyin(word) for word in s] + result.append(list(zip(s, p, pinyin))) + return result + +if __name__ == "__main__": + tokenizer = Tokenizer("word") + print(tokenizer(["LAC是个优秀的分词工具", "百度是一家高科技公司"])) diff --git a/opencompass/datasets/lawbench/utils/parallel_to_m2.py b/opencompass/datasets/lawbench/utils/parallel_to_m2.py index 30dbb2f1..6b2c035b 100644 --- a/opencompass/datasets/lawbench/utils/parallel_to_m2.py +++ b/opencompass/datasets/lawbench/utils/parallel_to_m2.py @@ -1,221 +1,221 @@ -import os -from modules.annotator import Annotator -from modules.tokenizer import Tokenizer -import argparse -from collections import Counter -from tqdm import tqdm -import torch -from collections import defaultdict -from multiprocessing import Pool -from opencc import OpenCC -import timeout_decorator - -os.environ["TOKENIZERS_PARALLELISM"] = "false" - -annotator, sentence_to_tokenized = None, None -cc = OpenCC("t2s") - -@timeout_decorator.timeout(10) -def annotate_with_time_out(line): - """ - :param line: - :return: - """ - sent_list = line.split("\t")[1:] - source = sent_list[0] - if args.segmented: - source = source.strip() - else: - source = "".join(source.strip().split()) - output_str = "" - for idx, target in enumerate(sent_list[1:]): - try: - if args.segmented: - target = target.strip() - else: - target = "".join(target.strip().split()) - if not args.no_simplified: - target = cc.convert(target) - source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target] - out, cors = annotator(source_tokenized, target_tokenized, idx) - if idx == 0: - output_str += "".join(out[:-1]) - else: - output_str += "".join(out[1:-1]) - except Exception: - raise Exception - return output_str - - -def annotate(line): - """ - :param line: - :return: - """ - sent_list = line.split("\t")[1:] - source = sent_list[0] - if args.segmented: - source = source.strip() - else: - source = "".join(source.strip().split()) - output_str = "" - for idx, target in enumerate(sent_list[1:]): - try: - if args.segmented: - target = target.strip() - else: - target = "".join(target.strip().split()) - if not args.no_simplified: - target = cc.convert(target) - source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target] - out, cors = annotator(source_tokenized, target_tokenized, idx) - if idx == 0: - output_str += "".join(out[:-1]) - else: - output_str += "".join(out[1:-1]) - except Exception: - raise Exception - return output_str - - - - - -def firsttime_process(args): - tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe) - global annotator, sentence_to_tokenized - annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy) - lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n") # format: id src tgt1 tgt2... - # error_types = [] - - with open(args.output, "w", encoding="utf-8") as f: - count = 0 - sentence_set = set() - sentence_to_tokenized = {} - for line in lines: - sent_list = line.split("\t")[1:] - for idx, sent in enumerate(sent_list): - if args.segmented: - # print(sent) - sent = sent.strip() - else: - sent = "".join(sent.split()).strip() - if idx >= 1: - if not args.no_simplified: - sentence_set.add(cc.convert(sent)) - else: - sentence_set.add(sent) - else: - sentence_set.add(sent) - batch = [] - for sent in tqdm(sentence_set): - count += 1 - if sent: - batch.append(sent) - if count % args.batch_size == 0: - results = tokenizer(batch) - for s, r in zip(batch, results): - sentence_to_tokenized[s] = r # Get tokenization map. - batch = [] - if batch: - results = tokenizer(batch) - for s, r in zip(batch, results): - sentence_to_tokenized[s] = r # Get tokenization map. - - timeout_indices = [] - - # 单进程模式 - for idx, line in enumerate(tqdm(lines)): - try: - ret = annotate_with_time_out(line) - except Exception: - timeout_indices.append(idx) - return timeout_indices - - - -def main(args): - timeout_indices = firsttime_process(args) - tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe) - global annotator, sentence_to_tokenized - annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy) - lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n") - new_lines = []# format: id src tgt1 tgt2... - - with open(args.output, "w", encoding="utf-8") as f: - count = 0 - sentence_set = set() - sentence_to_tokenized = {} - for line_idx, line in enumerate(lines): - - if line_idx in timeout_indices: - # print(f"line before split: {line}") - line_split = line.split("\t") - line_number, sent_list = line_split[0], line_split[1:] - assert len(sent_list) == 2 - sent_list[-1] = " 无" - line = line_number + "\t" + "\t".join(sent_list) - # print(f"line time out: {line}") - new_lines.append(line) - else: - new_lines.append(line) - - sent_list = line.split("\t")[1:] - for idx, sent in enumerate(sent_list): - if args.segmented: - # print(sent) - sent = sent.strip() - else: - sent = "".join(sent.split()).strip() - if idx >= 1: - if not args.no_simplified: - sentence_set.add(cc.convert(sent)) - else: - sentence_set.add(sent) - else: - sentence_set.add(sent) - batch = [] - for sent in tqdm(sentence_set): - count += 1 - if sent: - batch.append(sent) - if count % args.batch_size == 0: - results = tokenizer(batch) - for s, r in zip(batch, results): - sentence_to_tokenized[s] = r # Get tokenization map. - batch = [] - if batch: - results = tokenizer(batch) - for s, r in zip(batch, results): - sentence_to_tokenized[s] = r # Get tokenization map. - - # 单进程模式 - lines = new_lines - for idx, line in enumerate(tqdm(lines)): - ret = annotate(line) - f.write(ret) - f.write("\n") - - # 多进程模式:仅在Linux环境下测试,建议在linux服务器上使用 - # with Pool(args.worker_num) as pool: - # for ret in pool.imap(annotate, tqdm(lines), chunksize=8): - # if ret: - # f.write(ret) - # f.write("\n") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Choose input file to annotate") - parser.add_argument("-f", "--file", type=str, required=True, help="Input parallel file") - parser.add_argument("-o", "--output", type=str, help="Output file", required=True) - parser.add_argument("-b", "--batch_size", type=int, help="The size of batch", default=128) - parser.add_argument("-d", "--device", type=int, help="The ID of GPU", default=0) - parser.add_argument("-w", "--worker_num", type=int, help="The number of workers", default=16) - parser.add_argument("-g", "--granularity", type=str, help="Choose char-level or word-level evaluation", default="char") - parser.add_argument("-m", "--merge", help="Whether merge continuous replacement/deletion/insertion", action="store_true") - parser.add_argument("-s", "--multi_cheapest_strategy", type=str, choices=["first", "all"], default="all") - parser.add_argument("--segmented", help="Whether tokens have been segmented", action="store_true") # 支持提前token化,用空格隔开 - parser.add_argument("--no_simplified", help="Whether simplifying chinese", action="store_true") # 将所有corrections转换为简体中文 - parser.add_argument("--bpe", help="Whether to use bpe", action="store_true") # 支持 bpe 切分英文单词 - args = parser.parse_args() - main(args) +import os +from modules.annotator import Annotator +from modules.tokenizer import Tokenizer +import argparse +from collections import Counter +from tqdm import tqdm +import torch +from collections import defaultdict +from multiprocessing import Pool +from opencc import OpenCC +import timeout_decorator + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +annotator, sentence_to_tokenized = None, None +cc = OpenCC("t2s") + +@timeout_decorator.timeout(10) +def annotate_with_time_out(line): + """ + :param line: + :return: + """ + sent_list = line.split("\t")[1:] + source = sent_list[0] + if args.segmented: + source = source.strip() + else: + source = "".join(source.strip().split()) + output_str = "" + for idx, target in enumerate(sent_list[1:]): + try: + if args.segmented: + target = target.strip() + else: + target = "".join(target.strip().split()) + if not args.no_simplified: + target = cc.convert(target) + source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target] + out, cors = annotator(source_tokenized, target_tokenized, idx) + if idx == 0: + output_str += "".join(out[:-1]) + else: + output_str += "".join(out[1:-1]) + except Exception: + raise Exception + return output_str + + +def annotate(line): + """ + :param line: + :return: + """ + sent_list = line.split("\t")[1:] + source = sent_list[0] + if args.segmented: + source = source.strip() + else: + source = "".join(source.strip().split()) + output_str = "" + for idx, target in enumerate(sent_list[1:]): + try: + if args.segmented: + target = target.strip() + else: + target = "".join(target.strip().split()) + if not args.no_simplified: + target = cc.convert(target) + source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target] + out, cors = annotator(source_tokenized, target_tokenized, idx) + if idx == 0: + output_str += "".join(out[:-1]) + else: + output_str += "".join(out[1:-1]) + except Exception: + raise Exception + return output_str + + + + + +def firsttime_process(args): + tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe) + global annotator, sentence_to_tokenized + annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy) + lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n") # format: id src tgt1 tgt2... + # error_types = [] + + with open(args.output, "w", encoding="utf-8") as f: + count = 0 + sentence_set = set() + sentence_to_tokenized = {} + for line in lines: + sent_list = line.split("\t")[1:] + for idx, sent in enumerate(sent_list): + if args.segmented: + # print(sent) + sent = sent.strip() + else: + sent = "".join(sent.split()).strip() + if idx >= 1: + if not args.no_simplified: + sentence_set.add(cc.convert(sent)) + else: + sentence_set.add(sent) + else: + sentence_set.add(sent) + batch = [] + for sent in tqdm(sentence_set): + count += 1 + if sent: + batch.append(sent) + if count % args.batch_size == 0: + results = tokenizer(batch) + for s, r in zip(batch, results): + sentence_to_tokenized[s] = r # Get tokenization map. + batch = [] + if batch: + results = tokenizer(batch) + for s, r in zip(batch, results): + sentence_to_tokenized[s] = r # Get tokenization map. + + timeout_indices = [] + + # 单进程模式 + for idx, line in enumerate(tqdm(lines)): + try: + ret = annotate_with_time_out(line) + except Exception: + timeout_indices.append(idx) + return timeout_indices + + + +def main(args): + timeout_indices = firsttime_process(args) + tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe) + global annotator, sentence_to_tokenized + annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy) + lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n") + new_lines = []# format: id src tgt1 tgt2... + + with open(args.output, "w", encoding="utf-8") as f: + count = 0 + sentence_set = set() + sentence_to_tokenized = {} + for line_idx, line in enumerate(lines): + + if line_idx in timeout_indices: + # print(f"line before split: {line}") + line_split = line.split("\t") + line_number, sent_list = line_split[0], line_split[1:] + assert len(sent_list) == 2 + sent_list[-1] = " 无" + line = line_number + "\t" + "\t".join(sent_list) + # print(f"line time out: {line}") + new_lines.append(line) + else: + new_lines.append(line) + + sent_list = line.split("\t")[1:] + for idx, sent in enumerate(sent_list): + if args.segmented: + # print(sent) + sent = sent.strip() + else: + sent = "".join(sent.split()).strip() + if idx >= 1: + if not args.no_simplified: + sentence_set.add(cc.convert(sent)) + else: + sentence_set.add(sent) + else: + sentence_set.add(sent) + batch = [] + for sent in tqdm(sentence_set): + count += 1 + if sent: + batch.append(sent) + if count % args.batch_size == 0: + results = tokenizer(batch) + for s, r in zip(batch, results): + sentence_to_tokenized[s] = r # Get tokenization map. + batch = [] + if batch: + results = tokenizer(batch) + for s, r in zip(batch, results): + sentence_to_tokenized[s] = r # Get tokenization map. + + # 单进程模式 + lines = new_lines + for idx, line in enumerate(tqdm(lines)): + ret = annotate(line) + f.write(ret) + f.write("\n") + + # 多进程模式:仅在Linux环境下测试,建议在linux服务器上使用 + # with Pool(args.worker_num) as pool: + # for ret in pool.imap(annotate, tqdm(lines), chunksize=8): + # if ret: + # f.write(ret) + # f.write("\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Choose input file to annotate") + parser.add_argument("-f", "--file", type=str, required=True, help="Input parallel file") + parser.add_argument("-o", "--output", type=str, help="Output file", required=True) + parser.add_argument("-b", "--batch_size", type=int, help="The size of batch", default=128) + parser.add_argument("-d", "--device", type=int, help="The ID of GPU", default=0) + parser.add_argument("-w", "--worker_num", type=int, help="The number of workers", default=16) + parser.add_argument("-g", "--granularity", type=str, help="Choose char-level or word-level evaluation", default="char") + parser.add_argument("-m", "--merge", help="Whether merge continuous replacement/deletion/insertion", action="store_true") + parser.add_argument("-s", "--multi_cheapest_strategy", type=str, choices=["first", "all"], default="all") + parser.add_argument("--segmented", help="Whether tokens have been segmented", action="store_true") # 支持提前token化,用空格隔开 + parser.add_argument("--no_simplified", help="Whether simplifying chinese", action="store_true") # 将所有corrections转换为简体中文 + parser.add_argument("--bpe", help="Whether to use bpe", action="store_true") # 支持 bpe 切分英文单词 + args = parser.parse_args() + main(args) diff --git a/opencompass/runners/dlc.py b/opencompass/runners/dlc.py index d179a6f4..3b305771 100644 --- a/opencompass/runners/dlc.py +++ b/opencompass/runners/dlc.py @@ -1,10 +1,13 @@ +import datetime import os import os.path as osp import random +import re import subprocess +import sys import time from functools import partial -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import mmengine from mmengine.config import ConfigDict @@ -43,6 +46,11 @@ class DLCRunner(BaseRunner): self.max_num_workers = max_num_workers self.retry = retry + logger = get_logger() + logger.warning( + 'To ensure the integrity of the log results, the log displayed ' + f'by {self.__class__.__name__} has a 10-second delay.') + def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]: """Launch multiple tasks. @@ -63,18 +71,23 @@ class DLCRunner(BaseRunner): status = [self._launch(task, random_sleep=False) for task in tasks] return status - def _launch(self, cfg: ConfigDict, random_sleep: bool = True): + def _launch(self, cfg: ConfigDict, random_sleep: Optional[bool] = None): """Launch a single task. Args: cfg (ConfigDict): Task config. random_sleep (bool): Whether to sleep for a random time before - running the command. This avoids cluster error when launching - multiple tasks at the same time. Default: True. + running the command. When Aliyun has many tasks to schedule, + its stability decreases. Therefore, when we need to submit a + large number of tasks at once, we adopt the "random_sleep" + strategy. Tasks that would have been submitted all at once are + now evenly spread out over a 10-second period. Default: None. Returns: tuple[str, int]: Task name and exit code. """ + if random_sleep is None: + random_sleep = (self.max_num_workers > 32) task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type'])) num_gpus = task.num_gpus @@ -116,7 +129,7 @@ class DLCRunner(BaseRunner): # Run command with retry if self.debug: - stdout = None + stdout = sys.stdout else: out_path = task.get_log_path(file_extension='out') mmengine.mkdir_or_exist(osp.split(out_path)[0]) @@ -124,30 +137,92 @@ class DLCRunner(BaseRunner): if random_sleep: time.sleep(random.randint(0, 10)) - result = subprocess.run(cmd, - shell=True, - text=True, - stdout=stdout, - stderr=stdout) + def _run_within_retry(): + try: + process = subprocess.Popen(cmd, + shell=True, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + job_id = None + job_allocated = False + job_finished = False + last_end_time = datetime.datetime.now().strftime( + '%Y-%m-%dT%H:%M:%SZ') + while True: + if not job_allocated: + line = process.stdout.readline() + if not line: + break + match = re.search(r'(dlc[0-9a-z]+)', line) + if match and job_id is None: + job_id = match.group(1) + stdout.write(line) + match = re.search(r'Job .* is \[Running\]', line) + if match: + job_allocated = True + else: + try: + process.wait(10) + except subprocess.TimeoutExpired: + pass + else: + job_finished = True + if job_finished: + this_end_time = datetime.datetime.now( + ).strftime('%Y-%m-%dT%H:%M:%SZ') + else: + this_end_time = ( + datetime.datetime.now() - + datetime.timedelta(seconds=10) + ).strftime('%Y-%m-%dT%H:%M:%SZ') + logs_cmd = ( + 'dlc logs' + f' {job_id} {job_id}-worker-0' + f' --start_time {last_end_time}' + f' --end_time {this_end_time}' + f" -c {self.aliyun_cfg['dlc_config_path']}") + log_process = subprocess.Popen( + logs_cmd, + shell=True, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + log_output, log_err = log_process.communicate() + log_output = '\n'.join(log_output.split('\n')[2:]) + stdout.write(log_output) + last_end_time = this_end_time + stdout.flush() + if job_finished: + break + process.wait() + return process.returncode + finally: + if job_id is not None: + cancel_cmd = ( + 'dlc stop job' + f' {job_id}' + f" -c {self.aliyun_cfg['dlc_config_path']}" + ' -f') + subprocess.run(cancel_cmd, + shell=True, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + return_code = _run_within_retry() retry = self.retry output_paths = task.get_output_paths() - while self._job_failed(result.returncode, - output_paths) and retry > 0: + while self._job_failed(return_code, output_paths) and retry > 0: retry -= 1 - if random_sleep: - time.sleep(random.randint(0, 10)) - # Re-generate command to refresh ports. cmd = get_cmd() - result = subprocess.run(cmd, - shell=True, - text=True, - stdout=stdout, - stderr=stdout) + return_code = _run_within_retry() finally: # Clean up os.remove(param_file) - return task_name, result.returncode + + return task_name, return_code def _job_failed(self, return_code: int, output_paths: List[str]) -> bool: return return_code != 0 or not all(