mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Sync] sync with internal codes 20231019 (#488)
This commit is contained in:
parent
2737249f31
commit
4dd9a3fc10
@ -3,7 +3,9 @@ exclude: |
|
||||
tests/data/|
|
||||
opencompass/models/internal/|
|
||||
opencompass/utils/internal/|
|
||||
opencompass/openicl/icl_evaluator/hf_metrics/
|
||||
opencompass/openicl/icl_evaluator/hf_metrics/|
|
||||
opencompass/datasets/lawbench/utils|
|
||||
opencompass/datasets/lawbench/evaluation_functions/
|
||||
)
|
||||
repos:
|
||||
- repo: https://gitee.com/openmmlab/mirrors-flake8
|
||||
|
@ -1,19 +1,19 @@
|
||||
from ..utils.function_utils import compute_rouge
|
||||
|
||||
#情景法条识别
|
||||
|
||||
def compute_cjft(data_dict):
|
||||
"""
|
||||
Compute the ROUGE-L score between the prediction and the reference
|
||||
"""
|
||||
references, predictions = [], []
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
predictions.append(prediction)
|
||||
references.append(answer)
|
||||
|
||||
# compute the accuracy of score_list
|
||||
rouge_scores = compute_rouge(predictions, references)
|
||||
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
|
||||
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
|
||||
return {"score": average_rouge_l}
|
||||
from ..utils.function_utils import compute_rouge
|
||||
|
||||
#情景法条识别
|
||||
|
||||
def compute_cjft(data_dict):
|
||||
"""
|
||||
Compute the ROUGE-L score between the prediction and the reference
|
||||
"""
|
||||
references, predictions = [], []
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
predictions.append(prediction)
|
||||
references.append(answer)
|
||||
|
||||
# compute the accuracy of score_list
|
||||
rouge_scores = compute_rouge(predictions, references)
|
||||
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
|
||||
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
|
||||
return {"score": average_rouge_l}
|
||||
|
@ -1,18 +1,18 @@
|
||||
from ..utils.function_utils import compute_rouge
|
||||
|
||||
#法律咨询
|
||||
def compute_flzx(data_dict):
|
||||
"""
|
||||
Compute the ROUGE-L score between the prediction and the reference
|
||||
"""
|
||||
references, predictions = [], []
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
predictions.append(prediction)
|
||||
references.append(answer)
|
||||
|
||||
# compute the accuracy of score_list
|
||||
rouge_scores = compute_rouge(predictions, references)
|
||||
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
|
||||
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
|
||||
return {"score": average_rouge_l}
|
||||
from ..utils.function_utils import compute_rouge
|
||||
|
||||
#法律咨询
|
||||
def compute_flzx(data_dict):
|
||||
"""
|
||||
Compute the ROUGE-L score between the prediction and the reference
|
||||
"""
|
||||
references, predictions = [], []
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
predictions.append(prediction)
|
||||
references.append(answer)
|
||||
|
||||
# compute the accuracy of score_list
|
||||
rouge_scores = compute_rouge(predictions, references)
|
||||
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
|
||||
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
|
||||
return {"score": average_rouge_l}
|
||||
|
@ -1,19 +1,19 @@
|
||||
from ..utils.function_utils import compute_rouge
|
||||
|
||||
#法条记忆问答
|
||||
def compute_ftcs(data_dict):
|
||||
"""
|
||||
Compute the ROUGE-L score between the prediction and the reference
|
||||
"""
|
||||
references, predictions = [], []
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
answer = answer.replace("答案:", "")
|
||||
predictions.append(prediction)
|
||||
references.append(answer)
|
||||
|
||||
# compute the accuracy of score_list
|
||||
rouge_scores = compute_rouge(predictions, references)
|
||||
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
|
||||
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
|
||||
return {"score": average_rouge_l}
|
||||
from ..utils.function_utils import compute_rouge
|
||||
|
||||
#法条记忆问答
|
||||
def compute_ftcs(data_dict):
|
||||
"""
|
||||
Compute the ROUGE-L score between the prediction and the reference
|
||||
"""
|
||||
references, predictions = [], []
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
answer = answer.replace("答案:", "")
|
||||
predictions.append(prediction)
|
||||
references.append(answer)
|
||||
|
||||
# compute the accuracy of score_list
|
||||
rouge_scores = compute_rouge(predictions, references)
|
||||
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
|
||||
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
|
||||
return {"score": average_rouge_l}
|
||||
|
@ -1,36 +1,36 @@
|
||||
from ..utils.function_utils import multi_choice_judge
|
||||
|
||||
"""
|
||||
multi-choice single-label selection
|
||||
metric: accuracy
|
||||
争议焦点:识别案件涉及的争议焦点
|
||||
"""
|
||||
|
||||
def compute_jdzy(data_dict):
|
||||
"""
|
||||
Compute the Accuracy
|
||||
The JEC dataset has 16 possible answers for each question, stored in the option_list
|
||||
A prediction is correct if
|
||||
1. The correct answer appears in the prediction, and
|
||||
2. Options other than the answer do not appear in the prediction.
|
||||
"""
|
||||
|
||||
score_list, abstentions = [], 0
|
||||
option_list = ["诉讼主体", "租金情况", "利息", "本金争议", "责任认定", "责任划分", "损失认定及处理",
|
||||
"原审判决是否适当", "合同效力", "财产分割", "责任承担", "鉴定结论采信问题", "诉讼时效", "违约", "合同解除", "肇事逃逸"]
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
if answer[7:-1] == "赔偿":
|
||||
# todo: dataset imperfection
|
||||
continue
|
||||
assert answer.startswith("争议焦点类别:") and answer[7:-1] in option_list, \
|
||||
f"answer: {answer} \n question: {question}"
|
||||
|
||||
answer_letter = answer[7:-1]
|
||||
judge = multi_choice_judge(prediction, option_list, answer_letter)
|
||||
score_list.append(judge["score"])
|
||||
abstentions += judge["abstention"]
|
||||
|
||||
# compute the accuracy of score_list
|
||||
accuracy = sum(score_list) / len(score_list)
|
||||
return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)}
|
||||
from ..utils.function_utils import multi_choice_judge
|
||||
|
||||
"""
|
||||
multi-choice single-label selection
|
||||
metric: accuracy
|
||||
争议焦点:识别案件涉及的争议焦点
|
||||
"""
|
||||
|
||||
def compute_jdzy(data_dict):
|
||||
"""
|
||||
Compute the Accuracy
|
||||
The JEC dataset has 16 possible answers for each question, stored in the option_list
|
||||
A prediction is correct if
|
||||
1. The correct answer appears in the prediction, and
|
||||
2. Options other than the answer do not appear in the prediction.
|
||||
"""
|
||||
|
||||
score_list, abstentions = [], 0
|
||||
option_list = ["诉讼主体", "租金情况", "利息", "本金争议", "责任认定", "责任划分", "损失认定及处理",
|
||||
"原审判决是否适当", "合同效力", "财产分割", "责任承担", "鉴定结论采信问题", "诉讼时效", "违约", "合同解除", "肇事逃逸"]
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
if answer[7:-1] == "赔偿":
|
||||
# todo: dataset imperfection
|
||||
continue
|
||||
assert answer.startswith("争议焦点类别:") and answer[7:-1] in option_list, \
|
||||
f"answer: {answer} \n question: {question}"
|
||||
|
||||
answer_letter = answer[7:-1]
|
||||
judge = multi_choice_judge(prediction, option_list, answer_letter)
|
||||
score_list.append(judge["score"])
|
||||
abstentions += judge["abstention"]
|
||||
|
||||
# compute the accuracy of score_list
|
||||
accuracy = sum(score_list) / len(score_list)
|
||||
return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)}
|
||||
|
@ -1,29 +1,29 @@
|
||||
from ..utils.function_utils import multi_choice_judge
|
||||
|
||||
"""
|
||||
Task: multi-choice selection
|
||||
Metric: Accuracy
|
||||
司法考试-案例分析
|
||||
"""
|
||||
def compute_jec_ac(data_dict):
|
||||
"""
|
||||
Compute the Accuracy
|
||||
The JEC dataset has 4 options for each question: A, B, C, D
|
||||
A prediction is correct if
|
||||
1. The correct answer appears in the prediction, and
|
||||
2. Options other than the answer do not appear in the prediction.
|
||||
"""
|
||||
score_list, abstentions = [], 0
|
||||
option_list = ["A", "B", "C", "D"]
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
assert answer.startswith("正确答案:") and answer[5] in option_list, f"answer[5]: {answer}, question: {question}"
|
||||
|
||||
answer_letter = answer[5]
|
||||
judge = multi_choice_judge(prediction, option_list, answer_letter)
|
||||
score_list.append(judge["score"])
|
||||
abstentions += judge["abstention"]
|
||||
|
||||
# compute the accuracy of score_list
|
||||
accuracy = sum(score_list) / len(score_list)
|
||||
return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)}
|
||||
from ..utils.function_utils import multi_choice_judge
|
||||
|
||||
"""
|
||||
Task: multi-choice selection
|
||||
Metric: Accuracy
|
||||
司法考试-案例分析
|
||||
"""
|
||||
def compute_jec_ac(data_dict):
|
||||
"""
|
||||
Compute the Accuracy
|
||||
The JEC dataset has 4 options for each question: A, B, C, D
|
||||
A prediction is correct if
|
||||
1. The correct answer appears in the prediction, and
|
||||
2. Options other than the answer do not appear in the prediction.
|
||||
"""
|
||||
score_list, abstentions = [], 0
|
||||
option_list = ["A", "B", "C", "D"]
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
assert answer.startswith("正确答案:") and answer[5] in option_list, f"answer[5]: {answer}, question: {question}"
|
||||
|
||||
answer_letter = answer[5]
|
||||
judge = multi_choice_judge(prediction, option_list, answer_letter)
|
||||
score_list.append(judge["score"])
|
||||
abstentions += judge["abstention"]
|
||||
|
||||
# compute the accuracy of score_list
|
||||
accuracy = sum(score_list) / len(score_list)
|
||||
return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)}
|
||||
|
@ -1,29 +1,29 @@
|
||||
from ..utils.function_utils import multi_choice_judge
|
||||
|
||||
"""
|
||||
Task: multi-choice selection
|
||||
Metric: Accuracy
|
||||
司法考试
|
||||
"""
|
||||
def compute_jec_kd(data_dict):
|
||||
"""
|
||||
Compute the Accuracy
|
||||
The JEC_KD dataset has 4 options for each question: A, B, C, D
|
||||
A prediction is correct if
|
||||
1. The correct answer appears in the prediction, and
|
||||
2. Options other than the answer do not appear in the prediction.
|
||||
"""
|
||||
score_list, abstentions = [], 0
|
||||
option_list = ["A", "B", "C", "D"]
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
assert answer.startswith("正确答案:") and answer[5] in option_list, f"answer[5]: {answer}, question: {question}"
|
||||
|
||||
answer_letter = answer[5]
|
||||
judge = multi_choice_judge(prediction, option_list, answer_letter)
|
||||
score_list.append(judge["score"])
|
||||
abstentions += judge["abstention"]
|
||||
|
||||
# compute the accuracy of score_list
|
||||
accuracy = sum(score_list) / len(score_list)
|
||||
return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)}
|
||||
from ..utils.function_utils import multi_choice_judge
|
||||
|
||||
"""
|
||||
Task: multi-choice selection
|
||||
Metric: Accuracy
|
||||
司法考试
|
||||
"""
|
||||
def compute_jec_kd(data_dict):
|
||||
"""
|
||||
Compute the Accuracy
|
||||
The JEC_KD dataset has 4 options for each question: A, B, C, D
|
||||
A prediction is correct if
|
||||
1. The correct answer appears in the prediction, and
|
||||
2. Options other than the answer do not appear in the prediction.
|
||||
"""
|
||||
score_list, abstentions = [], 0
|
||||
option_list = ["A", "B", "C", "D"]
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
assert answer.startswith("正确答案:") and answer[5] in option_list, f"answer[5]: {answer}, question: {question}"
|
||||
|
||||
answer_letter = answer[5]
|
||||
judge = multi_choice_judge(prediction, option_list, answer_letter)
|
||||
score_list.append(judge["score"])
|
||||
abstentions += judge["abstention"]
|
||||
|
||||
# compute the accuracy of score_list
|
||||
accuracy = sum(score_list) / len(score_list)
|
||||
return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)}
|
||||
|
@ -1,43 +1,43 @@
|
||||
import re
|
||||
|
||||
"""
|
||||
number prediction
|
||||
metric: accuracy
|
||||
金额提取
|
||||
"""
|
||||
def compute_jetq(data_dict):
|
||||
"""
|
||||
Compute the Accuracy
|
||||
we extract the total amount of cost involved in the crime from the prediction and compare it with the reference
|
||||
The prediction is correct if
|
||||
the total amount of cost provided in the reference, appears in the prediction.
|
||||
"""
|
||||
score_list, abstentions = [], 0
|
||||
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
assert answer.startswith("上文涉及到的犯罪金额:"), f"answer: {answer}, question: {question}"
|
||||
assert answer.endswith("元。"), f"answer: {answer}, question: {question}"
|
||||
answer = answer.replace("上文涉及到的犯罪金额:", "")
|
||||
|
||||
assert "千元" not in answer, f"answer: {answer}, question: {question}"
|
||||
assert "万" not in answer, f"answer: {answer}, question: {question}"
|
||||
|
||||
# remove "元"
|
||||
answer = answer.replace("元。", "")
|
||||
answer = float(answer)
|
||||
|
||||
prediction_digits = re.findall(r"\d+\.?\d*", prediction)
|
||||
prediction_digits = [float(digit) for digit in prediction_digits]
|
||||
|
||||
if len(prediction_digits) == 0:
|
||||
abstentions += 1
|
||||
if answer in prediction_digits:
|
||||
score_list.append(1)
|
||||
else:
|
||||
score_list.append(0)
|
||||
|
||||
|
||||
# compute the accuracy of score_list
|
||||
accuracy = sum(score_list) / len(score_list)
|
||||
return {"score": accuracy, "abstention_rate": abstentions/len(data_dict)}
|
||||
import re
|
||||
|
||||
"""
|
||||
number prediction
|
||||
metric: accuracy
|
||||
金额提取
|
||||
"""
|
||||
def compute_jetq(data_dict):
|
||||
"""
|
||||
Compute the Accuracy
|
||||
we extract the total amount of cost involved in the crime from the prediction and compare it with the reference
|
||||
The prediction is correct if
|
||||
the total amount of cost provided in the reference, appears in the prediction.
|
||||
"""
|
||||
score_list, abstentions = [], 0
|
||||
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
assert answer.startswith("上文涉及到的犯罪金额:"), f"answer: {answer}, question: {question}"
|
||||
assert answer.endswith("元。"), f"answer: {answer}, question: {question}"
|
||||
answer = answer.replace("上文涉及到的犯罪金额:", "")
|
||||
|
||||
assert "千元" not in answer, f"answer: {answer}, question: {question}"
|
||||
assert "万" not in answer, f"answer: {answer}, question: {question}"
|
||||
|
||||
# remove "元"
|
||||
answer = answer.replace("元。", "")
|
||||
answer = float(answer)
|
||||
|
||||
prediction_digits = re.findall(r"\d+\.?\d*", prediction)
|
||||
prediction_digits = [float(digit) for digit in prediction_digits]
|
||||
|
||||
if len(prediction_digits) == 0:
|
||||
abstentions += 1
|
||||
if answer in prediction_digits:
|
||||
score_list.append(1)
|
||||
else:
|
||||
score_list.append(0)
|
||||
|
||||
|
||||
# compute the accuracy of score_list
|
||||
accuracy = sum(score_list) / len(score_list)
|
||||
return {"score": accuracy, "abstention_rate": abstentions/len(data_dict)}
|
||||
|
@ -1,29 +1,29 @@
|
||||
from ..utils.function_utils import multi_choice_judge
|
||||
|
||||
"""
|
||||
Task: multi-choice selection
|
||||
Metric: Accuracy
|
||||
论辩挖掘
|
||||
"""
|
||||
def compute_lblj(data_dict):
|
||||
"""
|
||||
Compute the Accuracy
|
||||
The LBLJ dataset has 5 options for each question: A, B, C, D, E
|
||||
A prediction is correct if
|
||||
1. The correct answer appears in the prediction, and
|
||||
2. Options other than the answer do not appear in the prediction.
|
||||
"""
|
||||
score_list, abstentions = [], 0
|
||||
option_list = ["A", "B", "C", "D", "E"]
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
assert answer.startswith("[正确答案]") and answer[6] in option_list, f"answer[6]: {answer}, question: {question}"
|
||||
|
||||
answer_letter = answer[6]
|
||||
judge = multi_choice_judge(prediction, option_list, answer_letter)
|
||||
score_list.append(judge["score"])
|
||||
abstentions += judge["abstention"]
|
||||
|
||||
# compute the accuracy of score_list
|
||||
accuracy = sum(score_list) / len(score_list)
|
||||
return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)}
|
||||
from ..utils.function_utils import multi_choice_judge
|
||||
|
||||
"""
|
||||
Task: multi-choice selection
|
||||
Metric: Accuracy
|
||||
论辩挖掘
|
||||
"""
|
||||
def compute_lblj(data_dict):
|
||||
"""
|
||||
Compute the Accuracy
|
||||
The LBLJ dataset has 5 options for each question: A, B, C, D, E
|
||||
A prediction is correct if
|
||||
1. The correct answer appears in the prediction, and
|
||||
2. Options other than the answer do not appear in the prediction.
|
||||
"""
|
||||
score_list, abstentions = [], 0
|
||||
option_list = ["A", "B", "C", "D", "E"]
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
assert answer.startswith("[正确答案]") and answer[6] in option_list, f"answer[6]: {answer}, question: {question}"
|
||||
|
||||
answer_letter = answer[6]
|
||||
judge = multi_choice_judge(prediction, option_list, answer_letter)
|
||||
score_list.append(judge["score"])
|
||||
abstentions += judge["abstention"]
|
||||
|
||||
# compute the accuracy of score_list
|
||||
accuracy = sum(score_list) / len(score_list)
|
||||
return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)}
|
||||
|
@ -1,76 +1,76 @@
|
||||
from ..utils.function_utils import compute_f1_two_sets
|
||||
"""
|
||||
task: legal accusation prediction
|
||||
metric: f1 score
|
||||
法律判决预测-罪名预测
|
||||
"""
|
||||
|
||||
option_list = ["侮辱", "违法发放贷款", "失火", "票据诈骗", "帮助犯罪分子逃避处罚", "重大责任事故", "对非国家工作人员行贿",
|
||||
"非法制造、销售非法制造的注册商标标识", "非法制造、买卖、运输、邮寄、储存枪支、弹药、爆炸物", "非法获取公民个人信息",
|
||||
"扰乱无线电通讯管理秩序", "非法持有、私藏枪支、弹药", "拒不执行判决、裁定", "虚开发票", "巨额财产来源不明",
|
||||
"组织、领导、参加黑社会性质组织", "非法获取国家秘密", "以危险方法危害公共安全", "非法持有毒品",
|
||||
"聚众扰乱公共场所秩序、交通秩序", "包庇毒品犯罪分子", "滥伐林木", "伪造公司、企业、事业单位、人民团体印章",
|
||||
"非法占用农用地", "走私废物", "串通投标", "非法采伐、毁坏国家重点保护植物", "冒充军人招摇撞骗", "玩忽职守",
|
||||
"重婚", "招收公务员、学生徇私舞弊", "组织、领导传销活动", "非法猎捕、杀害珍贵、濒危野生动物", "侵犯著作权",
|
||||
"非法种植毒品原植物", "伪造、变造、买卖武装部队公文、证件、印章", "倒卖文物", "伪造、变造居民身份证", "滥用职权",
|
||||
"诽谤", "猥亵儿童", "非法转让、倒卖土地使用权", "挪用公款", "污染环境", "出售、购买、运输假币", "敲诈勒索",
|
||||
"高利转贷", "故意伤害", "持有、使用假币", "单位受贿", "强奸", "引诱、容留、介绍卖淫", "虐待",
|
||||
"生产、销售伪劣农药、兽药、化肥、种子", "妨害公务", "容留他人吸毒", "拐骗儿童", "强制猥亵、侮辱妇女",
|
||||
"非法处置查封、扣押、冻结的财产", "骗取贷款、票据承兑、金融票证", "强迫他人吸毒", "非法拘禁",
|
||||
"非法携带枪支、弹药、管制刀具、危险物品危及公共安全", "绑架", "聚众斗殴", "破坏计算机信息系统",
|
||||
"制造、贩卖、传播淫秽物品", "虐待被监管人", "贷款诈骗", "赌博", "徇私舞弊不征、少征税款",
|
||||
"盗窃、抢夺枪支、弹药、爆炸物、危险物质", "故意杀人", "介绍贿赂", "提供侵入、非法控制计算机信息系统程序、工具",
|
||||
"编造、故意传播虚假恐怖信息", "妨害作证", "强迫卖淫", "走私、贩卖、运输、制造毒品", "伪证", "拐卖妇女、儿童",
|
||||
"过失损坏武器装备、军事设施、军事通信", "破坏广播电视设施、公用电信设施", "洗钱", "职务侵占", "倒卖车票、船票",
|
||||
"抢劫", "侵占", "掩饰、隐瞒犯罪所得、犯罪所得收益", "徇私舞弊不移交刑事案件", "引诱、教唆、欺骗他人吸毒", "遗弃",
|
||||
"生产、销售伪劣产品", "放火", "非法采矿", "对单位行贿", "盗窃、抢夺枪支、弹药、爆炸物", "破坏易燃易爆设备",
|
||||
"妨害信用卡管理", "制作、复制、出版、贩卖、传播淫秽物品牟利", "金融凭证诈骗", "私分国有资产",
|
||||
"走私国家禁止进出口的货物、物品", "假冒注册商标", "危险物品肇事", "走私普通货物、物品", "经济犯", "虚报注册资本",
|
||||
"盗掘古文化遗址、古墓葬", "传播淫秽物品", "窝藏、包庇", "拒不支付劳动报酬", "行贿", "开设赌场", "传授犯罪方法",
|
||||
"协助组织卖淫", "保险诈骗", "破坏生产经营", "破坏交通设施", "打击报复证人", "非法侵入住宅", "非国家工作人员受贿",
|
||||
"过失致人重伤", "伪造、变造金融票证", "窝藏、转移、隐瞒毒品、毒赃", "帮助毁灭、伪造证据", "走私珍贵动物、珍贵动物制品",
|
||||
"生产、销售假药", "逃税", "挪用特定款物", "聚众扰乱社会秩序", "组织、强迫、引诱、容留、介绍卖淫", "合同诈骗",
|
||||
"非法生产、销售间谍专用器材", "破坏交通工具", "传播性病", "强迫交易", "隐匿、故意销毁会计凭证、会计帐簿、财务会计报告",
|
||||
"非法组织卖血", "强迫劳动", "破坏电力设备", "销售假冒注册商标的商品", "收买被拐卖的妇女、儿童", "诬告陷害", "脱逃",
|
||||
"非法经营", "徇私枉法", "信用卡诈骗", "生产、销售不符合安全标准的食品", "非法行医", "伪造货币", "动植物检疫徇私舞弊",
|
||||
"单位行贿", "破坏监管秩序", "盗窃", "盗伐林木", "重大劳动安全事故", "非法吸收公众存款",
|
||||
"非法制造、出售非法制造的发票", "非法狩猎", "组织卖淫", "非法买卖、运输、携带、持有毒品原植物种子、幼苗", "挪用资金",
|
||||
"诈骗", "伪造、变造、买卖国家机关公文、证件、印章", "持有伪造的发票", "贪污", "非法生产、买卖警用装备",
|
||||
"投放危险物质", "伪造、倒卖伪造的有价票证", "集资诈骗", "抢夺", "生产、销售有毒、有害食品", "非法捕捞水产品",
|
||||
"过失致人死亡", "非法买卖制毒物品", "虚开增值税专用发票、用于骗取出口退税、抵扣税款发票", "寻衅滋事", "危险驾驶",
|
||||
"故意毁坏财物", "招摇撞骗", "盗窃、侮辱尸体", "走私武器、弹药",
|
||||
"非法收购、运输、加工、出售国家重点保护植物、国家重点保护植物制品", "非法出售发票", "劫持船只、汽车",
|
||||
"受贿", "聚众哄抢", "交通肇事"]
|
||||
|
||||
|
||||
def compute_ljp_accusation(data_dict):
|
||||
"""
|
||||
Compute the F1-score
|
||||
The LJP_Accusation dataset a set of 189 different accusation types.
|
||||
A question may involve one or more accusation types.
|
||||
Given a list of accusation types from both the ground truth and the prediction, we compute the F1-score between
|
||||
these two lists.
|
||||
"""
|
||||
score_list, abstentions = [], 0
|
||||
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
|
||||
assert answer.startswith("罪名:"), f"answer: {answer} \n question: {question}"
|
||||
answer = answer.replace("罪名:", "")
|
||||
answers = answer.split(";")
|
||||
|
||||
prediction_list =[]
|
||||
for option in option_list:
|
||||
if option in prediction:
|
||||
prediction_list.append(option)
|
||||
|
||||
if len(prediction_list) == 0:
|
||||
abstentions += 1
|
||||
gt_set = set(answers)
|
||||
pred_set = set(prediction_list)
|
||||
score = compute_f1_two_sets(gt_set, pred_set)
|
||||
score_list.append(score)
|
||||
|
||||
f1_score_average = sum(score_list) / len(score_list)
|
||||
return {"score": f1_score_average, "abstention_rate": abstentions/len(data_dict)}
|
||||
from ..utils.function_utils import compute_f1_two_sets
|
||||
"""
|
||||
task: legal accusation prediction
|
||||
metric: f1 score
|
||||
法律判决预测-罪名预测
|
||||
"""
|
||||
|
||||
option_list = ["侮辱", "违法发放贷款", "失火", "票据诈骗", "帮助犯罪分子逃避处罚", "重大责任事故", "对非国家工作人员行贿",
|
||||
"非法制造、销售非法制造的注册商标标识", "非法制造、买卖、运输、邮寄、储存枪支、弹药、爆炸物", "非法获取公民个人信息",
|
||||
"扰乱无线电通讯管理秩序", "非法持有、私藏枪支、弹药", "拒不执行判决、裁定", "虚开发票", "巨额财产来源不明",
|
||||
"组织、领导、参加黑社会性质组织", "非法获取国家秘密", "以危险方法危害公共安全", "非法持有毒品",
|
||||
"聚众扰乱公共场所秩序、交通秩序", "包庇毒品犯罪分子", "滥伐林木", "伪造公司、企业、事业单位、人民团体印章",
|
||||
"非法占用农用地", "走私废物", "串通投标", "非法采伐、毁坏国家重点保护植物", "冒充军人招摇撞骗", "玩忽职守",
|
||||
"重婚", "招收公务员、学生徇私舞弊", "组织、领导传销活动", "非法猎捕、杀害珍贵、濒危野生动物", "侵犯著作权",
|
||||
"非法种植毒品原植物", "伪造、变造、买卖武装部队公文、证件、印章", "倒卖文物", "伪造、变造居民身份证", "滥用职权",
|
||||
"诽谤", "猥亵儿童", "非法转让、倒卖土地使用权", "挪用公款", "污染环境", "出售、购买、运输假币", "敲诈勒索",
|
||||
"高利转贷", "故意伤害", "持有、使用假币", "单位受贿", "强奸", "引诱、容留、介绍卖淫", "虐待",
|
||||
"生产、销售伪劣农药、兽药、化肥、种子", "妨害公务", "容留他人吸毒", "拐骗儿童", "强制猥亵、侮辱妇女",
|
||||
"非法处置查封、扣押、冻结的财产", "骗取贷款、票据承兑、金融票证", "强迫他人吸毒", "非法拘禁",
|
||||
"非法携带枪支、弹药、管制刀具、危险物品危及公共安全", "绑架", "聚众斗殴", "破坏计算机信息系统",
|
||||
"制造、贩卖、传播淫秽物品", "虐待被监管人", "贷款诈骗", "赌博", "徇私舞弊不征、少征税款",
|
||||
"盗窃、抢夺枪支、弹药、爆炸物、危险物质", "故意杀人", "介绍贿赂", "提供侵入、非法控制计算机信息系统程序、工具",
|
||||
"编造、故意传播虚假恐怖信息", "妨害作证", "强迫卖淫", "走私、贩卖、运输、制造毒品", "伪证", "拐卖妇女、儿童",
|
||||
"过失损坏武器装备、军事设施、军事通信", "破坏广播电视设施、公用电信设施", "洗钱", "职务侵占", "倒卖车票、船票",
|
||||
"抢劫", "侵占", "掩饰、隐瞒犯罪所得、犯罪所得收益", "徇私舞弊不移交刑事案件", "引诱、教唆、欺骗他人吸毒", "遗弃",
|
||||
"生产、销售伪劣产品", "放火", "非法采矿", "对单位行贿", "盗窃、抢夺枪支、弹药、爆炸物", "破坏易燃易爆设备",
|
||||
"妨害信用卡管理", "制作、复制、出版、贩卖、传播淫秽物品牟利", "金融凭证诈骗", "私分国有资产",
|
||||
"走私国家禁止进出口的货物、物品", "假冒注册商标", "危险物品肇事", "走私普通货物、物品", "经济犯", "虚报注册资本",
|
||||
"盗掘古文化遗址、古墓葬", "传播淫秽物品", "窝藏、包庇", "拒不支付劳动报酬", "行贿", "开设赌场", "传授犯罪方法",
|
||||
"协助组织卖淫", "保险诈骗", "破坏生产经营", "破坏交通设施", "打击报复证人", "非法侵入住宅", "非国家工作人员受贿",
|
||||
"过失致人重伤", "伪造、变造金融票证", "窝藏、转移、隐瞒毒品、毒赃", "帮助毁灭、伪造证据", "走私珍贵动物、珍贵动物制品",
|
||||
"生产、销售假药", "逃税", "挪用特定款物", "聚众扰乱社会秩序", "组织、强迫、引诱、容留、介绍卖淫", "合同诈骗",
|
||||
"非法生产、销售间谍专用器材", "破坏交通工具", "传播性病", "强迫交易", "隐匿、故意销毁会计凭证、会计帐簿、财务会计报告",
|
||||
"非法组织卖血", "强迫劳动", "破坏电力设备", "销售假冒注册商标的商品", "收买被拐卖的妇女、儿童", "诬告陷害", "脱逃",
|
||||
"非法经营", "徇私枉法", "信用卡诈骗", "生产、销售不符合安全标准的食品", "非法行医", "伪造货币", "动植物检疫徇私舞弊",
|
||||
"单位行贿", "破坏监管秩序", "盗窃", "盗伐林木", "重大劳动安全事故", "非法吸收公众存款",
|
||||
"非法制造、出售非法制造的发票", "非法狩猎", "组织卖淫", "非法买卖、运输、携带、持有毒品原植物种子、幼苗", "挪用资金",
|
||||
"诈骗", "伪造、变造、买卖国家机关公文、证件、印章", "持有伪造的发票", "贪污", "非法生产、买卖警用装备",
|
||||
"投放危险物质", "伪造、倒卖伪造的有价票证", "集资诈骗", "抢夺", "生产、销售有毒、有害食品", "非法捕捞水产品",
|
||||
"过失致人死亡", "非法买卖制毒物品", "虚开增值税专用发票、用于骗取出口退税、抵扣税款发票", "寻衅滋事", "危险驾驶",
|
||||
"故意毁坏财物", "招摇撞骗", "盗窃、侮辱尸体", "走私武器、弹药",
|
||||
"非法收购、运输、加工、出售国家重点保护植物、国家重点保护植物制品", "非法出售发票", "劫持船只、汽车",
|
||||
"受贿", "聚众哄抢", "交通肇事"]
|
||||
|
||||
|
||||
def compute_ljp_accusation(data_dict):
|
||||
"""
|
||||
Compute the F1-score
|
||||
The LJP_Accusation dataset a set of 189 different accusation types.
|
||||
A question may involve one or more accusation types.
|
||||
Given a list of accusation types from both the ground truth and the prediction, we compute the F1-score between
|
||||
these two lists.
|
||||
"""
|
||||
score_list, abstentions = [], 0
|
||||
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
|
||||
assert answer.startswith("罪名:"), f"answer: {answer} \n question: {question}"
|
||||
answer = answer.replace("罪名:", "")
|
||||
answers = answer.split(";")
|
||||
|
||||
prediction_list =[]
|
||||
for option in option_list:
|
||||
if option in prediction:
|
||||
prediction_list.append(option)
|
||||
|
||||
if len(prediction_list) == 0:
|
||||
abstentions += 1
|
||||
gt_set = set(answers)
|
||||
pred_set = set(prediction_list)
|
||||
score = compute_f1_two_sets(gt_set, pred_set)
|
||||
score_list.append(score)
|
||||
|
||||
f1_score_average = sum(score_list) / len(score_list)
|
||||
return {"score": f1_score_average, "abstention_rate": abstentions/len(data_dict)}
|
||||
|
@ -1,70 +1,70 @@
|
||||
import re
|
||||
import cn2an
|
||||
|
||||
"""
|
||||
task: law article prediction
|
||||
metric: F1 score
|
||||
法律判决预测-法条预测
|
||||
"""
|
||||
def replace_match(match):
|
||||
return match.group(1)
|
||||
|
||||
def compute_ljp_article(data_dict):
|
||||
"""
|
||||
Compute the F1-score
|
||||
A reference contains a list of articles of the Criminal Law of the People's Republic of China.
|
||||
We compute the F1-score between the prediction and the reference.
|
||||
"""
|
||||
|
||||
score_list, abstentions = [], 0
|
||||
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
assert answer.startswith("法条:刑法第"), f"answer: {answer}"
|
||||
assert answer.endswith("条"), f"answer: {answer}"
|
||||
|
||||
answer = answer.replace("法条:刑法第", "")
|
||||
answer = answer.replace("条", "")
|
||||
|
||||
answer_law_indices = answer.split("、")
|
||||
answer_law_index_digit_list = []
|
||||
for answer_law_index in answer_law_indices:
|
||||
assert answer_law_index.isdigit(), f"answer_law_index: {answer_law_index}"
|
||||
answer_law_index_digit = int(answer_law_index)
|
||||
assert answer_law_index_digit <= 490, "刑法总共只有490条"
|
||||
answer_law_index_digit_list.append(answer_law_index_digit)
|
||||
|
||||
prediction_law_chunks = prediction.split("、")
|
||||
prediction_law_index_digit_list = []
|
||||
|
||||
for prediction_law_chunk in prediction_law_chunks:
|
||||
prediction_law_chunk = prediction_law_chunk.replace("万元", "元")
|
||||
|
||||
# delete phrase starts with "第" and ends with "款", we don't care about it in the answer
|
||||
prediction_law_chunk = re.sub(r'第(.*?)款', "", prediction_law_chunk)
|
||||
# keep only the digits in the phrase starts with "第" and ends with "条", otherwise cn may fail to convert
|
||||
prediction_law_chunk = re.sub(r'第(.*?)条', replace_match, prediction_law_chunk)
|
||||
prediction_law_chunk = cn2an.transform(prediction_law_chunk, "cn2an")
|
||||
# find digtis in prediction_law_chunk
|
||||
prediction_law_section_numbers = re.findall(r"\d+", prediction_law_chunk)
|
||||
if len(prediction_law_section_numbers) == 0:
|
||||
continue
|
||||
if len(prediction_law_section_numbers) != 1:
|
||||
# in this case, we only take the first number, and reject the others
|
||||
pass
|
||||
|
||||
prediction_law_index_digit = int(prediction_law_section_numbers[0])
|
||||
prediction_law_index_digit_list.append(prediction_law_index_digit)
|
||||
|
||||
gt_set = set(answer_law_index_digit_list)
|
||||
pred_set = set(prediction_law_index_digit_list)
|
||||
if len(pred_set) == 0:
|
||||
abstentions += 1
|
||||
precision = len(gt_set.intersection(pred_set)) / len(pred_set) if len(pred_set) != 0 else 0
|
||||
recall = len(gt_set.intersection(pred_set)) / len(gt_set) if len(gt_set) != 0 else 0
|
||||
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0
|
||||
score_list.append(f1_score)
|
||||
|
||||
# compute the accuracy of score_list
|
||||
average_f1 = sum(score_list) / len(score_list)
|
||||
return {'score': average_f1, 'abstention_rate': abstentions/len(data_dict)}
|
||||
import re
|
||||
import cn2an
|
||||
|
||||
"""
|
||||
task: law article prediction
|
||||
metric: F1 score
|
||||
法律判决预测-法条预测
|
||||
"""
|
||||
def replace_match(match):
|
||||
return match.group(1)
|
||||
|
||||
def compute_ljp_article(data_dict):
|
||||
"""
|
||||
Compute the F1-score
|
||||
A reference contains a list of articles of the Criminal Law of the People's Republic of China.
|
||||
We compute the F1-score between the prediction and the reference.
|
||||
"""
|
||||
|
||||
score_list, abstentions = [], 0
|
||||
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
assert answer.startswith("法条:刑法第"), f"answer: {answer}"
|
||||
assert answer.endswith("条"), f"answer: {answer}"
|
||||
|
||||
answer = answer.replace("法条:刑法第", "")
|
||||
answer = answer.replace("条", "")
|
||||
|
||||
answer_law_indices = answer.split("、")
|
||||
answer_law_index_digit_list = []
|
||||
for answer_law_index in answer_law_indices:
|
||||
assert answer_law_index.isdigit(), f"answer_law_index: {answer_law_index}"
|
||||
answer_law_index_digit = int(answer_law_index)
|
||||
assert answer_law_index_digit <= 490, "刑法总共只有490条"
|
||||
answer_law_index_digit_list.append(answer_law_index_digit)
|
||||
|
||||
prediction_law_chunks = prediction.split("、")
|
||||
prediction_law_index_digit_list = []
|
||||
|
||||
for prediction_law_chunk in prediction_law_chunks:
|
||||
prediction_law_chunk = prediction_law_chunk.replace("万元", "元")
|
||||
|
||||
# delete phrase starts with "第" and ends with "款", we don't care about it in the answer
|
||||
prediction_law_chunk = re.sub(r'第(.*?)款', "", prediction_law_chunk)
|
||||
# keep only the digits in the phrase starts with "第" and ends with "条", otherwise cn may fail to convert
|
||||
prediction_law_chunk = re.sub(r'第(.*?)条', replace_match, prediction_law_chunk)
|
||||
prediction_law_chunk = cn2an.transform(prediction_law_chunk, "cn2an")
|
||||
# find digtis in prediction_law_chunk
|
||||
prediction_law_section_numbers = re.findall(r"\d+", prediction_law_chunk)
|
||||
if len(prediction_law_section_numbers) == 0:
|
||||
continue
|
||||
if len(prediction_law_section_numbers) != 1:
|
||||
# in this case, we only take the first number, and reject the others
|
||||
pass
|
||||
|
||||
prediction_law_index_digit = int(prediction_law_section_numbers[0])
|
||||
prediction_law_index_digit_list.append(prediction_law_index_digit)
|
||||
|
||||
gt_set = set(answer_law_index_digit_list)
|
||||
pred_set = set(prediction_law_index_digit_list)
|
||||
if len(pred_set) == 0:
|
||||
abstentions += 1
|
||||
precision = len(gt_set.intersection(pred_set)) / len(pred_set) if len(pred_set) != 0 else 0
|
||||
recall = len(gt_set.intersection(pred_set)) / len(gt_set) if len(gt_set) != 0 else 0
|
||||
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0
|
||||
score_list.append(f1_score)
|
||||
|
||||
# compute the accuracy of score_list
|
||||
average_f1 = sum(score_list) / len(score_list)
|
||||
return {'score': average_f1, 'abstention_rate': abstentions/len(data_dict)}
|
||||
|
@ -1,49 +1,49 @@
|
||||
import math
|
||||
import cn2an
|
||||
import re
|
||||
|
||||
#法律判决预测-刑期预测
|
||||
def compute_ljp_imprison(data_dict):
|
||||
score_list, abstentions = [], 0
|
||||
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
# get answer digit, which is the number between "刑期:" and "个月"
|
||||
if "死刑" in answer or "无期" in answer:
|
||||
# TODO: data imperfection
|
||||
continue
|
||||
|
||||
assert answer.startswith("刑期:") and answer.endswith("个月"), f"answer: {answer}, question: {question}"
|
||||
answer = answer.replace("刑期:", "")
|
||||
answer = answer.replace("个月", "")
|
||||
answer_digit = int(answer)
|
||||
prediction = cn2an.transform(prediction, "cn2an")
|
||||
|
||||
# use regular expression to extract the digits from prediction, only consider digits before "个月" or "月"
|
||||
prediction_digit_month_list = re.findall(r"\d+个月", prediction)
|
||||
prediction_digit_month_list = [int(digit.replace("个月", "")) for digit in prediction_digit_month_list]
|
||||
prediction_digit_month_list2 = re.findall(r"\d+月", prediction)
|
||||
prediction_digit_month_list2 = [int(digit.replace("月", "")) for digit in prediction_digit_month_list2]
|
||||
prediction_digit_month_list.extend(prediction_digit_month_list2)
|
||||
# catches the digits before "年"
|
||||
prediction_digit_year_list = re.findall(r"\d+年", prediction)
|
||||
prediction_digit_year_list = [int(digit.replace("年", "")) for digit in prediction_digit_year_list]
|
||||
|
||||
if len(prediction_digit_month_list) > 0:
|
||||
prediction_digit_month = int(prediction_digit_month_list[0])
|
||||
elif len(prediction_digit_year_list) > 0:
|
||||
prediction_digit_month = int(prediction_digit_year_list[0]) * 12
|
||||
else:
|
||||
abstentions += 1
|
||||
prediction_digit_month = -1
|
||||
|
||||
if prediction_digit_month != -1:
|
||||
score_list.append(abs(math.log(answer_digit + 1) - math.log(prediction_digit_month + 1)))
|
||||
else:
|
||||
score_list.append(math.log(216))
|
||||
|
||||
# compute the average of score_list (log distance)
|
||||
log_distance = sum(score_list) / len(score_list)
|
||||
# normalize the score to between 0 and 1
|
||||
log_distance = (math.log(216) - log_distance)/math.log(216)
|
||||
return {"score": log_distance, "abstention_rate": abstentions/len(data_dict)}
|
||||
import math
|
||||
import cn2an
|
||||
import re
|
||||
|
||||
#法律判决预测-刑期预测
|
||||
def compute_ljp_imprison(data_dict):
|
||||
score_list, abstentions = [], 0
|
||||
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
# get answer digit, which is the number between "刑期:" and "个月"
|
||||
if "死刑" in answer or "无期" in answer:
|
||||
# TODO: data imperfection
|
||||
continue
|
||||
|
||||
assert answer.startswith("刑期:") and answer.endswith("个月"), f"answer: {answer}, question: {question}"
|
||||
answer = answer.replace("刑期:", "")
|
||||
answer = answer.replace("个月", "")
|
||||
answer_digit = int(answer)
|
||||
prediction = cn2an.transform(prediction, "cn2an")
|
||||
|
||||
# use regular expression to extract the digits from prediction, only consider digits before "个月" or "月"
|
||||
prediction_digit_month_list = re.findall(r"\d+个月", prediction)
|
||||
prediction_digit_month_list = [int(digit.replace("个月", "")) for digit in prediction_digit_month_list]
|
||||
prediction_digit_month_list2 = re.findall(r"\d+月", prediction)
|
||||
prediction_digit_month_list2 = [int(digit.replace("月", "")) for digit in prediction_digit_month_list2]
|
||||
prediction_digit_month_list.extend(prediction_digit_month_list2)
|
||||
# catches the digits before "年"
|
||||
prediction_digit_year_list = re.findall(r"\d+年", prediction)
|
||||
prediction_digit_year_list = [int(digit.replace("年", "")) for digit in prediction_digit_year_list]
|
||||
|
||||
if len(prediction_digit_month_list) > 0:
|
||||
prediction_digit_month = int(prediction_digit_month_list[0])
|
||||
elif len(prediction_digit_year_list) > 0:
|
||||
prediction_digit_month = int(prediction_digit_year_list[0]) * 12
|
||||
else:
|
||||
abstentions += 1
|
||||
prediction_digit_month = -1
|
||||
|
||||
if prediction_digit_month != -1:
|
||||
score_list.append(abs(math.log(answer_digit + 1) - math.log(prediction_digit_month + 1)))
|
||||
else:
|
||||
score_list.append(math.log(216))
|
||||
|
||||
# compute the average of score_list (log distance)
|
||||
log_distance = sum(score_list) / len(score_list)
|
||||
# normalize the score to between 0 and 1
|
||||
log_distance = (math.log(216) - log_distance)/math.log(216)
|
||||
return {"score": log_distance, "abstention_rate": abstentions/len(data_dict)}
|
||||
|
@ -1,64 +1,64 @@
|
||||
from ..utils.function_utils import compute_f1_two_sets
|
||||
from ..utils.rc_f1 import CJRCEvaluator
|
||||
|
||||
|
||||
"""
|
||||
task: event detection
|
||||
metric: F1 score
|
||||
事件检测
|
||||
"""
|
||||
option_list = ["支付/给付", "欺骗", "搜查/扣押", "要求/请求", "卖出", "买入", "获利", "拘捕", "鉴定", "同意/接受", "供述", "联络", "帮助/救助", "租用/借用", "受伤", "伪造", "卖淫", "伤害人身", "赔偿", "归还/偿还"]
|
||||
|
||||
def compute_sjjc(data_dict):
|
||||
"""
|
||||
Compute the F1-score
|
||||
The sjjc task covers 20 event types.
|
||||
A question may involve one or more event types.
|
||||
Given a list of event types from both the ground truth and the prediction, we compute the F1-score between
|
||||
these two lists.
|
||||
"""
|
||||
score_list, abstentions = [], 0
|
||||
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
|
||||
answers = answer.split(";")
|
||||
|
||||
prediction_list =[]
|
||||
for option in option_list:
|
||||
if option in prediction:
|
||||
prediction_list.append(option)
|
||||
|
||||
if len(prediction_list) == 0:
|
||||
abstentions += 1
|
||||
gt_set = set(answers)
|
||||
pred_set = set(prediction_list)
|
||||
score = compute_f1_two_sets(gt_set, pred_set)
|
||||
score_list.append(score)
|
||||
|
||||
f1_score_average = sum(score_list) / len(score_list)
|
||||
return {"score": f1_score_average, "abstention_rate": abstentions/len(data_dict)}
|
||||
|
||||
"""
|
||||
task: trigger word extraction
|
||||
metric: F1 score
|
||||
触发词抽取
|
||||
"""
|
||||
def compute_cfcy(data_dict):
|
||||
|
||||
scores = 0
|
||||
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
|
||||
answers = answer.split(";")
|
||||
predictions = prediction.split(";")
|
||||
intersected = [CJRCEvaluator.compute_f1(r, h) for r, h in zip(answers, predictions)]
|
||||
|
||||
prec = sum(intersected) / len(predictions) if len(predictions) > 0 else 0
|
||||
rec = sum(intersected) / len(answers) if len(answers) > 0 else 0
|
||||
# print(prec, rec, intersected)
|
||||
scores += 2 * prec * rec / (prec + rec + 1e-10)
|
||||
|
||||
f1_score_average = scores / len(data_dict)
|
||||
return {"score": f1_score_average}
|
||||
from ..utils.function_utils import compute_f1_two_sets
|
||||
from ..utils.rc_f1 import CJRCEvaluator
|
||||
|
||||
|
||||
"""
|
||||
task: event detection
|
||||
metric: F1 score
|
||||
事件检测
|
||||
"""
|
||||
option_list = ["支付/给付", "欺骗", "搜查/扣押", "要求/请求", "卖出", "买入", "获利", "拘捕", "鉴定", "同意/接受", "供述", "联络", "帮助/救助", "租用/借用", "受伤", "伪造", "卖淫", "伤害人身", "赔偿", "归还/偿还"]
|
||||
|
||||
def compute_sjjc(data_dict):
|
||||
"""
|
||||
Compute the F1-score
|
||||
The sjjc task covers 20 event types.
|
||||
A question may involve one or more event types.
|
||||
Given a list of event types from both the ground truth and the prediction, we compute the F1-score between
|
||||
these two lists.
|
||||
"""
|
||||
score_list, abstentions = [], 0
|
||||
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
|
||||
answers = answer.split(";")
|
||||
|
||||
prediction_list =[]
|
||||
for option in option_list:
|
||||
if option in prediction:
|
||||
prediction_list.append(option)
|
||||
|
||||
if len(prediction_list) == 0:
|
||||
abstentions += 1
|
||||
gt_set = set(answers)
|
||||
pred_set = set(prediction_list)
|
||||
score = compute_f1_two_sets(gt_set, pred_set)
|
||||
score_list.append(score)
|
||||
|
||||
f1_score_average = sum(score_list) / len(score_list)
|
||||
return {"score": f1_score_average, "abstention_rate": abstentions/len(data_dict)}
|
||||
|
||||
"""
|
||||
task: trigger word extraction
|
||||
metric: F1 score
|
||||
触发词抽取
|
||||
"""
|
||||
def compute_cfcy(data_dict):
|
||||
|
||||
scores = 0
|
||||
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
|
||||
answers = answer.split(";")
|
||||
predictions = prediction.split(";")
|
||||
intersected = [CJRCEvaluator.compute_f1(r, h) for r, h in zip(answers, predictions)]
|
||||
|
||||
prec = sum(intersected) / len(predictions) if len(predictions) > 0 else 0
|
||||
rec = sum(intersected) / len(answers) if len(answers) > 0 else 0
|
||||
# print(prec, rec, intersected)
|
||||
scores += 2 * prec * rec / (prec + rec + 1e-10)
|
||||
|
||||
f1_score_average = scores / len(data_dict)
|
||||
return {"score": f1_score_average}
|
||||
|
@ -1,42 +1,42 @@
|
||||
"""
|
||||
task: multiple choice classification
|
||||
metric: F1 score
|
||||
婚姻文本分类
|
||||
"""
|
||||
|
||||
def compute_wbfl(data_dict):
|
||||
"""
|
||||
A reference (R) contains a list of options, each option is from the option_list.
|
||||
We will extract the options appearing in the prediction and convert them into a set (P).
|
||||
We compute the F1 score between the prediction (P) and the reference (R).
|
||||
"""
|
||||
|
||||
|
||||
score_list, abstentions = [], 0
|
||||
option_list = ["婚后有子女", "限制行为能力子女抚养", "有夫妻共同财产", "支付抚养费", "不动产分割", "婚后分局",
|
||||
"二次起诉离婚", "按月给付抚养费", "准予离婚", "有夫妻共同债务", "婚前个人财产", "法定离婚", "不履行家庭义务",
|
||||
"存在非婚生子", "适当帮助", "不履行离婚协议", "损害赔偿", "感情不和分居满二年", "子女随非抚养权人生活", "婚后个人财产"]
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
assert answer.startswith("类别:") and answer.endswith("。"), f"answer: {answer}, question: {question}"
|
||||
|
||||
gt_list = (answer[3:-1].split("、"))
|
||||
for gt in gt_list:
|
||||
assert gt in option_list, f"gt: {gt}, question: {question}"
|
||||
gt_set = set(gt_list)
|
||||
|
||||
prediction_list = []
|
||||
for option in option_list:
|
||||
if option in prediction:
|
||||
prediction_list.append(option)
|
||||
if len(prediction_list) == 0:
|
||||
abstentions += 1
|
||||
predict_set = set(prediction_list)
|
||||
precision = len(gt_set.intersection(predict_set)) / len(predict_set) if len(predict_set) != 0 else 0
|
||||
recall = len(gt_set.intersection(predict_set)) / len(gt_set) if len(gt_set) != 0 else 0
|
||||
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0
|
||||
score_list.append(f1_score)
|
||||
|
||||
# compute the accuracy of score_list
|
||||
final_f1_score = sum(score_list) / len(score_list)
|
||||
return {'score': final_f1_score, 'abstention_rate': abstentions / len(data_dict)}
|
||||
"""
|
||||
task: multiple choice classification
|
||||
metric: F1 score
|
||||
婚姻文本分类
|
||||
"""
|
||||
|
||||
def compute_wbfl(data_dict):
|
||||
"""
|
||||
A reference (R) contains a list of options, each option is from the option_list.
|
||||
We will extract the options appearing in the prediction and convert them into a set (P).
|
||||
We compute the F1 score between the prediction (P) and the reference (R).
|
||||
"""
|
||||
|
||||
|
||||
score_list, abstentions = [], 0
|
||||
option_list = ["婚后有子女", "限制行为能力子女抚养", "有夫妻共同财产", "支付抚养费", "不动产分割", "婚后分局",
|
||||
"二次起诉离婚", "按月给付抚养费", "准予离婚", "有夫妻共同债务", "婚前个人财产", "法定离婚", "不履行家庭义务",
|
||||
"存在非婚生子", "适当帮助", "不履行离婚协议", "损害赔偿", "感情不和分居满二年", "子女随非抚养权人生活", "婚后个人财产"]
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
assert answer.startswith("类别:") and answer.endswith("。"), f"answer: {answer}, question: {question}"
|
||||
|
||||
gt_list = (answer[3:-1].split("、"))
|
||||
for gt in gt_list:
|
||||
assert gt in option_list, f"gt: {gt}, question: {question}"
|
||||
gt_set = set(gt_list)
|
||||
|
||||
prediction_list = []
|
||||
for option in option_list:
|
||||
if option in prediction:
|
||||
prediction_list.append(option)
|
||||
if len(prediction_list) == 0:
|
||||
abstentions += 1
|
||||
predict_set = set(prediction_list)
|
||||
precision = len(gt_set.intersection(predict_set)) / len(predict_set) if len(predict_set) != 0 else 0
|
||||
recall = len(gt_set.intersection(predict_set)) / len(gt_set) if len(gt_set) != 0 else 0
|
||||
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0
|
||||
score_list.append(f1_score)
|
||||
|
||||
# compute the accuracy of score_list
|
||||
final_f1_score = sum(score_list) / len(score_list)
|
||||
return {'score': final_f1_score, 'abstention_rate': abstentions / len(data_dict)}
|
||||
|
@ -1,50 +1,50 @@
|
||||
import re
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
"""
|
||||
Task: legal document grammar correction
|
||||
Metric: F0.5 score
|
||||
文书校对
|
||||
"""
|
||||
def compute_wsjd(data_dict):
|
||||
origins, references, predictions = [], [], []
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
if isinstance(question, list):
|
||||
question = question[0]['prompt']
|
||||
start = question.index('句子:\n') + 4
|
||||
origins.append(re.sub(r'\n|\t', '', question[start:].split('\n')[0]))
|
||||
# truncate predictions >5 tokens longer than the reference
|
||||
prediction = re.sub(r'\n|\t', '', prediction)
|
||||
if len(prediction) - len(answer) > 5:
|
||||
prediction = prediction[:len(answer) + 5]
|
||||
if len(prediction) == 0:
|
||||
prediction = "无内容"
|
||||
predictions.append(prediction)
|
||||
references.append(re.sub(r'\n|\t', '', answer))
|
||||
|
||||
#generate input files for ChERRANT
|
||||
preds = [f'{i} \t {origin} \t {prediction} \n' for i, (origin, prediction) in enumerate(zip(origins, predictions))]
|
||||
golds = [f'{i} \t {origin} \t {reference} \n' for i, (origin, reference) in enumerate(zip(origins, references))]
|
||||
|
||||
now_path = os.path.abspath(os.getcwd())
|
||||
utils_path = os.path.abspath(os.path.join(__file__, '..', '..', 'utils'))
|
||||
uid = os.getuid()
|
||||
os.chdir(utils_path)
|
||||
with open(f'/tmp/tmp_pred_{uid}.para', 'w') as f:
|
||||
f.writelines(preds)
|
||||
with open(f'/tmp/tmp_gold_{uid}.para', 'w') as f:
|
||||
f.writelines(golds)
|
||||
os.environ['KMP_DUPLICATE_LIB_OK']='True'
|
||||
os.system(f'python3 parallel_to_m2.py -f /tmp/tmp_pred_{uid}.para -o /tmp/tmp_pred_{uid}.para.m2 -g char')
|
||||
os.system(f'python3 parallel_to_m2.py -f /tmp/tmp_gold_{uid}.para -o /tmp/tmp_gold_{uid}.para.m2 -g char')
|
||||
output = subprocess.check_output(f"python3 compare_m2_for_evaluation.py -hyp /tmp/tmp_pred_{uid}.para.m2 -ref /tmp/tmp_gold_{uid}.para.m2", shell = True)
|
||||
score = float(output.decode().split('\t')[-1].split('\n')[0])
|
||||
#remove prediction files
|
||||
os.remove(f'/tmp/tmp_pred_{uid}.para')
|
||||
os.remove(f'/tmp/tmp_gold_{uid}.para')
|
||||
os.remove(f'/tmp/tmp_pred_{uid}.para.m2')
|
||||
os.remove(f'/tmp/tmp_gold_{uid}.para.m2')
|
||||
os.chdir(now_path)
|
||||
return {"score": score}
|
||||
import re
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
"""
|
||||
Task: legal document grammar correction
|
||||
Metric: F0.5 score
|
||||
文书校对
|
||||
"""
|
||||
def compute_wsjd(data_dict):
|
||||
origins, references, predictions = [], [], []
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
if isinstance(question, list):
|
||||
question = question[0]['prompt']
|
||||
start = question.index('句子:\n') + 4
|
||||
origins.append(re.sub(r'\n|\t', '', question[start:].split('\n')[0]))
|
||||
# truncate predictions >5 tokens longer than the reference
|
||||
prediction = re.sub(r'\n|\t', '', prediction)
|
||||
if len(prediction) - len(answer) > 5:
|
||||
prediction = prediction[:len(answer) + 5]
|
||||
if len(prediction) == 0:
|
||||
prediction = "无内容"
|
||||
predictions.append(prediction)
|
||||
references.append(re.sub(r'\n|\t', '', answer))
|
||||
|
||||
#generate input files for ChERRANT
|
||||
preds = [f'{i} \t {origin} \t {prediction} \n' for i, (origin, prediction) in enumerate(zip(origins, predictions))]
|
||||
golds = [f'{i} \t {origin} \t {reference} \n' for i, (origin, reference) in enumerate(zip(origins, references))]
|
||||
|
||||
now_path = os.path.abspath(os.getcwd())
|
||||
utils_path = os.path.abspath(os.path.join(__file__, '..', '..', 'utils'))
|
||||
uid = os.getuid()
|
||||
os.chdir(utils_path)
|
||||
with open(f'/tmp/tmp_pred_{uid}.para', 'w') as f:
|
||||
f.writelines(preds)
|
||||
with open(f'/tmp/tmp_gold_{uid}.para', 'w') as f:
|
||||
f.writelines(golds)
|
||||
os.environ['KMP_DUPLICATE_LIB_OK']='True'
|
||||
os.system(f'python3 parallel_to_m2.py -f /tmp/tmp_pred_{uid}.para -o /tmp/tmp_pred_{uid}.para.m2 -g char')
|
||||
os.system(f'python3 parallel_to_m2.py -f /tmp/tmp_gold_{uid}.para -o /tmp/tmp_gold_{uid}.para.m2 -g char')
|
||||
output = subprocess.check_output(f"python3 compare_m2_for_evaluation.py -hyp /tmp/tmp_pred_{uid}.para.m2 -ref /tmp/tmp_gold_{uid}.para.m2", shell = True)
|
||||
score = float(output.decode().split('\t')[-1].split('\n')[0])
|
||||
#remove prediction files
|
||||
os.remove(f'/tmp/tmp_pred_{uid}.para')
|
||||
os.remove(f'/tmp/tmp_gold_{uid}.para')
|
||||
os.remove(f'/tmp/tmp_pred_{uid}.para.m2')
|
||||
os.remove(f'/tmp/tmp_gold_{uid}.para.m2')
|
||||
os.chdir(now_path)
|
||||
return {"score": score}
|
||||
|
@ -1,17 +1,17 @@
|
||||
from ..utils.comprehension_scores import compute_ie_f1
|
||||
|
||||
|
||||
"""
|
||||
task: information extraction
|
||||
metric: F1 score
|
||||
信息抽取
|
||||
"""
|
||||
def compute_xxcq(data_dict):
|
||||
references, predictions = [], []
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
predictions.append(prediction)
|
||||
references.append(answer)
|
||||
|
||||
return compute_ie_f1(predictions, references, {"犯罪嫌疑人", "受害人", "被盗货币", "物品价值", "盗窃获利",
|
||||
"被盗物品", "作案工具", "时间", "地点", "组织机构"})
|
||||
from ..utils.comprehension_scores import compute_ie_f1
|
||||
|
||||
|
||||
"""
|
||||
task: information extraction
|
||||
metric: F1 score
|
||||
信息抽取
|
||||
"""
|
||||
def compute_xxcq(data_dict):
|
||||
references, predictions = [], []
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
predictions.append(prediction)
|
||||
references.append(answer)
|
||||
|
||||
return compute_ie_f1(predictions, references, {"犯罪嫌疑人", "受害人", "被盗货币", "物品价值", "盗窃获利",
|
||||
"被盗物品", "作案工具", "时间", "地点", "组织机构"})
|
||||
|
@ -1,17 +1,17 @@
|
||||
from ..utils.comprehension_scores import compute_rc_f1
|
||||
|
||||
"""
|
||||
Task: machine reading comprehension
|
||||
Metric: F1 score
|
||||
法律阅读理解
|
||||
"""
|
||||
def compute_ydlj(data_dict):
|
||||
references, predictions = [], []
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
answer = answer.replace("回答:", "")
|
||||
predictions.append(prediction)
|
||||
references.append(answer)
|
||||
|
||||
f1_score = compute_rc_f1(predictions, references)
|
||||
return f1_score
|
||||
from ..utils.comprehension_scores import compute_rc_f1
|
||||
|
||||
"""
|
||||
Task: machine reading comprehension
|
||||
Metric: F1 score
|
||||
法律阅读理解
|
||||
"""
|
||||
def compute_ydlj(data_dict):
|
||||
references, predictions = [], []
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
answer = answer.replace("回答:", "")
|
||||
predictions.append(prediction)
|
||||
references.append(answer)
|
||||
|
||||
f1_score = compute_rc_f1(predictions, references)
|
||||
return f1_score
|
||||
|
@ -1,18 +1,18 @@
|
||||
from ..utils.function_utils import compute_rouge
|
||||
|
||||
#舆情摘要
|
||||
def compute_yqzy(data_dict):
|
||||
"""
|
||||
Compute the ROUGE-L score between the prediction and the reference
|
||||
"""
|
||||
references, predictions = [], []
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
predictions.append(prediction)
|
||||
references.append(answer)
|
||||
|
||||
# compute the accuracy of score_list
|
||||
rouge_scores = compute_rouge(predictions, references)
|
||||
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
|
||||
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
|
||||
return {"score": average_rouge_l}
|
||||
from ..utils.function_utils import compute_rouge
|
||||
|
||||
#舆情摘要
|
||||
def compute_yqzy(data_dict):
|
||||
"""
|
||||
Compute the ROUGE-L score between the prediction and the reference
|
||||
"""
|
||||
references, predictions = [], []
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
predictions.append(prediction)
|
||||
references.append(answer)
|
||||
|
||||
# compute the accuracy of score_list
|
||||
rouge_scores = compute_rouge(predictions, references)
|
||||
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
|
||||
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
|
||||
return {"score": average_rouge_l}
|
||||
|
@ -1,27 +1,27 @@
|
||||
from ..utils.function_utils import multi_choice_judge
|
||||
|
||||
"""
|
||||
task: multiple choice classification
|
||||
metric: accuracy
|
||||
咨询分类
|
||||
"""
|
||||
|
||||
def compute_zxfl(data_dict):
|
||||
"""
|
||||
A reference (R) contains a list of options, each option is from the option_list.
|
||||
We will extract the options appearing in the prediction and convert them into a set (P).
|
||||
We compute the accuracy between the prediction (P) and the reference (R).
|
||||
"""
|
||||
|
||||
|
||||
score_list, abstentions = [], 0
|
||||
option_list = ['婚姻家庭', '劳动纠纷', '交通事故', '债权债务', '刑事辩护', '合同纠纷', '房产纠纷', '侵权', '公司法', '医疗纠纷', '拆迁安置', '行政诉讼', '建设工程', '知识产权', '综合咨询', '人身损害', '涉外法律', '海事海商', '消费权益', '抵押担保']
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
judge = multi_choice_judge(prediction, option_list, answer)
|
||||
score_list.append(judge["score"])
|
||||
abstentions += judge["abstention"]
|
||||
|
||||
# compute the accuracy of score_list
|
||||
final_accuracy_score = sum(score_list) / len(score_list)
|
||||
return {'score': final_accuracy_score, 'abstention_rate': abstentions / len(data_dict)}
|
||||
from ..utils.function_utils import multi_choice_judge
|
||||
|
||||
"""
|
||||
task: multiple choice classification
|
||||
metric: accuracy
|
||||
咨询分类
|
||||
"""
|
||||
|
||||
def compute_zxfl(data_dict):
|
||||
"""
|
||||
A reference (R) contains a list of options, each option is from the option_list.
|
||||
We will extract the options appearing in the prediction and convert them into a set (P).
|
||||
We compute the accuracy between the prediction (P) and the reference (R).
|
||||
"""
|
||||
|
||||
|
||||
score_list, abstentions = [], 0
|
||||
option_list = ['婚姻家庭', '劳动纠纷', '交通事故', '债权债务', '刑事辩护', '合同纠纷', '房产纠纷', '侵权', '公司法', '医疗纠纷', '拆迁安置', '行政诉讼', '建设工程', '知识产权', '综合咨询', '人身损害', '涉外法律', '海事海商', '消费权益', '抵押担保']
|
||||
for example in data_dict:
|
||||
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||
judge = multi_choice_judge(prediction, option_list, answer)
|
||||
score_list.append(judge["score"])
|
||||
abstentions += judge["abstention"]
|
||||
|
||||
# compute the accuracy of score_list
|
||||
final_accuracy_score = sum(score_list) / len(score_list)
|
||||
return {'score': final_accuracy_score, 'abstention_rate': abstentions / len(data_dict)}
|
||||
|
@ -1,456 +1,456 @@
|
||||
### Copy from https://github.com/iqiyi/FASPell ###
|
||||
|
||||
"""
|
||||
Requirements:
|
||||
- java (required only if tree edit distance is used)
|
||||
- numpy
|
||||
"""
|
||||
import numpy as np
|
||||
from subprocess import Popen, PIPE, STDOUT
|
||||
import os
|
||||
import argparse
|
||||
|
||||
IDCS = {'\u2ff0': 2, # 12 ideographic description characters and their capacity of son nodes
|
||||
'\u2ff1': 2,
|
||||
'\u2ff2': 3,
|
||||
'\u2ff3': 3,
|
||||
'\u2ff4': 2,
|
||||
'\u2ff5': 2,
|
||||
'\u2ff6': 2,
|
||||
'\u2ff7': 2,
|
||||
'\u2ff8': 2,
|
||||
'\u2ff9': 2,
|
||||
'\u2ffa': 2,
|
||||
'\u2ffb': 2, }
|
||||
|
||||
PINYIN = {'ā': ['a', 1], 'á': ['a', 2], 'ǎ': ['a', 3], 'à': ['a', 4],
|
||||
'ē': ['e', 1], 'é': ['e', 2], 'ě': ['e', 3], 'è': ['e', 4],
|
||||
'ī': ['i', 1], 'í': ['i', 2], 'ǐ': ['i', 3], 'ì': ['i', 4],
|
||||
'ō': ['o', 1], 'ó': ['o', 2], 'ǒ': ['o', 3], 'ò': ['o', 4],
|
||||
'ū': ['u', 1], 'ú': ['u', 2], 'ǔ': ['u', 3], 'ù': ['u', 4],
|
||||
'ǖ': ['ü', 1], 'ǘ': ['ü', 2], 'ǚ': ['ü', 3], 'ǜ': ['ü', 4],
|
||||
'': ['m', 2], 'ń': ['n', 2], 'ň': ['n', 3], 'ǹ': ['n', 4],
|
||||
}
|
||||
|
||||
# APTED_JAR_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'apted.jar')
|
||||
APTED_JAR_PATH = 'apted.jar'
|
||||
|
||||
|
||||
def tree_edit_distance(tree_a, tree_b):
|
||||
"""
|
||||
We use APTED algorithm proposed by M. Pawlik and N. Augsten
|
||||
github link: https://github.com/DatabaseGroup/apted
|
||||
"""
|
||||
p = Popen(['java', '-jar', APTED_JAR_PATH, '-t', tree_a, tree_b], stdout=PIPE, stderr=STDOUT)
|
||||
|
||||
res = [line for line in p.stdout]
|
||||
res = res[0]
|
||||
res = res.strip()
|
||||
res = float(res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def edit_distance(string_a, string_b, name='Levenshtein'):
|
||||
"""
|
||||
>>> edit_distance('abcde', 'avbcude')
|
||||
2
|
||||
>>> edit_distance(['至', '刂'], ['亻', '至', '刂'])
|
||||
1
|
||||
>>> edit_distance('fang', 'qwe')
|
||||
4
|
||||
>>> edit_distance('fang', 'hen')
|
||||
3
|
||||
"""
|
||||
size_x = len(string_a) + 1
|
||||
size_y = len(string_b) + 1
|
||||
matrix = np.zeros((size_x, size_y), dtype=int)
|
||||
for x in range(size_x):
|
||||
matrix[x, 0] = x
|
||||
for y in range(size_y):
|
||||
matrix[0, y] = y
|
||||
|
||||
for x in range(1, size_x):
|
||||
for y in range(1, size_y):
|
||||
if string_a[x - 1] == string_b[y - 1]:
|
||||
matrix[x, y] = min(
|
||||
matrix[x - 1, y] + 1,
|
||||
matrix[x - 1, y - 1],
|
||||
matrix[x, y - 1] + 1
|
||||
)
|
||||
else:
|
||||
if name == 'Levenshtein':
|
||||
matrix[x, y] = min(
|
||||
matrix[x - 1, y] + 1,
|
||||
matrix[x - 1, y - 1] + 1,
|
||||
matrix[x, y - 1] + 1
|
||||
)
|
||||
else: # Canonical
|
||||
matrix[x, y] = min(
|
||||
matrix[x - 1, y] + 1,
|
||||
matrix[x - 1, y - 1] + 2,
|
||||
matrix[x, y - 1] + 1
|
||||
)
|
||||
|
||||
return matrix[size_x - 1, size_y - 1]
|
||||
|
||||
|
||||
class CharFuncs(object):
|
||||
def __init__(self, char_meta_fname):
|
||||
self.data = self.load_char_meta(char_meta_fname)
|
||||
self.char_dict = dict([(c, 0) for c in self.data])
|
||||
|
||||
self.safe = {'\u2ff0': 'A',
|
||||
# to eliminate the bug that, in Windows CMD, char ⿻ and ⿵ are encoded to be the same.
|
||||
'\u2ff1': 'B',
|
||||
'\u2ff2': 'C',
|
||||
'\u2ff3': 'D',
|
||||
'\u2ff4': 'E',
|
||||
'\u2ff5': 'F',
|
||||
'\u2ff6': 'G',
|
||||
'\u2ff7': 'H',
|
||||
'\u2ff8': 'I',
|
||||
'\u2ff9': 'J',
|
||||
'\u2ffa': 'L',
|
||||
'\u2ffb': 'M', }
|
||||
|
||||
@staticmethod
|
||||
def load_char_meta(fname):
|
||||
data = {}
|
||||
f = open(fname, 'r', encoding='utf-8')
|
||||
for line in f:
|
||||
items = line.strip().split('\t')
|
||||
code_point = items[0]
|
||||
char = items[1]
|
||||
pronunciation = items[2]
|
||||
decompositions = items[3:]
|
||||
assert char not in data
|
||||
data[char] = {"code_point": code_point, "pronunciation": pronunciation, "decompositions": decompositions}
|
||||
return data
|
||||
|
||||
def shape_distance(self, char1, char2, safe=True, as_tree=False):
|
||||
"""
|
||||
>>> c = CharFuncs('data/char_meta.txt')
|
||||
>>> c.shape_distance('田', '由')
|
||||
1
|
||||
>>> c.shape_distance('牛', '午')
|
||||
1
|
||||
"""
|
||||
assert char1 in self.data
|
||||
assert char2 in self.data
|
||||
|
||||
def safe_encode(decomp):
|
||||
tree = ''
|
||||
for c in string_to_tree(decomp):
|
||||
if c not in self.safe:
|
||||
tree += c
|
||||
else:
|
||||
tree += self.safe[c]
|
||||
return tree
|
||||
|
||||
def safe_encode_string(decomp):
|
||||
tree = ''
|
||||
for c in decomp:
|
||||
if c not in self.safe:
|
||||
tree += c
|
||||
else:
|
||||
tree += self.safe[c]
|
||||
return tree
|
||||
|
||||
decomps_1 = self.data[char1]["decompositions"]
|
||||
decomps_2 = self.data[char2]["decompositions"]
|
||||
|
||||
distance = 1e5
|
||||
if as_tree:
|
||||
for decomp1 in decomps_1:
|
||||
for decomp2 in decomps_2:
|
||||
if not safe:
|
||||
ted = tree_edit_distance(string_to_tree(decomp1), string_to_tree(decomp2))
|
||||
else:
|
||||
ted = tree_edit_distance(safe_encode(decomp1), safe_encode(decomp2))
|
||||
distance = min(distance, ted)
|
||||
else:
|
||||
for decomp1 in decomps_1:
|
||||
for decomp2 in decomps_2:
|
||||
if not safe:
|
||||
ed = edit_distance(decomp1, decomp2)
|
||||
else:
|
||||
ed = edit_distance(safe_encode_string(decomp1), safe_encode_string(decomp2))
|
||||
distance = min(distance, ed)
|
||||
|
||||
return distance
|
||||
|
||||
def pronunciation_distance(self, char1, char2):
|
||||
"""
|
||||
>>> c = CharFuncs('data/char_meta.txt')
|
||||
>>> c.pronunciation_distance('田', '由')
|
||||
3.4
|
||||
>>> c.pronunciation_distance('牛', '午')
|
||||
2.6
|
||||
"""
|
||||
assert char1 in self.data
|
||||
assert char2 in self.data
|
||||
pronunciations1 = self.data[char1]["pronunciation"]
|
||||
pronunciations2 = self.data[char2]["pronunciation"]
|
||||
|
||||
if pronunciations1[0] == 'null' or pronunciations2 == 'null':
|
||||
return 0.0
|
||||
else:
|
||||
|
||||
pronunciations1 = pronunciations1.split(';') # separate by lan
|
||||
pronunciations2 = pronunciations2.split(';') # separate by lan
|
||||
|
||||
distance = 0.0
|
||||
count = 0
|
||||
for pron_lan1, pron_lan2 in zip(pronunciations1, pronunciations2):
|
||||
if (pron_lan1 == 'null') or (pron_lan2 == 'null'):
|
||||
pass
|
||||
else:
|
||||
distance_lan = 1e5
|
||||
for p1 in pron_lan1.split(','):
|
||||
for p2 in pron_lan2.split(','):
|
||||
distance_lan = min(distance_lan, edit_distance(p1, p2))
|
||||
distance += distance_lan
|
||||
count += 1
|
||||
|
||||
return distance / count
|
||||
|
||||
@staticmethod
|
||||
def load_dict(fname):
|
||||
data = {}
|
||||
f = open(fname, 'r', encoding='utf-8')
|
||||
for line in f:
|
||||
char, freq = line.strip().split('\t')
|
||||
assert char not in data
|
||||
data[char] = freq
|
||||
|
||||
return data
|
||||
|
||||
def similarity(self, char1, char2, weights=(0.8, 0.2, 0.0), as_tree=False):
|
||||
"""
|
||||
this function returns weighted similarity. When used in FASPell, each weight can only be 0 or 1.
|
||||
"""
|
||||
|
||||
# assert char1 in self.char_dict
|
||||
# assert char2 in self.char_dict
|
||||
shape_w, sound_w, freq_w = weights
|
||||
|
||||
if char1 in self.char_dict and char2 in self.char_dict:
|
||||
|
||||
shape_sim = self.shape_similarity(char1, char2, as_tree=as_tree)
|
||||
sound_sim = self.pronunciation_similarity(char1, char2)
|
||||
freq_sim = 1.0 - self.char_dict[char2] / len(self.char_dict)
|
||||
|
||||
return shape_sim * shape_w + sound_sim * sound_w + freq_sim * freq_w
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def shape_similarity(self, char1, char2, safe=True, as_tree=False):
|
||||
"""
|
||||
>>> c = CharFuncs('data/char_meta.txt')
|
||||
>>> c.shape_similarity('牛', '午')
|
||||
0.8571428571428572
|
||||
>>> c.shape_similarity('田', '由')
|
||||
0.8888888888888888
|
||||
"""
|
||||
assert char1 in self.data
|
||||
assert char2 in self.data
|
||||
|
||||
def safe_encode(decomp):
|
||||
tree = ''
|
||||
for c in string_to_tree(decomp):
|
||||
if c not in self.safe:
|
||||
tree += c
|
||||
else:
|
||||
tree += self.safe[c]
|
||||
return tree
|
||||
|
||||
def safe_encode_string(decomp):
|
||||
tree = ''
|
||||
for c in decomp:
|
||||
if c not in self.safe:
|
||||
tree += c
|
||||
else:
|
||||
tree += self.safe[c]
|
||||
return tree
|
||||
|
||||
decomps_1 = self.data[char1]["decompositions"]
|
||||
decomps_2 = self.data[char2]["decompositions"]
|
||||
|
||||
similarity = 0.0
|
||||
if as_tree:
|
||||
for decomp1 in decomps_1:
|
||||
for decomp2 in decomps_2:
|
||||
if not safe:
|
||||
ted = tree_edit_distance(string_to_tree(decomp1), string_to_tree(decomp2))
|
||||
else:
|
||||
ted = tree_edit_distance(safe_encode(decomp1), safe_encode(decomp2))
|
||||
normalized_ted = 2 * ted / (len(decomp1) + len(decomp2) + ted)
|
||||
similarity = max(similarity, 1 - normalized_ted)
|
||||
else:
|
||||
for decomp1 in decomps_1:
|
||||
for decomp2 in decomps_2:
|
||||
if not safe:
|
||||
ed = edit_distance(decomp1, decomp2)
|
||||
else:
|
||||
ed = edit_distance(safe_encode_string(decomp1), safe_encode_string(decomp2))
|
||||
normalized_ed = ed / max(len(decomp1), len(decomp2))
|
||||
similarity = max(similarity, 1 - normalized_ed)
|
||||
|
||||
return similarity
|
||||
|
||||
def pronunciation_similarity(self, char1, char2):
|
||||
"""
|
||||
>>> c = CharFuncs('data/char_meta.txt')
|
||||
>>> c.pronunciation_similarity('牛', '午')
|
||||
0.27999999999999997
|
||||
>>> c.pronunciation_similarity('由', '田')
|
||||
0.09
|
||||
|
||||
"""
|
||||
assert char1 in self.data
|
||||
assert char2 in self.data
|
||||
pronunciations1 = self.data[char1]["pronunciation"]
|
||||
pronunciations2 = self.data[char2]["pronunciation"]
|
||||
|
||||
if pronunciations1[0] == 'null' or pronunciations2 == 'null':
|
||||
return 0.0
|
||||
else:
|
||||
|
||||
pronunciations1 = pronunciations1.split(';') # separate by lan
|
||||
pronunciations2 = pronunciations2.split(';') # separate by lan
|
||||
|
||||
similarity = 0.0
|
||||
count = 0
|
||||
for pron_lan1, pron_lan2 in zip(pronunciations1, pronunciations2):
|
||||
if (pron_lan1 == 'null') or (pron_lan2 == 'null'):
|
||||
pass
|
||||
else:
|
||||
similarity_lan = 0.0
|
||||
for p1 in pron_lan1.split(','):
|
||||
for p2 in pron_lan2.split(','):
|
||||
tmp_sim = 1 - edit_distance(p1, p2) / max(len(p1), len(p2))
|
||||
similarity_lan = max(similarity_lan, tmp_sim)
|
||||
similarity += similarity_lan
|
||||
count += 1
|
||||
|
||||
return similarity / count if count else 0
|
||||
|
||||
|
||||
def string_to_tree(string):
|
||||
"""
|
||||
This function converts ids string to a string that can be used as a tree input to APTED.
|
||||
Any Error raised by this function implies that the input string is invalid.
|
||||
>>> string_to_tree('⿱⿱⿰丿㇏⿰丿㇏⿱⿰丿㇏⿰丿㇏') # 炎
|
||||
'{⿱{⿱{⿰{丿}{㇏}}{⿰{丿}{㇏}}}{⿱{⿰{丿}{㇏}}{⿰{丿}{㇏}}}}'
|
||||
>>> string_to_tree('⿱⿰丿㇏⿱一⿱⿻一丨一') # 全
|
||||
'{⿱{⿰{丿}{㇏}}{⿱{一}{⿱{⿻{一}{丨}}{一}}}}'
|
||||
>>> string_to_tree('⿱⿰丿㇏⿻⿱一⿱⿻一丨一丷') # 金
|
||||
'{⿱{⿰{丿}{㇏}}{⿻{⿱{一}{⿱{⿻{一}{丨}}{一}}}{丷}}}'
|
||||
>>> string_to_tree('⿻⿻⿻一丨一⿴⿱⿰丨𠃌一一') # 車
|
||||
'{⿻{⿻{⿻{一}{丨}}{一}}{⿴{⿱{⿰{丨}{𠃌}}{一}}{一}}}'
|
||||
>>> string_to_tree('⿻⿻⿻一丨⿰丿㇏⿴⿱⿰丨𠃌一一') # 東
|
||||
'{⿻{⿻{⿻{一}{丨}}{⿰{丿}{㇏}}}{⿴{⿱{⿰{丨}{𠃌}}{一}}{一}}}'
|
||||
>>> string_to_tree('丿') # 丿
|
||||
'{丿}'
|
||||
>>> string_to_tree('⿻') # ⿻
|
||||
'{⿻}'
|
||||
"""
|
||||
if string[0] in IDCS and len(string) != 1:
|
||||
bracket_stack = []
|
||||
tree = []
|
||||
|
||||
def add_brackets(num):
|
||||
if num == 2:
|
||||
bracket_stack.extend(['}', '{', '}'])
|
||||
else:
|
||||
bracket_stack.extend(['}', '{', '}', '{', '}'])
|
||||
tree.append('{')
|
||||
|
||||
global_just_put = '{'
|
||||
|
||||
for c in string:
|
||||
tree.append(c)
|
||||
if c in IDCS:
|
||||
assert global_just_put != '}'
|
||||
add_brackets(IDCS[c])
|
||||
global_just_put = '{'
|
||||
else:
|
||||
just_put = ''
|
||||
while just_put != '{' and bracket_stack:
|
||||
just_put = bracket_stack.pop(-1)
|
||||
tree.append(just_put)
|
||||
global_just_put = just_put
|
||||
|
||||
res = ''.join(tree)
|
||||
assert res[-1] == '}'
|
||||
else:
|
||||
assert len(string) == 1 or string == 'null'
|
||||
res = string[0]
|
||||
|
||||
return '{' + res + '}'
|
||||
|
||||
|
||||
def pinyin_map(standard_pinyin):
|
||||
"""
|
||||
>>> pinyin_map('xuě')
|
||||
'xue3'
|
||||
>>> pinyin_map('xue')
|
||||
'xue'
|
||||
>>> pinyin_map('lǜ')
|
||||
'lü4'
|
||||
>>> pinyin_map('fá')
|
||||
'fa2'
|
||||
"""
|
||||
tone = ''
|
||||
pinyin = ''
|
||||
|
||||
assert ' ' not in standard_pinyin
|
||||
for c in standard_pinyin:
|
||||
if c in PINYIN:
|
||||
pinyin += PINYIN[c][0]
|
||||
assert tone == ''
|
||||
tone = str(PINYIN[c][1])
|
||||
else:
|
||||
pinyin += c
|
||||
pinyin += tone
|
||||
return pinyin
|
||||
|
||||
|
||||
def parse_args():
|
||||
usage = '\n1. You can compute character similarity by:\n' \
|
||||
'python char_sim.py 午 牛 年 千\n' \
|
||||
'\n' \
|
||||
'2. You can use ted in computing character similarity by:\n' \
|
||||
'python char_sim.py 午 牛 年 千 -t\n' \
|
||||
'\n'
|
||||
parser = argparse.ArgumentParser(
|
||||
description='A script to compute Chinese character (Kanji) similarity', usage=usage)
|
||||
|
||||
parser.add_argument('multiargs', nargs='*', type=str, default=None,
|
||||
help='Chinese characters in question')
|
||||
parser.add_argument('--ted', '-t', action="store_true", default=False,
|
||||
help='True=to use tree edit distence (TED)'
|
||||
'False=to use string edit distance')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
c = CharFuncs('data/char_meta.txt')
|
||||
if not args.ted:
|
||||
for i, c1 in enumerate(args.multiargs):
|
||||
for c2 in args.multiargs[i:]:
|
||||
if c1 != c2:
|
||||
print(f'For character pair ({c1}, {c2}):')
|
||||
print(f' v-sim = {c.shape_similarity(c1, c2)}')
|
||||
print(f' p-sim = {c.pronunciation_similarity(c1, c2)}\n')
|
||||
else:
|
||||
for i, c1 in enumerate(args.multiargs):
|
||||
for c2 in args.multiargs[i:]:
|
||||
if c1 != c2:
|
||||
print(f'For character pair ({c1}, {c2}):')
|
||||
print(f' v-sim = {c.shape_similarity(c1, c2, as_tree=True)}')
|
||||
### Copy from https://github.com/iqiyi/FASPell ###
|
||||
|
||||
"""
|
||||
Requirements:
|
||||
- java (required only if tree edit distance is used)
|
||||
- numpy
|
||||
"""
|
||||
import numpy as np
|
||||
from subprocess import Popen, PIPE, STDOUT
|
||||
import os
|
||||
import argparse
|
||||
|
||||
IDCS = {'\u2ff0': 2, # 12 ideographic description characters and their capacity of son nodes
|
||||
'\u2ff1': 2,
|
||||
'\u2ff2': 3,
|
||||
'\u2ff3': 3,
|
||||
'\u2ff4': 2,
|
||||
'\u2ff5': 2,
|
||||
'\u2ff6': 2,
|
||||
'\u2ff7': 2,
|
||||
'\u2ff8': 2,
|
||||
'\u2ff9': 2,
|
||||
'\u2ffa': 2,
|
||||
'\u2ffb': 2, }
|
||||
|
||||
PINYIN = {'ā': ['a', 1], 'á': ['a', 2], 'ǎ': ['a', 3], 'à': ['a', 4],
|
||||
'ē': ['e', 1], 'é': ['e', 2], 'ě': ['e', 3], 'è': ['e', 4],
|
||||
'ī': ['i', 1], 'í': ['i', 2], 'ǐ': ['i', 3], 'ì': ['i', 4],
|
||||
'ō': ['o', 1], 'ó': ['o', 2], 'ǒ': ['o', 3], 'ò': ['o', 4],
|
||||
'ū': ['u', 1], 'ú': ['u', 2], 'ǔ': ['u', 3], 'ù': ['u', 4],
|
||||
'ǖ': ['ü', 1], 'ǘ': ['ü', 2], 'ǚ': ['ü', 3], 'ǜ': ['ü', 4],
|
||||
'': ['m', 2], 'ń': ['n', 2], 'ň': ['n', 3], 'ǹ': ['n', 4],
|
||||
}
|
||||
|
||||
# APTED_JAR_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'apted.jar')
|
||||
APTED_JAR_PATH = 'apted.jar'
|
||||
|
||||
|
||||
def tree_edit_distance(tree_a, tree_b):
|
||||
"""
|
||||
We use APTED algorithm proposed by M. Pawlik and N. Augsten
|
||||
github link: https://github.com/DatabaseGroup/apted
|
||||
"""
|
||||
p = Popen(['java', '-jar', APTED_JAR_PATH, '-t', tree_a, tree_b], stdout=PIPE, stderr=STDOUT)
|
||||
|
||||
res = [line for line in p.stdout]
|
||||
res = res[0]
|
||||
res = res.strip()
|
||||
res = float(res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def edit_distance(string_a, string_b, name='Levenshtein'):
|
||||
"""
|
||||
>>> edit_distance('abcde', 'avbcude')
|
||||
2
|
||||
>>> edit_distance(['至', '刂'], ['亻', '至', '刂'])
|
||||
1
|
||||
>>> edit_distance('fang', 'qwe')
|
||||
4
|
||||
>>> edit_distance('fang', 'hen')
|
||||
3
|
||||
"""
|
||||
size_x = len(string_a) + 1
|
||||
size_y = len(string_b) + 1
|
||||
matrix = np.zeros((size_x, size_y), dtype=int)
|
||||
for x in range(size_x):
|
||||
matrix[x, 0] = x
|
||||
for y in range(size_y):
|
||||
matrix[0, y] = y
|
||||
|
||||
for x in range(1, size_x):
|
||||
for y in range(1, size_y):
|
||||
if string_a[x - 1] == string_b[y - 1]:
|
||||
matrix[x, y] = min(
|
||||
matrix[x - 1, y] + 1,
|
||||
matrix[x - 1, y - 1],
|
||||
matrix[x, y - 1] + 1
|
||||
)
|
||||
else:
|
||||
if name == 'Levenshtein':
|
||||
matrix[x, y] = min(
|
||||
matrix[x - 1, y] + 1,
|
||||
matrix[x - 1, y - 1] + 1,
|
||||
matrix[x, y - 1] + 1
|
||||
)
|
||||
else: # Canonical
|
||||
matrix[x, y] = min(
|
||||
matrix[x - 1, y] + 1,
|
||||
matrix[x - 1, y - 1] + 2,
|
||||
matrix[x, y - 1] + 1
|
||||
)
|
||||
|
||||
return matrix[size_x - 1, size_y - 1]
|
||||
|
||||
|
||||
class CharFuncs(object):
|
||||
def __init__(self, char_meta_fname):
|
||||
self.data = self.load_char_meta(char_meta_fname)
|
||||
self.char_dict = dict([(c, 0) for c in self.data])
|
||||
|
||||
self.safe = {'\u2ff0': 'A',
|
||||
# to eliminate the bug that, in Windows CMD, char ⿻ and ⿵ are encoded to be the same.
|
||||
'\u2ff1': 'B',
|
||||
'\u2ff2': 'C',
|
||||
'\u2ff3': 'D',
|
||||
'\u2ff4': 'E',
|
||||
'\u2ff5': 'F',
|
||||
'\u2ff6': 'G',
|
||||
'\u2ff7': 'H',
|
||||
'\u2ff8': 'I',
|
||||
'\u2ff9': 'J',
|
||||
'\u2ffa': 'L',
|
||||
'\u2ffb': 'M', }
|
||||
|
||||
@staticmethod
|
||||
def load_char_meta(fname):
|
||||
data = {}
|
||||
f = open(fname, 'r', encoding='utf-8')
|
||||
for line in f:
|
||||
items = line.strip().split('\t')
|
||||
code_point = items[0]
|
||||
char = items[1]
|
||||
pronunciation = items[2]
|
||||
decompositions = items[3:]
|
||||
assert char not in data
|
||||
data[char] = {"code_point": code_point, "pronunciation": pronunciation, "decompositions": decompositions}
|
||||
return data
|
||||
|
||||
def shape_distance(self, char1, char2, safe=True, as_tree=False):
|
||||
"""
|
||||
>>> c = CharFuncs('data/char_meta.txt')
|
||||
>>> c.shape_distance('田', '由')
|
||||
1
|
||||
>>> c.shape_distance('牛', '午')
|
||||
1
|
||||
"""
|
||||
assert char1 in self.data
|
||||
assert char2 in self.data
|
||||
|
||||
def safe_encode(decomp):
|
||||
tree = ''
|
||||
for c in string_to_tree(decomp):
|
||||
if c not in self.safe:
|
||||
tree += c
|
||||
else:
|
||||
tree += self.safe[c]
|
||||
return tree
|
||||
|
||||
def safe_encode_string(decomp):
|
||||
tree = ''
|
||||
for c in decomp:
|
||||
if c not in self.safe:
|
||||
tree += c
|
||||
else:
|
||||
tree += self.safe[c]
|
||||
return tree
|
||||
|
||||
decomps_1 = self.data[char1]["decompositions"]
|
||||
decomps_2 = self.data[char2]["decompositions"]
|
||||
|
||||
distance = 1e5
|
||||
if as_tree:
|
||||
for decomp1 in decomps_1:
|
||||
for decomp2 in decomps_2:
|
||||
if not safe:
|
||||
ted = tree_edit_distance(string_to_tree(decomp1), string_to_tree(decomp2))
|
||||
else:
|
||||
ted = tree_edit_distance(safe_encode(decomp1), safe_encode(decomp2))
|
||||
distance = min(distance, ted)
|
||||
else:
|
||||
for decomp1 in decomps_1:
|
||||
for decomp2 in decomps_2:
|
||||
if not safe:
|
||||
ed = edit_distance(decomp1, decomp2)
|
||||
else:
|
||||
ed = edit_distance(safe_encode_string(decomp1), safe_encode_string(decomp2))
|
||||
distance = min(distance, ed)
|
||||
|
||||
return distance
|
||||
|
||||
def pronunciation_distance(self, char1, char2):
|
||||
"""
|
||||
>>> c = CharFuncs('data/char_meta.txt')
|
||||
>>> c.pronunciation_distance('田', '由')
|
||||
3.4
|
||||
>>> c.pronunciation_distance('牛', '午')
|
||||
2.6
|
||||
"""
|
||||
assert char1 in self.data
|
||||
assert char2 in self.data
|
||||
pronunciations1 = self.data[char1]["pronunciation"]
|
||||
pronunciations2 = self.data[char2]["pronunciation"]
|
||||
|
||||
if pronunciations1[0] == 'null' or pronunciations2 == 'null':
|
||||
return 0.0
|
||||
else:
|
||||
|
||||
pronunciations1 = pronunciations1.split(';') # separate by lan
|
||||
pronunciations2 = pronunciations2.split(';') # separate by lan
|
||||
|
||||
distance = 0.0
|
||||
count = 0
|
||||
for pron_lan1, pron_lan2 in zip(pronunciations1, pronunciations2):
|
||||
if (pron_lan1 == 'null') or (pron_lan2 == 'null'):
|
||||
pass
|
||||
else:
|
||||
distance_lan = 1e5
|
||||
for p1 in pron_lan1.split(','):
|
||||
for p2 in pron_lan2.split(','):
|
||||
distance_lan = min(distance_lan, edit_distance(p1, p2))
|
||||
distance += distance_lan
|
||||
count += 1
|
||||
|
||||
return distance / count
|
||||
|
||||
@staticmethod
|
||||
def load_dict(fname):
|
||||
data = {}
|
||||
f = open(fname, 'r', encoding='utf-8')
|
||||
for line in f:
|
||||
char, freq = line.strip().split('\t')
|
||||
assert char not in data
|
||||
data[char] = freq
|
||||
|
||||
return data
|
||||
|
||||
def similarity(self, char1, char2, weights=(0.8, 0.2, 0.0), as_tree=False):
|
||||
"""
|
||||
this function returns weighted similarity. When used in FASPell, each weight can only be 0 or 1.
|
||||
"""
|
||||
|
||||
# assert char1 in self.char_dict
|
||||
# assert char2 in self.char_dict
|
||||
shape_w, sound_w, freq_w = weights
|
||||
|
||||
if char1 in self.char_dict and char2 in self.char_dict:
|
||||
|
||||
shape_sim = self.shape_similarity(char1, char2, as_tree=as_tree)
|
||||
sound_sim = self.pronunciation_similarity(char1, char2)
|
||||
freq_sim = 1.0 - self.char_dict[char2] / len(self.char_dict)
|
||||
|
||||
return shape_sim * shape_w + sound_sim * sound_w + freq_sim * freq_w
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def shape_similarity(self, char1, char2, safe=True, as_tree=False):
|
||||
"""
|
||||
>>> c = CharFuncs('data/char_meta.txt')
|
||||
>>> c.shape_similarity('牛', '午')
|
||||
0.8571428571428572
|
||||
>>> c.shape_similarity('田', '由')
|
||||
0.8888888888888888
|
||||
"""
|
||||
assert char1 in self.data
|
||||
assert char2 in self.data
|
||||
|
||||
def safe_encode(decomp):
|
||||
tree = ''
|
||||
for c in string_to_tree(decomp):
|
||||
if c not in self.safe:
|
||||
tree += c
|
||||
else:
|
||||
tree += self.safe[c]
|
||||
return tree
|
||||
|
||||
def safe_encode_string(decomp):
|
||||
tree = ''
|
||||
for c in decomp:
|
||||
if c not in self.safe:
|
||||
tree += c
|
||||
else:
|
||||
tree += self.safe[c]
|
||||
return tree
|
||||
|
||||
decomps_1 = self.data[char1]["decompositions"]
|
||||
decomps_2 = self.data[char2]["decompositions"]
|
||||
|
||||
similarity = 0.0
|
||||
if as_tree:
|
||||
for decomp1 in decomps_1:
|
||||
for decomp2 in decomps_2:
|
||||
if not safe:
|
||||
ted = tree_edit_distance(string_to_tree(decomp1), string_to_tree(decomp2))
|
||||
else:
|
||||
ted = tree_edit_distance(safe_encode(decomp1), safe_encode(decomp2))
|
||||
normalized_ted = 2 * ted / (len(decomp1) + len(decomp2) + ted)
|
||||
similarity = max(similarity, 1 - normalized_ted)
|
||||
else:
|
||||
for decomp1 in decomps_1:
|
||||
for decomp2 in decomps_2:
|
||||
if not safe:
|
||||
ed = edit_distance(decomp1, decomp2)
|
||||
else:
|
||||
ed = edit_distance(safe_encode_string(decomp1), safe_encode_string(decomp2))
|
||||
normalized_ed = ed / max(len(decomp1), len(decomp2))
|
||||
similarity = max(similarity, 1 - normalized_ed)
|
||||
|
||||
return similarity
|
||||
|
||||
def pronunciation_similarity(self, char1, char2):
|
||||
"""
|
||||
>>> c = CharFuncs('data/char_meta.txt')
|
||||
>>> c.pronunciation_similarity('牛', '午')
|
||||
0.27999999999999997
|
||||
>>> c.pronunciation_similarity('由', '田')
|
||||
0.09
|
||||
|
||||
"""
|
||||
assert char1 in self.data
|
||||
assert char2 in self.data
|
||||
pronunciations1 = self.data[char1]["pronunciation"]
|
||||
pronunciations2 = self.data[char2]["pronunciation"]
|
||||
|
||||
if pronunciations1[0] == 'null' or pronunciations2 == 'null':
|
||||
return 0.0
|
||||
else:
|
||||
|
||||
pronunciations1 = pronunciations1.split(';') # separate by lan
|
||||
pronunciations2 = pronunciations2.split(';') # separate by lan
|
||||
|
||||
similarity = 0.0
|
||||
count = 0
|
||||
for pron_lan1, pron_lan2 in zip(pronunciations1, pronunciations2):
|
||||
if (pron_lan1 == 'null') or (pron_lan2 == 'null'):
|
||||
pass
|
||||
else:
|
||||
similarity_lan = 0.0
|
||||
for p1 in pron_lan1.split(','):
|
||||
for p2 in pron_lan2.split(','):
|
||||
tmp_sim = 1 - edit_distance(p1, p2) / max(len(p1), len(p2))
|
||||
similarity_lan = max(similarity_lan, tmp_sim)
|
||||
similarity += similarity_lan
|
||||
count += 1
|
||||
|
||||
return similarity / count if count else 0
|
||||
|
||||
|
||||
def string_to_tree(string):
|
||||
"""
|
||||
This function converts ids string to a string that can be used as a tree input to APTED.
|
||||
Any Error raised by this function implies that the input string is invalid.
|
||||
>>> string_to_tree('⿱⿱⿰丿㇏⿰丿㇏⿱⿰丿㇏⿰丿㇏') # 炎
|
||||
'{⿱{⿱{⿰{丿}{㇏}}{⿰{丿}{㇏}}}{⿱{⿰{丿}{㇏}}{⿰{丿}{㇏}}}}'
|
||||
>>> string_to_tree('⿱⿰丿㇏⿱一⿱⿻一丨一') # 全
|
||||
'{⿱{⿰{丿}{㇏}}{⿱{一}{⿱{⿻{一}{丨}}{一}}}}'
|
||||
>>> string_to_tree('⿱⿰丿㇏⿻⿱一⿱⿻一丨一丷') # 金
|
||||
'{⿱{⿰{丿}{㇏}}{⿻{⿱{一}{⿱{⿻{一}{丨}}{一}}}{丷}}}'
|
||||
>>> string_to_tree('⿻⿻⿻一丨一⿴⿱⿰丨𠃌一一') # 車
|
||||
'{⿻{⿻{⿻{一}{丨}}{一}}{⿴{⿱{⿰{丨}{𠃌}}{一}}{一}}}'
|
||||
>>> string_to_tree('⿻⿻⿻一丨⿰丿㇏⿴⿱⿰丨𠃌一一') # 東
|
||||
'{⿻{⿻{⿻{一}{丨}}{⿰{丿}{㇏}}}{⿴{⿱{⿰{丨}{𠃌}}{一}}{一}}}'
|
||||
>>> string_to_tree('丿') # 丿
|
||||
'{丿}'
|
||||
>>> string_to_tree('⿻') # ⿻
|
||||
'{⿻}'
|
||||
"""
|
||||
if string[0] in IDCS and len(string) != 1:
|
||||
bracket_stack = []
|
||||
tree = []
|
||||
|
||||
def add_brackets(num):
|
||||
if num == 2:
|
||||
bracket_stack.extend(['}', '{', '}'])
|
||||
else:
|
||||
bracket_stack.extend(['}', '{', '}', '{', '}'])
|
||||
tree.append('{')
|
||||
|
||||
global_just_put = '{'
|
||||
|
||||
for c in string:
|
||||
tree.append(c)
|
||||
if c in IDCS:
|
||||
assert global_just_put != '}'
|
||||
add_brackets(IDCS[c])
|
||||
global_just_put = '{'
|
||||
else:
|
||||
just_put = ''
|
||||
while just_put != '{' and bracket_stack:
|
||||
just_put = bracket_stack.pop(-1)
|
||||
tree.append(just_put)
|
||||
global_just_put = just_put
|
||||
|
||||
res = ''.join(tree)
|
||||
assert res[-1] == '}'
|
||||
else:
|
||||
assert len(string) == 1 or string == 'null'
|
||||
res = string[0]
|
||||
|
||||
return '{' + res + '}'
|
||||
|
||||
|
||||
def pinyin_map(standard_pinyin):
|
||||
"""
|
||||
>>> pinyin_map('xuě')
|
||||
'xue3'
|
||||
>>> pinyin_map('xue')
|
||||
'xue'
|
||||
>>> pinyin_map('lǜ')
|
||||
'lü4'
|
||||
>>> pinyin_map('fá')
|
||||
'fa2'
|
||||
"""
|
||||
tone = ''
|
||||
pinyin = ''
|
||||
|
||||
assert ' ' not in standard_pinyin
|
||||
for c in standard_pinyin:
|
||||
if c in PINYIN:
|
||||
pinyin += PINYIN[c][0]
|
||||
assert tone == ''
|
||||
tone = str(PINYIN[c][1])
|
||||
else:
|
||||
pinyin += c
|
||||
pinyin += tone
|
||||
return pinyin
|
||||
|
||||
|
||||
def parse_args():
|
||||
usage = '\n1. You can compute character similarity by:\n' \
|
||||
'python char_sim.py 午 牛 年 千\n' \
|
||||
'\n' \
|
||||
'2. You can use ted in computing character similarity by:\n' \
|
||||
'python char_sim.py 午 牛 年 千 -t\n' \
|
||||
'\n'
|
||||
parser = argparse.ArgumentParser(
|
||||
description='A script to compute Chinese character (Kanji) similarity', usage=usage)
|
||||
|
||||
parser.add_argument('multiargs', nargs='*', type=str, default=None,
|
||||
help='Chinese characters in question')
|
||||
parser.add_argument('--ted', '-t', action="store_true", default=False,
|
||||
help='True=to use tree edit distence (TED)'
|
||||
'False=to use string edit distance')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
c = CharFuncs('data/char_meta.txt')
|
||||
if not args.ted:
|
||||
for i, c1 in enumerate(args.multiargs):
|
||||
for c2 in args.multiargs[i:]:
|
||||
if c1 != c2:
|
||||
print(f'For character pair ({c1}, {c2}):')
|
||||
print(f' v-sim = {c.shape_similarity(c1, c2)}')
|
||||
print(f' p-sim = {c.pronunciation_similarity(c1, c2)}\n')
|
||||
else:
|
||||
for i, c1 in enumerate(args.multiargs):
|
||||
for c2 in args.multiargs[i:]:
|
||||
if c1 != c2:
|
||||
print(f'For character pair ({c1}, {c2}):')
|
||||
print(f' v-sim = {c.shape_similarity(c1, c2, as_tree=True)}')
|
||||
print(f' p-sim = {c.pronunciation_similarity(c1, c2)}\n')
|
@ -1,433 +1,433 @@
|
||||
import argparse
|
||||
from collections import Counter
|
||||
|
||||
def main():
|
||||
# Parse command line args
|
||||
args = parse_args()
|
||||
# Open hypothesis and reference m2 files and split into chunks
|
||||
hyp_m2 = open(args.hyp).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.hyp).read().strip().split("\n\n")
|
||||
ref_m2 = open(args.ref).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.ref).read().strip().split("\n\n")
|
||||
# Make sure they have the same number of sentences
|
||||
assert len(hyp_m2) == len(ref_m2), print(len(hyp_m2), len(ref_m2))
|
||||
|
||||
# Store global corpus level best counts here
|
||||
best_dict = Counter({"tp":0, "fp":0, "fn":0})
|
||||
best_cats = {}
|
||||
# Process each sentence
|
||||
sents = zip(hyp_m2, ref_m2)
|
||||
for sent_id, sent in enumerate(sents):
|
||||
# Simplify the edits into lists of lists
|
||||
# if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons:
|
||||
# sent_id_cons.append(sent_id)
|
||||
src = sent[0].split("\n")[0]
|
||||
hyp_edits = simplify_edits(sent[0], args.max_answer_num)
|
||||
ref_edits = simplify_edits(sent[1], args.max_answer_num)
|
||||
# Process the edits for detection/correction based on args
|
||||
hyp_dict = process_edits(hyp_edits, args)
|
||||
ref_dict = process_edits(ref_edits, args)
|
||||
if args.reference_num is None or len(ref_dict.keys()) == args.reference_num:
|
||||
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
|
||||
count_dict, cat_dict = evaluate_edits(src,
|
||||
hyp_dict, ref_dict, best_dict, sent_id, args)
|
||||
# Merge these dicts with best_dict and best_cats
|
||||
best_dict += Counter(count_dict)
|
||||
best_cats = merge_dict(best_cats, cat_dict)
|
||||
# Print results
|
||||
print_results(best_dict, best_cats, args)
|
||||
|
||||
# Parse command line args
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Calculate F-scores for error detection and/or correction.\n"
|
||||
"Flags let you evaluate at different levels of granularity.",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
usage="%(prog)s [options] -hyp HYP -ref REF")
|
||||
parser.add_argument(
|
||||
"-hyp",
|
||||
help="A hypothesis M2 file.",
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
"-ref",
|
||||
help="A reference M2 file.",
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=int,
|
||||
default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--end",
|
||||
type=int,
|
||||
default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_answer_num",
|
||||
type=int,
|
||||
default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference_num",
|
||||
type=int,
|
||||
default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--beta",
|
||||
help="Value of beta in F-score. (default: 0.5)",
|
||||
default=0.5,
|
||||
type=float)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
help="Print verbose output.",
|
||||
action="store_true")
|
||||
eval_type = parser.add_mutually_exclusive_group()
|
||||
eval_type.add_argument(
|
||||
"-dt",
|
||||
help="Evaluate Detection in terms of Tokens.",
|
||||
action="store_true")
|
||||
eval_type.add_argument(
|
||||
"-ds",
|
||||
help="Evaluate Detection in terms of Spans.",
|
||||
action="store_true")
|
||||
eval_type.add_argument(
|
||||
"-cs",
|
||||
help="Evaluate Correction in terms of Spans. (default)",
|
||||
action="store_true")
|
||||
eval_type.add_argument(
|
||||
"-cse",
|
||||
help="Evaluate Correction in terms of Spans and Error types.",
|
||||
action="store_true")
|
||||
parser.add_argument(
|
||||
"-single",
|
||||
help="Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1",
|
||||
action="store_true")
|
||||
parser.add_argument(
|
||||
"-multi",
|
||||
help="Only evaluate multi token edits; i.e. 2+:n or n:2+",
|
||||
action="store_true")
|
||||
parser.add_argument(
|
||||
"-multi_hyp_avg",
|
||||
help="When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence.",
|
||||
action="store_true") # For IAA calculation
|
||||
parser.add_argument(
|
||||
"-multi_hyp_max",
|
||||
help="When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence.",
|
||||
action="store_true") # For multiple hypotheses system evaluation
|
||||
parser.add_argument(
|
||||
"-filt",
|
||||
help="Do not evaluate the specified error types.",
|
||||
nargs="+",
|
||||
default=[])
|
||||
parser.add_argument(
|
||||
"-cat",
|
||||
help="Show error category scores.\n"
|
||||
"1: Only show operation tier scores; e.g. R.\n"
|
||||
"2: Only show main tier scores; e.g. NOUN.\n"
|
||||
"3: Show all category scores; e.g. R:NOUN.",
|
||||
choices=[1, 2, 3],
|
||||
type=int)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
# Input: An m2 format sentence with edits.
|
||||
# Output: A list of lists. Each edit: [start, end, cat, cor, coder]
|
||||
def simplify_edits(sent, max_answer_num):
|
||||
out_edits = []
|
||||
# Get the edit lines from an m2 block.
|
||||
edits = sent.split("\n")
|
||||
# Loop through the edits
|
||||
for edit in edits:
|
||||
# Preprocessing
|
||||
if edit.startswith("A "):
|
||||
edit = edit[2:].split("|||") # Ignore "A " then split.
|
||||
span = edit[0].split()
|
||||
start = int(span[0])
|
||||
end = int(span[1])
|
||||
cat = edit[1]
|
||||
cor = edit[2].replace(" ", "")
|
||||
coder = int(edit[-1])
|
||||
out_edit = [start, end, cat, cor, coder]
|
||||
out_edits.append(out_edit)
|
||||
# return [edit for edit in out_edits if edit[-1] in [0,1]]
|
||||
if max_answer_num is None:
|
||||
return out_edits
|
||||
elif max_answer_num == 1:
|
||||
return [edit for edit in out_edits if edit[-1] == 0]
|
||||
elif max_answer_num == 2:
|
||||
return [edit for edit in out_edits if edit[-1] in [0, 1]]
|
||||
elif max_answer_num == 3:
|
||||
return [edit for edit in out_edits if edit[-1] in [0, 1, 2]]
|
||||
|
||||
# Input 1: A list of edits. Each edit: [start, end, cat, cor, coder]
|
||||
# Input 2: Command line args
|
||||
# Output: A dict; key is coder, value is edit dict.
|
||||
def process_edits(edits, args):
|
||||
coder_dict = {}
|
||||
# Add an explicit noop edit if there are no edits.
|
||||
if not edits: edits = [[-1, -1, "noop", "-NONE-", 0]]
|
||||
# Loop through the edits
|
||||
for edit in edits:
|
||||
# Name the edit elements for clarity
|
||||
start = edit[0]
|
||||
end = edit[1]
|
||||
cat = edit[2]
|
||||
cor = edit[3]
|
||||
coder = edit[4]
|
||||
# Add the coder to the coder_dict if necessary
|
||||
if coder not in coder_dict: coder_dict[coder] = {}
|
||||
|
||||
# Optionally apply filters based on args
|
||||
# 1. UNK type edits are only useful for detection, not correction.
|
||||
if not args.dt and not args.ds and cat == "UNK": continue
|
||||
# 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1
|
||||
if args.single and (end-start >= 2 or len(cor.split()) >= 2): continue
|
||||
# 3. Only evaluate multi token edits; i.e. 2+:n or n:2+
|
||||
if args.multi and end-start < 2 and len(cor.split()) < 2: continue
|
||||
# 4. If there is a filter, ignore the specified error types
|
||||
if args.filt and cat in args.filt: continue
|
||||
|
||||
# Token Based Detection
|
||||
if args.dt:
|
||||
# Preserve noop edits.
|
||||
if start == -1:
|
||||
if (start, start) in coder_dict[coder].keys():
|
||||
coder_dict[coder][(start, start)].append(cat)
|
||||
else:
|
||||
coder_dict[coder][(start, start)] = [cat]
|
||||
# Insertions defined as affecting the token on the right
|
||||
elif start == end and start >= 0:
|
||||
if (start, start+1) in coder_dict[coder].keys():
|
||||
coder_dict[coder][(start, start+1)].append(cat)
|
||||
else:
|
||||
coder_dict[coder][(start, start+1)] = [cat]
|
||||
# Edit spans are split for each token in the range.
|
||||
else:
|
||||
for tok_id in range(start, end):
|
||||
if (tok_id, tok_id+1) in coder_dict[coder].keys():
|
||||
coder_dict[coder][(tok_id, tok_id+1)].append(cat)
|
||||
else:
|
||||
coder_dict[coder][(tok_id, tok_id+1)] = [cat]
|
||||
|
||||
# Span Based Detection
|
||||
elif args.ds:
|
||||
if (start, end) in coder_dict[coder].keys():
|
||||
coder_dict[coder][(start, end)].append(cat)
|
||||
else:
|
||||
coder_dict[coder][(start, end)] = [cat]
|
||||
|
||||
# Span Based Correction
|
||||
else:
|
||||
# With error type classification
|
||||
if args.cse:
|
||||
if (start, end, cat, cor) in coder_dict[coder].keys():
|
||||
coder_dict[coder][(start, end, cat, cor)].append(cat)
|
||||
else:
|
||||
coder_dict[coder][(start, end, cat, cor)] = [cat]
|
||||
# Without error type classification
|
||||
else:
|
||||
if (start, end, cor) in coder_dict[coder].keys():
|
||||
coder_dict[coder][(start, end, cor)].append(cat)
|
||||
else:
|
||||
coder_dict[coder][(start, end, cor)] = [cat]
|
||||
return coder_dict
|
||||
|
||||
# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits.
|
||||
# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits.
|
||||
# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far.
|
||||
# Input 4: Sentence ID (for verbose output only)
|
||||
# Input 5: Command line args
|
||||
# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
|
||||
# Output 2: The corresponding error type dict for the above dict.
|
||||
def evaluate_edits(src, hyp_dict, ref_dict, best, sent_id, args):
|
||||
# Store the best sentence level scores and hyp+ref combination IDs
|
||||
# best_f is initialised as -1 cause 0 is a valid result.
|
||||
best_tp, best_fp, best_fn, best_f, best_hyp, best_ref = 0, 0, 0, -1, 0, 0
|
||||
best_cat = {}
|
||||
# skip not annotatable sentence
|
||||
if len(ref_dict.keys()) == 1:
|
||||
ref_id = list(ref_dict.keys())[0]
|
||||
if len(ref_dict[ref_id].keys()) == 1:
|
||||
cat = list(ref_dict[ref_id].values())[0][0]
|
||||
if cat == "NA":
|
||||
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
|
||||
return best_dict, best_cat
|
||||
|
||||
# Compare each hyp and ref combination
|
||||
for hyp_id in hyp_dict.keys():
|
||||
for ref_id in ref_dict.keys():
|
||||
# Get the local counts for the current combination.
|
||||
tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id])
|
||||
# Compute the local sentence scores (for verbose output only)
|
||||
loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta)
|
||||
# Compute the global sentence scores
|
||||
p, r, f = computeFScore(
|
||||
tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta)
|
||||
# Save the scores if they are better in terms of:
|
||||
# 1. Higher F-score
|
||||
# 2. Same F-score, higher TP
|
||||
# 3. Same F-score and TP, lower FP
|
||||
# 4. Same F-score, TP and FP, lower FN
|
||||
if (f > best_f) or \
|
||||
(f == best_f and tp > best_tp) or \
|
||||
(f == best_f and tp == best_tp and fp < best_fp) or \
|
||||
(f == best_f and tp == best_tp and fp == best_fp and fn < best_fn):
|
||||
best_tp, best_fp, best_fn = tp, fp, fn
|
||||
best_f, best_hyp, best_ref = f, hyp_id, ref_id
|
||||
best_cat = cat_dict
|
||||
# Verbose output
|
||||
if args.verbose:
|
||||
# Prepare verbose output edits.
|
||||
hyp_verb = list(sorted(hyp_dict[hyp_id].keys()))
|
||||
ref_verb = list(sorted(ref_dict[ref_id].keys()))
|
||||
# Ignore noop edits
|
||||
if not hyp_verb or hyp_verb[0][0] == -1: hyp_verb = []
|
||||
if not ref_verb or ref_verb[0][0] == -1: ref_verb = []
|
||||
# Print verbose info
|
||||
print('{:-^40}'.format(""))
|
||||
print("SENTENCE "+str(sent_id)+src[1:])
|
||||
print('{:-^40}'.format(""))
|
||||
print("SENTENCE "+str(sent_id)+" - HYP "+str(hyp_id)+" - REF "+str(ref_id))
|
||||
print("HYPOTHESIS EDITS :", hyp_verb)
|
||||
print("REFERENCE EDITS :", ref_verb)
|
||||
print("Local TP/FP/FN :", str(tp), str(fp), str(fn))
|
||||
print("Local P/R/F"+str(args.beta)+" :", str(loc_p), str(loc_r), str(loc_f))
|
||||
print("Global TP/FP/FN :", str(tp+best["tp"]), str(fp+best["fp"]), str(fn+best["fn"]))
|
||||
print("Global P/R/F"+str(args.beta)+" :", str(p), str(r), str(f))
|
||||
# Verbose output: display the best hyp+ref combination
|
||||
if args.verbose:
|
||||
print('{:-^40}'.format(""))
|
||||
print("^^ HYP "+str(best_hyp)+", REF "+str(best_ref)+" chosen for sentence "+str(sent_id))
|
||||
# Save the best TP, FP and FNs as a dict, and return this and the best_cat dict
|
||||
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
|
||||
return best_dict, best_cat
|
||||
|
||||
# Input 1: A dictionary of hypothesis edits for a single system.
|
||||
# Input 2: A dictionary of reference edits for a single annotator.
|
||||
# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator.
|
||||
# Output 4: A dictionary of the error type counts.
|
||||
def compareEdits(hyp_edits, ref_edits):
|
||||
tp = 0 # True Positives
|
||||
fp = 0 # False Positives
|
||||
fn = 0 # False Negatives
|
||||
cat_dict = {} # {cat: [tp, fp, fn], ...}
|
||||
|
||||
for h_edit, h_cats in hyp_edits.items():
|
||||
# noop hyp edits cannot be TP or FP
|
||||
if h_cats[0] == "noop": continue
|
||||
# TRUE POSITIVES
|
||||
if h_edit in ref_edits.keys():
|
||||
# On occasion, multiple tokens at same span.
|
||||
for h_cat in ref_edits[h_edit]: # Use ref dict for TP
|
||||
tp += 1
|
||||
# Each dict value [TP, FP, FN]
|
||||
if h_cat in cat_dict.keys():
|
||||
cat_dict[h_cat][0] += 1
|
||||
else:
|
||||
cat_dict[h_cat] = [1, 0, 0]
|
||||
# FALSE POSITIVES
|
||||
else:
|
||||
# On occasion, multiple tokens at same span.
|
||||
for h_cat in h_cats:
|
||||
fp += 1
|
||||
# Each dict value [TP, FP, FN]
|
||||
if h_cat in cat_dict.keys():
|
||||
cat_dict[h_cat][1] += 1
|
||||
else:
|
||||
cat_dict[h_cat] = [0, 1, 0]
|
||||
for r_edit, r_cats in ref_edits.items():
|
||||
# noop ref edits cannot be FN
|
||||
if r_cats[0] == "noop": continue
|
||||
# FALSE NEGATIVES
|
||||
if r_edit not in hyp_edits.keys():
|
||||
# On occasion, multiple tokens at same span.
|
||||
for r_cat in r_cats:
|
||||
fn += 1
|
||||
# Each dict value [TP, FP, FN]
|
||||
if r_cat in cat_dict.keys():
|
||||
cat_dict[r_cat][2] += 1
|
||||
else:
|
||||
cat_dict[r_cat] = [0, 0, 1]
|
||||
return tp, fp, fn, cat_dict
|
||||
|
||||
# Input 1-3: True positives, false positives, false negatives
|
||||
# Input 4: Value of beta in F-score.
|
||||
# Output 1-3: Precision, Recall and F-score rounded to 4dp.
|
||||
def computeFScore(tp, fp, fn, beta):
|
||||
p = float(tp)/(tp+fp) if fp else 1.0
|
||||
r = float(tp)/(tp+fn) if fn else 1.0
|
||||
f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0
|
||||
return round(p, 4), round(r, 4), round(f, 4)
|
||||
|
||||
# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN.
|
||||
# Output: The dictionaries combined with cumulative TP, FP, FN.
|
||||
def merge_dict(dict1, dict2):
|
||||
for cat, stats in dict2.items():
|
||||
if cat in dict1.keys():
|
||||
dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)]
|
||||
else:
|
||||
dict1[cat] = stats
|
||||
return dict1
|
||||
|
||||
# Input 1: A dict; key is error cat, value is counts for [tp, fp, fn]
|
||||
# Input 2: Integer value denoting level of error category granularity.
|
||||
# 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything.
|
||||
# Output: A dictionary of category TP, FP and FN based on Input 2.
|
||||
def processCategories(cat_dict, setting):
|
||||
# Otherwise, do some processing.
|
||||
proc_cat_dict = {}
|
||||
for cat, cnt in cat_dict.items():
|
||||
if cat == "UNK":
|
||||
proc_cat_dict[cat] = cnt
|
||||
continue
|
||||
# M, U, R or UNK combined only.
|
||||
if setting == 1:
|
||||
if cat[0] in proc_cat_dict.keys():
|
||||
proc_cat_dict[cat[0]] = [x+y for x, y in zip(proc_cat_dict[cat[0]], cnt)]
|
||||
else:
|
||||
proc_cat_dict[cat[0]] = cnt
|
||||
# Everything without M, U or R.
|
||||
elif setting == 2:
|
||||
if cat[2:] in proc_cat_dict.keys():
|
||||
proc_cat_dict[cat[2:]] = [x+y for x, y in zip(proc_cat_dict[cat[2:]], cnt)]
|
||||
else:
|
||||
proc_cat_dict[cat[2:]] = cnt
|
||||
# All error category combinations
|
||||
else:
|
||||
return cat_dict
|
||||
return proc_cat_dict
|
||||
|
||||
# Input 1: A dict of global best TP, FP and FNs
|
||||
# Input 2: A dict of error types and counts for those TP, FP and FNs
|
||||
# Input 3: Command line args
|
||||
def print_results(best, best_cats, args):
|
||||
# Prepare output title.
|
||||
if args.dt: title = " Token-Based Detection "
|
||||
elif args.ds: title = " Span-Based Detection "
|
||||
elif args.cse: title = " Span-Based Correction + Classification "
|
||||
else: title = " Span-Based Correction "
|
||||
|
||||
# Category Scores
|
||||
if args.cat:
|
||||
best_cats = processCategories(best_cats, args.cat)
|
||||
print("")
|
||||
print('{:=^66}'.format(title))
|
||||
print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8),
|
||||
"P".ljust(8), "R".ljust(8), "F"+str(args.beta))
|
||||
for cat, cnts in sorted(best_cats.items()):
|
||||
cat_p, cat_r, cat_f = computeFScore(cnts[0], cnts[1], cnts[2], args.beta)
|
||||
print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8),
|
||||
str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f)
|
||||
|
||||
# Print the overall results.
|
||||
print("")
|
||||
print('{:=^46}'.format(title))
|
||||
print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)]))
|
||||
print("\t".join(map(str, [best["tp"], best["fp"],
|
||||
best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta)))))
|
||||
print('{:=^46}'.format(""))
|
||||
print("")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the program
|
||||
main()
|
||||
import argparse
|
||||
from collections import Counter
|
||||
|
||||
def main():
|
||||
# Parse command line args
|
||||
args = parse_args()
|
||||
# Open hypothesis and reference m2 files and split into chunks
|
||||
hyp_m2 = open(args.hyp).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.hyp).read().strip().split("\n\n")
|
||||
ref_m2 = open(args.ref).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.ref).read().strip().split("\n\n")
|
||||
# Make sure they have the same number of sentences
|
||||
assert len(hyp_m2) == len(ref_m2), print(len(hyp_m2), len(ref_m2))
|
||||
|
||||
# Store global corpus level best counts here
|
||||
best_dict = Counter({"tp":0, "fp":0, "fn":0})
|
||||
best_cats = {}
|
||||
# Process each sentence
|
||||
sents = zip(hyp_m2, ref_m2)
|
||||
for sent_id, sent in enumerate(sents):
|
||||
# Simplify the edits into lists of lists
|
||||
# if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons:
|
||||
# sent_id_cons.append(sent_id)
|
||||
src = sent[0].split("\n")[0]
|
||||
hyp_edits = simplify_edits(sent[0], args.max_answer_num)
|
||||
ref_edits = simplify_edits(sent[1], args.max_answer_num)
|
||||
# Process the edits for detection/correction based on args
|
||||
hyp_dict = process_edits(hyp_edits, args)
|
||||
ref_dict = process_edits(ref_edits, args)
|
||||
if args.reference_num is None or len(ref_dict.keys()) == args.reference_num:
|
||||
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
|
||||
count_dict, cat_dict = evaluate_edits(src,
|
||||
hyp_dict, ref_dict, best_dict, sent_id, args)
|
||||
# Merge these dicts with best_dict and best_cats
|
||||
best_dict += Counter(count_dict)
|
||||
best_cats = merge_dict(best_cats, cat_dict)
|
||||
# Print results
|
||||
print_results(best_dict, best_cats, args)
|
||||
|
||||
# Parse command line args
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Calculate F-scores for error detection and/or correction.\n"
|
||||
"Flags let you evaluate at different levels of granularity.",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
usage="%(prog)s [options] -hyp HYP -ref REF")
|
||||
parser.add_argument(
|
||||
"-hyp",
|
||||
help="A hypothesis M2 file.",
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
"-ref",
|
||||
help="A reference M2 file.",
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=int,
|
||||
default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--end",
|
||||
type=int,
|
||||
default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_answer_num",
|
||||
type=int,
|
||||
default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference_num",
|
||||
type=int,
|
||||
default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--beta",
|
||||
help="Value of beta in F-score. (default: 0.5)",
|
||||
default=0.5,
|
||||
type=float)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
help="Print verbose output.",
|
||||
action="store_true")
|
||||
eval_type = parser.add_mutually_exclusive_group()
|
||||
eval_type.add_argument(
|
||||
"-dt",
|
||||
help="Evaluate Detection in terms of Tokens.",
|
||||
action="store_true")
|
||||
eval_type.add_argument(
|
||||
"-ds",
|
||||
help="Evaluate Detection in terms of Spans.",
|
||||
action="store_true")
|
||||
eval_type.add_argument(
|
||||
"-cs",
|
||||
help="Evaluate Correction in terms of Spans. (default)",
|
||||
action="store_true")
|
||||
eval_type.add_argument(
|
||||
"-cse",
|
||||
help="Evaluate Correction in terms of Spans and Error types.",
|
||||
action="store_true")
|
||||
parser.add_argument(
|
||||
"-single",
|
||||
help="Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1",
|
||||
action="store_true")
|
||||
parser.add_argument(
|
||||
"-multi",
|
||||
help="Only evaluate multi token edits; i.e. 2+:n or n:2+",
|
||||
action="store_true")
|
||||
parser.add_argument(
|
||||
"-multi_hyp_avg",
|
||||
help="When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence.",
|
||||
action="store_true") # For IAA calculation
|
||||
parser.add_argument(
|
||||
"-multi_hyp_max",
|
||||
help="When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence.",
|
||||
action="store_true") # For multiple hypotheses system evaluation
|
||||
parser.add_argument(
|
||||
"-filt",
|
||||
help="Do not evaluate the specified error types.",
|
||||
nargs="+",
|
||||
default=[])
|
||||
parser.add_argument(
|
||||
"-cat",
|
||||
help="Show error category scores.\n"
|
||||
"1: Only show operation tier scores; e.g. R.\n"
|
||||
"2: Only show main tier scores; e.g. NOUN.\n"
|
||||
"3: Show all category scores; e.g. R:NOUN.",
|
||||
choices=[1, 2, 3],
|
||||
type=int)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
# Input: An m2 format sentence with edits.
|
||||
# Output: A list of lists. Each edit: [start, end, cat, cor, coder]
|
||||
def simplify_edits(sent, max_answer_num):
|
||||
out_edits = []
|
||||
# Get the edit lines from an m2 block.
|
||||
edits = sent.split("\n")
|
||||
# Loop through the edits
|
||||
for edit in edits:
|
||||
# Preprocessing
|
||||
if edit.startswith("A "):
|
||||
edit = edit[2:].split("|||") # Ignore "A " then split.
|
||||
span = edit[0].split()
|
||||
start = int(span[0])
|
||||
end = int(span[1])
|
||||
cat = edit[1]
|
||||
cor = edit[2].replace(" ", "")
|
||||
coder = int(edit[-1])
|
||||
out_edit = [start, end, cat, cor, coder]
|
||||
out_edits.append(out_edit)
|
||||
# return [edit for edit in out_edits if edit[-1] in [0,1]]
|
||||
if max_answer_num is None:
|
||||
return out_edits
|
||||
elif max_answer_num == 1:
|
||||
return [edit for edit in out_edits if edit[-1] == 0]
|
||||
elif max_answer_num == 2:
|
||||
return [edit for edit in out_edits if edit[-1] in [0, 1]]
|
||||
elif max_answer_num == 3:
|
||||
return [edit for edit in out_edits if edit[-1] in [0, 1, 2]]
|
||||
|
||||
# Input 1: A list of edits. Each edit: [start, end, cat, cor, coder]
|
||||
# Input 2: Command line args
|
||||
# Output: A dict; key is coder, value is edit dict.
|
||||
def process_edits(edits, args):
|
||||
coder_dict = {}
|
||||
# Add an explicit noop edit if there are no edits.
|
||||
if not edits: edits = [[-1, -1, "noop", "-NONE-", 0]]
|
||||
# Loop through the edits
|
||||
for edit in edits:
|
||||
# Name the edit elements for clarity
|
||||
start = edit[0]
|
||||
end = edit[1]
|
||||
cat = edit[2]
|
||||
cor = edit[3]
|
||||
coder = edit[4]
|
||||
# Add the coder to the coder_dict if necessary
|
||||
if coder not in coder_dict: coder_dict[coder] = {}
|
||||
|
||||
# Optionally apply filters based on args
|
||||
# 1. UNK type edits are only useful for detection, not correction.
|
||||
if not args.dt and not args.ds and cat == "UNK": continue
|
||||
# 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1
|
||||
if args.single and (end-start >= 2 or len(cor.split()) >= 2): continue
|
||||
# 3. Only evaluate multi token edits; i.e. 2+:n or n:2+
|
||||
if args.multi and end-start < 2 and len(cor.split()) < 2: continue
|
||||
# 4. If there is a filter, ignore the specified error types
|
||||
if args.filt and cat in args.filt: continue
|
||||
|
||||
# Token Based Detection
|
||||
if args.dt:
|
||||
# Preserve noop edits.
|
||||
if start == -1:
|
||||
if (start, start) in coder_dict[coder].keys():
|
||||
coder_dict[coder][(start, start)].append(cat)
|
||||
else:
|
||||
coder_dict[coder][(start, start)] = [cat]
|
||||
# Insertions defined as affecting the token on the right
|
||||
elif start == end and start >= 0:
|
||||
if (start, start+1) in coder_dict[coder].keys():
|
||||
coder_dict[coder][(start, start+1)].append(cat)
|
||||
else:
|
||||
coder_dict[coder][(start, start+1)] = [cat]
|
||||
# Edit spans are split for each token in the range.
|
||||
else:
|
||||
for tok_id in range(start, end):
|
||||
if (tok_id, tok_id+1) in coder_dict[coder].keys():
|
||||
coder_dict[coder][(tok_id, tok_id+1)].append(cat)
|
||||
else:
|
||||
coder_dict[coder][(tok_id, tok_id+1)] = [cat]
|
||||
|
||||
# Span Based Detection
|
||||
elif args.ds:
|
||||
if (start, end) in coder_dict[coder].keys():
|
||||
coder_dict[coder][(start, end)].append(cat)
|
||||
else:
|
||||
coder_dict[coder][(start, end)] = [cat]
|
||||
|
||||
# Span Based Correction
|
||||
else:
|
||||
# With error type classification
|
||||
if args.cse:
|
||||
if (start, end, cat, cor) in coder_dict[coder].keys():
|
||||
coder_dict[coder][(start, end, cat, cor)].append(cat)
|
||||
else:
|
||||
coder_dict[coder][(start, end, cat, cor)] = [cat]
|
||||
# Without error type classification
|
||||
else:
|
||||
if (start, end, cor) in coder_dict[coder].keys():
|
||||
coder_dict[coder][(start, end, cor)].append(cat)
|
||||
else:
|
||||
coder_dict[coder][(start, end, cor)] = [cat]
|
||||
return coder_dict
|
||||
|
||||
# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits.
|
||||
# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits.
|
||||
# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far.
|
||||
# Input 4: Sentence ID (for verbose output only)
|
||||
# Input 5: Command line args
|
||||
# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
|
||||
# Output 2: The corresponding error type dict for the above dict.
|
||||
def evaluate_edits(src, hyp_dict, ref_dict, best, sent_id, args):
|
||||
# Store the best sentence level scores and hyp+ref combination IDs
|
||||
# best_f is initialised as -1 cause 0 is a valid result.
|
||||
best_tp, best_fp, best_fn, best_f, best_hyp, best_ref = 0, 0, 0, -1, 0, 0
|
||||
best_cat = {}
|
||||
# skip not annotatable sentence
|
||||
if len(ref_dict.keys()) == 1:
|
||||
ref_id = list(ref_dict.keys())[0]
|
||||
if len(ref_dict[ref_id].keys()) == 1:
|
||||
cat = list(ref_dict[ref_id].values())[0][0]
|
||||
if cat == "NA":
|
||||
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
|
||||
return best_dict, best_cat
|
||||
|
||||
# Compare each hyp and ref combination
|
||||
for hyp_id in hyp_dict.keys():
|
||||
for ref_id in ref_dict.keys():
|
||||
# Get the local counts for the current combination.
|
||||
tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id])
|
||||
# Compute the local sentence scores (for verbose output only)
|
||||
loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta)
|
||||
# Compute the global sentence scores
|
||||
p, r, f = computeFScore(
|
||||
tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta)
|
||||
# Save the scores if they are better in terms of:
|
||||
# 1. Higher F-score
|
||||
# 2. Same F-score, higher TP
|
||||
# 3. Same F-score and TP, lower FP
|
||||
# 4. Same F-score, TP and FP, lower FN
|
||||
if (f > best_f) or \
|
||||
(f == best_f and tp > best_tp) or \
|
||||
(f == best_f and tp == best_tp and fp < best_fp) or \
|
||||
(f == best_f and tp == best_tp and fp == best_fp and fn < best_fn):
|
||||
best_tp, best_fp, best_fn = tp, fp, fn
|
||||
best_f, best_hyp, best_ref = f, hyp_id, ref_id
|
||||
best_cat = cat_dict
|
||||
# Verbose output
|
||||
if args.verbose:
|
||||
# Prepare verbose output edits.
|
||||
hyp_verb = list(sorted(hyp_dict[hyp_id].keys()))
|
||||
ref_verb = list(sorted(ref_dict[ref_id].keys()))
|
||||
# Ignore noop edits
|
||||
if not hyp_verb or hyp_verb[0][0] == -1: hyp_verb = []
|
||||
if not ref_verb or ref_verb[0][0] == -1: ref_verb = []
|
||||
# Print verbose info
|
||||
print('{:-^40}'.format(""))
|
||||
print("SENTENCE "+str(sent_id)+src[1:])
|
||||
print('{:-^40}'.format(""))
|
||||
print("SENTENCE "+str(sent_id)+" - HYP "+str(hyp_id)+" - REF "+str(ref_id))
|
||||
print("HYPOTHESIS EDITS :", hyp_verb)
|
||||
print("REFERENCE EDITS :", ref_verb)
|
||||
print("Local TP/FP/FN :", str(tp), str(fp), str(fn))
|
||||
print("Local P/R/F"+str(args.beta)+" :", str(loc_p), str(loc_r), str(loc_f))
|
||||
print("Global TP/FP/FN :", str(tp+best["tp"]), str(fp+best["fp"]), str(fn+best["fn"]))
|
||||
print("Global P/R/F"+str(args.beta)+" :", str(p), str(r), str(f))
|
||||
# Verbose output: display the best hyp+ref combination
|
||||
if args.verbose:
|
||||
print('{:-^40}'.format(""))
|
||||
print("^^ HYP "+str(best_hyp)+", REF "+str(best_ref)+" chosen for sentence "+str(sent_id))
|
||||
# Save the best TP, FP and FNs as a dict, and return this and the best_cat dict
|
||||
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
|
||||
return best_dict, best_cat
|
||||
|
||||
# Input 1: A dictionary of hypothesis edits for a single system.
|
||||
# Input 2: A dictionary of reference edits for a single annotator.
|
||||
# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator.
|
||||
# Output 4: A dictionary of the error type counts.
|
||||
def compareEdits(hyp_edits, ref_edits):
|
||||
tp = 0 # True Positives
|
||||
fp = 0 # False Positives
|
||||
fn = 0 # False Negatives
|
||||
cat_dict = {} # {cat: [tp, fp, fn], ...}
|
||||
|
||||
for h_edit, h_cats in hyp_edits.items():
|
||||
# noop hyp edits cannot be TP or FP
|
||||
if h_cats[0] == "noop": continue
|
||||
# TRUE POSITIVES
|
||||
if h_edit in ref_edits.keys():
|
||||
# On occasion, multiple tokens at same span.
|
||||
for h_cat in ref_edits[h_edit]: # Use ref dict for TP
|
||||
tp += 1
|
||||
# Each dict value [TP, FP, FN]
|
||||
if h_cat in cat_dict.keys():
|
||||
cat_dict[h_cat][0] += 1
|
||||
else:
|
||||
cat_dict[h_cat] = [1, 0, 0]
|
||||
# FALSE POSITIVES
|
||||
else:
|
||||
# On occasion, multiple tokens at same span.
|
||||
for h_cat in h_cats:
|
||||
fp += 1
|
||||
# Each dict value [TP, FP, FN]
|
||||
if h_cat in cat_dict.keys():
|
||||
cat_dict[h_cat][1] += 1
|
||||
else:
|
||||
cat_dict[h_cat] = [0, 1, 0]
|
||||
for r_edit, r_cats in ref_edits.items():
|
||||
# noop ref edits cannot be FN
|
||||
if r_cats[0] == "noop": continue
|
||||
# FALSE NEGATIVES
|
||||
if r_edit not in hyp_edits.keys():
|
||||
# On occasion, multiple tokens at same span.
|
||||
for r_cat in r_cats:
|
||||
fn += 1
|
||||
# Each dict value [TP, FP, FN]
|
||||
if r_cat in cat_dict.keys():
|
||||
cat_dict[r_cat][2] += 1
|
||||
else:
|
||||
cat_dict[r_cat] = [0, 0, 1]
|
||||
return tp, fp, fn, cat_dict
|
||||
|
||||
# Input 1-3: True positives, false positives, false negatives
|
||||
# Input 4: Value of beta in F-score.
|
||||
# Output 1-3: Precision, Recall and F-score rounded to 4dp.
|
||||
def computeFScore(tp, fp, fn, beta):
|
||||
p = float(tp)/(tp+fp) if fp else 1.0
|
||||
r = float(tp)/(tp+fn) if fn else 1.0
|
||||
f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0
|
||||
return round(p, 4), round(r, 4), round(f, 4)
|
||||
|
||||
# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN.
|
||||
# Output: The dictionaries combined with cumulative TP, FP, FN.
|
||||
def merge_dict(dict1, dict2):
|
||||
for cat, stats in dict2.items():
|
||||
if cat in dict1.keys():
|
||||
dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)]
|
||||
else:
|
||||
dict1[cat] = stats
|
||||
return dict1
|
||||
|
||||
# Input 1: A dict; key is error cat, value is counts for [tp, fp, fn]
|
||||
# Input 2: Integer value denoting level of error category granularity.
|
||||
# 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything.
|
||||
# Output: A dictionary of category TP, FP and FN based on Input 2.
|
||||
def processCategories(cat_dict, setting):
|
||||
# Otherwise, do some processing.
|
||||
proc_cat_dict = {}
|
||||
for cat, cnt in cat_dict.items():
|
||||
if cat == "UNK":
|
||||
proc_cat_dict[cat] = cnt
|
||||
continue
|
||||
# M, U, R or UNK combined only.
|
||||
if setting == 1:
|
||||
if cat[0] in proc_cat_dict.keys():
|
||||
proc_cat_dict[cat[0]] = [x+y for x, y in zip(proc_cat_dict[cat[0]], cnt)]
|
||||
else:
|
||||
proc_cat_dict[cat[0]] = cnt
|
||||
# Everything without M, U or R.
|
||||
elif setting == 2:
|
||||
if cat[2:] in proc_cat_dict.keys():
|
||||
proc_cat_dict[cat[2:]] = [x+y for x, y in zip(proc_cat_dict[cat[2:]], cnt)]
|
||||
else:
|
||||
proc_cat_dict[cat[2:]] = cnt
|
||||
# All error category combinations
|
||||
else:
|
||||
return cat_dict
|
||||
return proc_cat_dict
|
||||
|
||||
# Input 1: A dict of global best TP, FP and FNs
|
||||
# Input 2: A dict of error types and counts for those TP, FP and FNs
|
||||
# Input 3: Command line args
|
||||
def print_results(best, best_cats, args):
|
||||
# Prepare output title.
|
||||
if args.dt: title = " Token-Based Detection "
|
||||
elif args.ds: title = " Span-Based Detection "
|
||||
elif args.cse: title = " Span-Based Correction + Classification "
|
||||
else: title = " Span-Based Correction "
|
||||
|
||||
# Category Scores
|
||||
if args.cat:
|
||||
best_cats = processCategories(best_cats, args.cat)
|
||||
print("")
|
||||
print('{:=^66}'.format(title))
|
||||
print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8),
|
||||
"P".ljust(8), "R".ljust(8), "F"+str(args.beta))
|
||||
for cat, cnts in sorted(best_cats.items()):
|
||||
cat_p, cat_r, cat_f = computeFScore(cnts[0], cnts[1], cnts[2], args.beta)
|
||||
print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8),
|
||||
str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f)
|
||||
|
||||
# Print the overall results.
|
||||
print("")
|
||||
print('{:=^46}'.format(title))
|
||||
print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)]))
|
||||
print("\t".join(map(str, [best["tp"], best["fp"],
|
||||
best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta)))))
|
||||
print('{:=^46}'.format(""))
|
||||
print("")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the program
|
||||
main()
|
||||
|
@ -1,49 +1,49 @@
|
||||
from rouge_chinese import Rouge
|
||||
import jieba
|
||||
from nltk.translate.gleu_score import corpus_gleu
|
||||
|
||||
def compute_f1_two_sets(pred_set, gt_set):
|
||||
precision = len(pred_set.intersection(gt_set)) / len(pred_set) if len(pred_set) > 0 else 0
|
||||
recall = len(pred_set.intersection(gt_set)) / len(gt_set) if len(gt_set) > 0 else 0
|
||||
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
|
||||
return f1
|
||||
|
||||
def multi_choice_judge(prediction, option_list, answer_token):
|
||||
# a dict, key: letters in the option list, value: count of the letter in the prediction
|
||||
count_dict, abstention, accuracy = {}, 0, 0
|
||||
for option in option_list:
|
||||
option_count = prediction.count(option)
|
||||
count_dict[option] = 1 if option_count > 0 else 0 # multiple occurrence of the same letter is counted as 1
|
||||
|
||||
if sum(count_dict.values()) == 0:
|
||||
abstention = 1
|
||||
# if the answer token is the only predicted token, the prediction is correct
|
||||
elif count_dict[answer_token] == 1 and sum(count_dict.values()) == 1:
|
||||
accuracy = 1
|
||||
return {"score": accuracy, "abstention": abstention}
|
||||
|
||||
"""
|
||||
compute the rouge score.
|
||||
hyps and refs are lists of hyposisis and reference strings
|
||||
empty predictions are replaces with 无内容
|
||||
"""
|
||||
|
||||
|
||||
def compute_rouge(hyps, refs):
|
||||
assert(len(hyps) == len(refs))
|
||||
hyps = [' '.join(jieba.cut(h)) for h in hyps]
|
||||
hyps = [h if h.strip() != "" else "无内容" for h in hyps]
|
||||
refs = [' '.join(jieba.cut(r)) for r in refs]
|
||||
return Rouge().get_scores(hyps, refs)
|
||||
|
||||
"""
|
||||
compute the gleu score.
|
||||
hyps and refs are lists of hyposisis and reference strings
|
||||
empty predictions are replaces with 无内容
|
||||
"""
|
||||
def compute_gleu(hyps, refs):
|
||||
assert(len(hyps) == len(refs))
|
||||
hyps = [' '.join(jieba.cut(h)) for h in hyps]
|
||||
hyps = [h if h.strip() != "" else "无内容" for h in hyps]
|
||||
refs = [[' '.join(jieba.cut(r))] for r in refs]
|
||||
return corpus_gleu(refs, hyps)
|
||||
from rouge_chinese import Rouge
|
||||
import jieba
|
||||
from nltk.translate.gleu_score import corpus_gleu
|
||||
|
||||
def compute_f1_two_sets(pred_set, gt_set):
|
||||
precision = len(pred_set.intersection(gt_set)) / len(pred_set) if len(pred_set) > 0 else 0
|
||||
recall = len(pred_set.intersection(gt_set)) / len(gt_set) if len(gt_set) > 0 else 0
|
||||
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
|
||||
return f1
|
||||
|
||||
def multi_choice_judge(prediction, option_list, answer_token):
|
||||
# a dict, key: letters in the option list, value: count of the letter in the prediction
|
||||
count_dict, abstention, accuracy = {}, 0, 0
|
||||
for option in option_list:
|
||||
option_count = prediction.count(option)
|
||||
count_dict[option] = 1 if option_count > 0 else 0 # multiple occurrence of the same letter is counted as 1
|
||||
|
||||
if sum(count_dict.values()) == 0:
|
||||
abstention = 1
|
||||
# if the answer token is the only predicted token, the prediction is correct
|
||||
elif count_dict[answer_token] == 1 and sum(count_dict.values()) == 1:
|
||||
accuracy = 1
|
||||
return {"score": accuracy, "abstention": abstention}
|
||||
|
||||
"""
|
||||
compute the rouge score.
|
||||
hyps and refs are lists of hyposisis and reference strings
|
||||
empty predictions are replaces with 无内容
|
||||
"""
|
||||
|
||||
|
||||
def compute_rouge(hyps, refs):
|
||||
assert(len(hyps) == len(refs))
|
||||
hyps = [' '.join(jieba.cut(h)) for h in hyps]
|
||||
hyps = [h if h.strip() != "" else "无内容" for h in hyps]
|
||||
refs = [' '.join(jieba.cut(r)) for r in refs]
|
||||
return Rouge().get_scores(hyps, refs)
|
||||
|
||||
"""
|
||||
compute the gleu score.
|
||||
hyps and refs are lists of hyposisis and reference strings
|
||||
empty predictions are replaces with 无内容
|
||||
"""
|
||||
def compute_gleu(hyps, refs):
|
||||
assert(len(hyps) == len(refs))
|
||||
hyps = [' '.join(jieba.cut(h)) for h in hyps]
|
||||
hyps = [h if h.strip() != "" else "无内容" for h in hyps]
|
||||
refs = [[' '.join(jieba.cut(r))] for r in refs]
|
||||
return corpus_gleu(refs, hyps)
|
||||
|
@ -1,334 +1,334 @@
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Dict
|
||||
from modules.tokenizer import Tokenizer
|
||||
import os
|
||||
from string import punctuation
|
||||
|
||||
REAL_PATH = os.path.split(os.path.realpath(__file__))[0]
|
||||
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏"
|
||||
english_punct = punctuation
|
||||
punct = chinese_punct + english_punct
|
||||
|
||||
def check_all_chinese(word):
|
||||
"""
|
||||
判断一个单词是否全部由中文组成
|
||||
:param word:
|
||||
:return:
|
||||
"""
|
||||
return all(['\u4e00' <= ch <= '\u9fff' for ch in word])
|
||||
|
||||
def read_cilin():
|
||||
"""
|
||||
Cilin 詞林 is a thesaurus with semantic information
|
||||
"""
|
||||
# TODO -- fix this path
|
||||
project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path
|
||||
lines = open(os.path.join(project_dir, "data", "cilin.txt"), "r", encoding="gbk").read().strip().split("\n")
|
||||
semantic_dict = {}
|
||||
semantic_classes = {}
|
||||
for line in lines:
|
||||
code, *words = line.split(" ")
|
||||
for word in words:
|
||||
semantic_dict[word] = code
|
||||
# make reverse dict
|
||||
if code in semantic_classes:
|
||||
semantic_classes[code] += words
|
||||
else:
|
||||
semantic_classes[code] = words
|
||||
return semantic_dict, semantic_classes
|
||||
|
||||
|
||||
def read_confusion():
|
||||
confusion_dict = {}
|
||||
project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path
|
||||
with open(os.path.join(project_dir, "data", "confusion_dict.txt"), "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
li = line.rstrip('\n').split(" ")
|
||||
confusion_dict[li[0]] = li[1:]
|
||||
return confusion_dict
|
||||
|
||||
class Alignment:
|
||||
"""
|
||||
对齐错误句子和正确句子,
|
||||
使用编辑距离算法抽取编辑操作
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
semantic_dict: Dict,
|
||||
confusion_dict: Dict,
|
||||
granularity: str = "word",
|
||||
) -> None:
|
||||
"""
|
||||
构造函数
|
||||
:param semantic_dict: 语义词典(大词林)
|
||||
:param confusion_dict: 字符混淆集
|
||||
"""
|
||||
self.insertion_cost = 1
|
||||
self.deletion_cost = 1
|
||||
self.semantic_dict = semantic_dict
|
||||
self.confusion_dict = confusion_dict
|
||||
# Because we use character level tokenization, this doesn't currently use POS
|
||||
self._open_pos = {} # 如果是词级别,还可以利用词性是否相同来计算cost
|
||||
self.granularity = granularity # word-level or character-level
|
||||
self.align_seqs = []
|
||||
|
||||
def __call__(self,
|
||||
src: List[Tuple],
|
||||
tgt: List[Tuple],
|
||||
verbose: bool = False):
|
||||
cost_matrix, oper_matrix = self.align(src, tgt)
|
||||
align_seq = self.get_cheapest_align_seq(oper_matrix)
|
||||
|
||||
if verbose:
|
||||
print("========== Seg. and POS: ==========")
|
||||
print(src)
|
||||
print(tgt)
|
||||
print("========== Cost Matrix ==========")
|
||||
print(cost_matrix)
|
||||
print("========== Oper Matrix ==========")
|
||||
print(oper_matrix)
|
||||
print("========== Alignment ==========")
|
||||
print(align_seq)
|
||||
print("========== Results ==========")
|
||||
for a in align_seq:
|
||||
print(a[0], src[a[1]: a[2]], tgt[a[3]: a[4]])
|
||||
return align_seq
|
||||
|
||||
def _get_semantic_class(self, word):
|
||||
"""
|
||||
NOTE: Based on the paper:
|
||||
Improved-Edit-Distance Kernel for Chinese Relation Extraction
|
||||
获取每个词语的语义类别(基于大词林,有三个级别)
|
||||
"""
|
||||
if word in self.semantic_dict:
|
||||
code = self.semantic_dict[word]
|
||||
high, mid, low = code[0], code[1], code[2:4]
|
||||
return high, mid, low
|
||||
else: # unknown
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_class_diff(a_class, b_class):
|
||||
"""
|
||||
d == 3 for equivalent semantics
|
||||
d == 0 for completely different semantics
|
||||
根据大词林的信息,计算两个词的语义类别的差距
|
||||
"""
|
||||
d = sum([a == b for a, b in zip(a_class, b_class)])
|
||||
return d
|
||||
|
||||
def _get_semantic_cost(self, a, b):
|
||||
"""
|
||||
计算基于语义信息的替换操作cost
|
||||
:param a: 单词a的语义类别
|
||||
:param b: 单词b的语义类别
|
||||
:return: 替换编辑代价
|
||||
"""
|
||||
a_class = self._get_semantic_class(a)
|
||||
b_class = self._get_semantic_class(b)
|
||||
# unknown class, default to 1
|
||||
if a_class is None or b_class is None:
|
||||
return 4
|
||||
elif a_class == b_class:
|
||||
return 0
|
||||
else:
|
||||
return 2 * (3 - self._get_class_diff(a_class, b_class))
|
||||
|
||||
def _get_pos_cost(self, a_pos, b_pos):
|
||||
"""
|
||||
计算基于词性信息的编辑距离cost
|
||||
:param a_pos: 单词a的词性
|
||||
:param b_pos: 单词b的词性
|
||||
:return: 替换编辑代价
|
||||
"""
|
||||
if a_pos == b_pos:
|
||||
return 0
|
||||
elif a_pos in self._open_pos and b_pos in self._open_pos:
|
||||
return 0.25
|
||||
else:
|
||||
return 0.499
|
||||
|
||||
def _get_char_cost(self, a, b, pinyin_a, pinyin_b):
|
||||
"""
|
||||
NOTE: This is a replacement of ERRANTS lemma cost for Chinese
|
||||
计算基于字符相似度的编辑距离cost
|
||||
"""
|
||||
if not (check_all_chinese(a) and check_all_chinese(b)):
|
||||
return 0.5
|
||||
if len(a) > len(b):
|
||||
a, b = b, a
|
||||
pinyin_a, pinyin_b = pinyin_b, pinyin_a
|
||||
if a == b:
|
||||
return 0
|
||||
else:
|
||||
return self._get_spell_cost(a, b, pinyin_a, pinyin_b)
|
||||
|
||||
def _get_spell_cost(self, a, b, pinyin_a, pinyin_b):
|
||||
"""
|
||||
计算两个单词拼写相似度,分别由字形相似度和字音相似度组成
|
||||
:param a: 单词a
|
||||
:param b: 单词b,且单词a的长度小于等于b
|
||||
:param pinyin_a: 单词a的拼音
|
||||
:param pinyin_b: 单词b的拼音
|
||||
:return: 替换操作cost
|
||||
"""
|
||||
count = 0
|
||||
for i in range(len(a)):
|
||||
for j in range(len(b)):
|
||||
if a[i] == b[j] or (set(pinyin_a) & set(pinyin_b)) or (b[j] in self.confusion_dict.keys() and a[i] in self.confusion_dict[b[j]]) or (a[i] in self.confusion_dict.keys() and b[j] in self.confusion_dict[a[i]]):
|
||||
count += 1
|
||||
break
|
||||
return (len(a) - count) / (len(a) * 2)
|
||||
|
||||
def get_sub_cost(self, a_seg, b_seg):
|
||||
"""
|
||||
Calculate the substitution cost between words a and b
|
||||
计算两个单词替换操作的编辑cost,最大为2,等于一次删除和一次添加
|
||||
"""
|
||||
if a_seg[0] == b_seg[0]:
|
||||
return 0
|
||||
|
||||
if self.granularity == "word": # 词级别可以额外利用词性信息
|
||||
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
|
||||
pos_cost = self._get_pos_cost(a_seg[1], b_seg[1])
|
||||
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
|
||||
return semantic_cost + pos_cost + char_cost
|
||||
else: # 字级别只能利用字义信息(从大词林中获取)和字面相似度信息
|
||||
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
|
||||
if a_seg[0] in punct and b_seg[0] in punct:
|
||||
pos_cost = 0.0
|
||||
elif a_seg[0] not in punct and b_seg[0] not in punct:
|
||||
pos_cost = 0.25
|
||||
else:
|
||||
pos_cost = 0.499
|
||||
# pos_cost = 0.0 if (a_seg[0] in punct and b_seg[0] in punct) or (a_seg[0] not in punct and b_seg[0] not in punct) else 0.5
|
||||
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
|
||||
return semantic_cost + char_cost + pos_cost
|
||||
|
||||
def align(self,
|
||||
src: List[Tuple],
|
||||
tgt: List[Tuple]):
|
||||
"""
|
||||
Based on ERRANT's alignment
|
||||
基于改进的动态规划算法,为原句子的每个字打上编辑标签,以便使它能够成功转换为目标句子。
|
||||
编辑操作类别:
|
||||
1) M:Match,即KEEP,即当前字保持不变
|
||||
2) D:Delete,删除,即当前字需要被删除
|
||||
3) I:Insert,插入,即当前字需要被插入
|
||||
4) T:Transposition,移位操作,即涉及到词序问题
|
||||
"""
|
||||
cost_matrix = np.zeros((len(src) + 1, len(tgt) + 1)) # 编辑cost矩阵
|
||||
oper_matrix = np.full(
|
||||
(len(src) + 1, len(tgt) + 1), "O", dtype=object
|
||||
) # 操作矩阵
|
||||
# Fill in the edges
|
||||
for i in range(1, len(src) + 1):
|
||||
cost_matrix[i][0] = cost_matrix[i - 1][0] + 1
|
||||
oper_matrix[i][0] = ["D"]
|
||||
for j in range(1, len(tgt) + 1):
|
||||
cost_matrix[0][j] = cost_matrix[0][j - 1] + 1
|
||||
oper_matrix[0][j] = ["I"]
|
||||
|
||||
# Loop through the cost matrix
|
||||
for i in range(len(src)):
|
||||
for j in range(len(tgt)):
|
||||
# Matches
|
||||
if src[i][0] == tgt[j][0]: # 如果两个字相等,则匹配成功(Match),编辑距离为0
|
||||
cost_matrix[i + 1][j + 1] = cost_matrix[i][j]
|
||||
oper_matrix[i + 1][j + 1] = ["M"]
|
||||
# Non-matches
|
||||
else:
|
||||
del_cost = cost_matrix[i][j + 1] + self.deletion_cost # 由删除动作得到的总cost
|
||||
ins_cost = cost_matrix[i + 1][j] + self.insertion_cost # 由插入动作得到的总cost
|
||||
sub_cost = cost_matrix[i][j] + self.get_sub_cost(
|
||||
src[i], tgt[j]
|
||||
) # 由替换动作得到的总cost
|
||||
# Calculate transposition cost
|
||||
# 计算移位操作的总cost
|
||||
trans_cost = float("inf")
|
||||
k = 1
|
||||
while (
|
||||
i - k >= 0
|
||||
and j - k >= 0
|
||||
and cost_matrix[i - k + 1][j - k + 1]
|
||||
!= cost_matrix[i - k][j - k]
|
||||
):
|
||||
p1 = sorted([a[0] for a in src][i - k: i + 1])
|
||||
p2 = sorted([b[0] for b in tgt][j - k: j + 1])
|
||||
if p1 == p2:
|
||||
trans_cost = cost_matrix[i - k][j - k] + k
|
||||
break
|
||||
k += 1
|
||||
|
||||
costs = [trans_cost, sub_cost, ins_cost, del_cost]
|
||||
ind = costs.index(min(costs))
|
||||
cost_matrix[i + 1][j + 1] = costs[ind]
|
||||
# ind = costs.index(costs[ind], ind+1)
|
||||
for idx, cost in enumerate(costs):
|
||||
if cost == costs[ind]:
|
||||
if idx == 0:
|
||||
if oper_matrix[i + 1][j + 1] == "O":
|
||||
oper_matrix[i + 1][j + 1] = ["T" + str(k + 1)]
|
||||
else:
|
||||
oper_matrix[i + 1][j + 1].append("T" + str(k + 1))
|
||||
elif idx == 1:
|
||||
if oper_matrix[i + 1][j + 1] == "O":
|
||||
oper_matrix[i + 1][j + 1] = ["S"]
|
||||
else:
|
||||
oper_matrix[i + 1][j + 1].append("S")
|
||||
elif idx == 2:
|
||||
if oper_matrix[i + 1][j + 1] == "O":
|
||||
oper_matrix[i + 1][j + 1] = ["I"]
|
||||
else:
|
||||
oper_matrix[i + 1][j + 1].append("I")
|
||||
else:
|
||||
if oper_matrix[i + 1][j + 1] == "O":
|
||||
oper_matrix[i + 1][j + 1] = ["D"]
|
||||
else:
|
||||
oper_matrix[i + 1][j + 1].append("D")
|
||||
return cost_matrix, oper_matrix
|
||||
|
||||
def _dfs(self, i, j, align_seq_now, oper_matrix, strategy="all"):
|
||||
"""
|
||||
深度优先遍历,获取最小编辑距离相同的所有序列
|
||||
"""
|
||||
if i + j == 0:
|
||||
self.align_seqs.append(align_seq_now)
|
||||
else:
|
||||
ops = oper_matrix[i][j] # 可以类比成搜索一棵树从根结点到叶子结点的所有路径
|
||||
if strategy != "all": ops = ops[:1]
|
||||
for op in ops:
|
||||
if op in {"M", "S"}:
|
||||
self._dfs(i - 1, j - 1, align_seq_now + [(op, i - 1, i, j - 1, j)], oper_matrix, strategy)
|
||||
elif op == "D":
|
||||
self._dfs(i - 1, j, align_seq_now + [(op, i - 1, i, j, j)], oper_matrix, strategy)
|
||||
elif op == "I":
|
||||
self._dfs(i, j - 1, align_seq_now + [(op, i, i, j - 1, j)], oper_matrix, strategy)
|
||||
else:
|
||||
k = int(op[1:])
|
||||
self._dfs(i - k, j - k, align_seq_now + [(op, i - k, i, j - k, j)], oper_matrix, strategy)
|
||||
|
||||
def get_cheapest_align_seq(self, oper_matrix):
|
||||
"""
|
||||
回溯获得编辑距离最小的编辑序列
|
||||
"""
|
||||
self.align_seqs = []
|
||||
i = oper_matrix.shape[0] - 1
|
||||
j = oper_matrix.shape[1] - 1
|
||||
if abs(i - j) > 10:
|
||||
self._dfs(i, j , [], oper_matrix, "first")
|
||||
else:
|
||||
self._dfs(i, j , [], oper_matrix, "all")
|
||||
final_align_seqs = [seq[::-1] for seq in self.align_seqs]
|
||||
return final_align_seqs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = Tokenizer("word")
|
||||
semantic_dict, semantic_class = read_cilin()
|
||||
confusion_dict = read_confusion()
|
||||
alignment = Alignment(semantic_dict, confusion_dict)
|
||||
sents = ["首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 搾 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 6 粒 , 纯净 水 4量杯 、 香菜 半量杯 和 草菇 10 个 。".replace(" ", ""), "首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 榨 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 六 粒 , 纯净 水 四 量杯 、 香菜 半量杯 和 草菇 十 个 。".replace(" ", "")]
|
||||
src, tgt = tokenizer(sents)
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Dict
|
||||
from modules.tokenizer import Tokenizer
|
||||
import os
|
||||
from string import punctuation
|
||||
|
||||
REAL_PATH = os.path.split(os.path.realpath(__file__))[0]
|
||||
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏"
|
||||
english_punct = punctuation
|
||||
punct = chinese_punct + english_punct
|
||||
|
||||
def check_all_chinese(word):
|
||||
"""
|
||||
判断一个单词是否全部由中文组成
|
||||
:param word:
|
||||
:return:
|
||||
"""
|
||||
return all(['\u4e00' <= ch <= '\u9fff' for ch in word])
|
||||
|
||||
def read_cilin():
|
||||
"""
|
||||
Cilin 詞林 is a thesaurus with semantic information
|
||||
"""
|
||||
# TODO -- fix this path
|
||||
project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path
|
||||
lines = open(os.path.join(project_dir, "data", "cilin.txt"), "r", encoding="gbk").read().strip().split("\n")
|
||||
semantic_dict = {}
|
||||
semantic_classes = {}
|
||||
for line in lines:
|
||||
code, *words = line.split(" ")
|
||||
for word in words:
|
||||
semantic_dict[word] = code
|
||||
# make reverse dict
|
||||
if code in semantic_classes:
|
||||
semantic_classes[code] += words
|
||||
else:
|
||||
semantic_classes[code] = words
|
||||
return semantic_dict, semantic_classes
|
||||
|
||||
|
||||
def read_confusion():
|
||||
confusion_dict = {}
|
||||
project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path
|
||||
with open(os.path.join(project_dir, "data", "confusion_dict.txt"), "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
li = line.rstrip('\n').split(" ")
|
||||
confusion_dict[li[0]] = li[1:]
|
||||
return confusion_dict
|
||||
|
||||
class Alignment:
|
||||
"""
|
||||
对齐错误句子和正确句子,
|
||||
使用编辑距离算法抽取编辑操作
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
semantic_dict: Dict,
|
||||
confusion_dict: Dict,
|
||||
granularity: str = "word",
|
||||
) -> None:
|
||||
"""
|
||||
构造函数
|
||||
:param semantic_dict: 语义词典(大词林)
|
||||
:param confusion_dict: 字符混淆集
|
||||
"""
|
||||
self.insertion_cost = 1
|
||||
self.deletion_cost = 1
|
||||
self.semantic_dict = semantic_dict
|
||||
self.confusion_dict = confusion_dict
|
||||
# Because we use character level tokenization, this doesn't currently use POS
|
||||
self._open_pos = {} # 如果是词级别,还可以利用词性是否相同来计算cost
|
||||
self.granularity = granularity # word-level or character-level
|
||||
self.align_seqs = []
|
||||
|
||||
def __call__(self,
|
||||
src: List[Tuple],
|
||||
tgt: List[Tuple],
|
||||
verbose: bool = False):
|
||||
cost_matrix, oper_matrix = self.align(src, tgt)
|
||||
align_seq = self.get_cheapest_align_seq(oper_matrix)
|
||||
|
||||
if verbose:
|
||||
print("========== Seg. and POS: ==========")
|
||||
print(src)
|
||||
print(tgt)
|
||||
print("========== Cost Matrix ==========")
|
||||
print(cost_matrix)
|
||||
print("========== Oper Matrix ==========")
|
||||
print(oper_matrix)
|
||||
print("========== Alignment ==========")
|
||||
print(align_seq)
|
||||
print("========== Results ==========")
|
||||
for a in align_seq:
|
||||
print(a[0], src[a[1]: a[2]], tgt[a[3]: a[4]])
|
||||
return align_seq
|
||||
|
||||
def _get_semantic_class(self, word):
|
||||
"""
|
||||
NOTE: Based on the paper:
|
||||
Improved-Edit-Distance Kernel for Chinese Relation Extraction
|
||||
获取每个词语的语义类别(基于大词林,有三个级别)
|
||||
"""
|
||||
if word in self.semantic_dict:
|
||||
code = self.semantic_dict[word]
|
||||
high, mid, low = code[0], code[1], code[2:4]
|
||||
return high, mid, low
|
||||
else: # unknown
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_class_diff(a_class, b_class):
|
||||
"""
|
||||
d == 3 for equivalent semantics
|
||||
d == 0 for completely different semantics
|
||||
根据大词林的信息,计算两个词的语义类别的差距
|
||||
"""
|
||||
d = sum([a == b for a, b in zip(a_class, b_class)])
|
||||
return d
|
||||
|
||||
def _get_semantic_cost(self, a, b):
|
||||
"""
|
||||
计算基于语义信息的替换操作cost
|
||||
:param a: 单词a的语义类别
|
||||
:param b: 单词b的语义类别
|
||||
:return: 替换编辑代价
|
||||
"""
|
||||
a_class = self._get_semantic_class(a)
|
||||
b_class = self._get_semantic_class(b)
|
||||
# unknown class, default to 1
|
||||
if a_class is None or b_class is None:
|
||||
return 4
|
||||
elif a_class == b_class:
|
||||
return 0
|
||||
else:
|
||||
return 2 * (3 - self._get_class_diff(a_class, b_class))
|
||||
|
||||
def _get_pos_cost(self, a_pos, b_pos):
|
||||
"""
|
||||
计算基于词性信息的编辑距离cost
|
||||
:param a_pos: 单词a的词性
|
||||
:param b_pos: 单词b的词性
|
||||
:return: 替换编辑代价
|
||||
"""
|
||||
if a_pos == b_pos:
|
||||
return 0
|
||||
elif a_pos in self._open_pos and b_pos in self._open_pos:
|
||||
return 0.25
|
||||
else:
|
||||
return 0.499
|
||||
|
||||
def _get_char_cost(self, a, b, pinyin_a, pinyin_b):
|
||||
"""
|
||||
NOTE: This is a replacement of ERRANTS lemma cost for Chinese
|
||||
计算基于字符相似度的编辑距离cost
|
||||
"""
|
||||
if not (check_all_chinese(a) and check_all_chinese(b)):
|
||||
return 0.5
|
||||
if len(a) > len(b):
|
||||
a, b = b, a
|
||||
pinyin_a, pinyin_b = pinyin_b, pinyin_a
|
||||
if a == b:
|
||||
return 0
|
||||
else:
|
||||
return self._get_spell_cost(a, b, pinyin_a, pinyin_b)
|
||||
|
||||
def _get_spell_cost(self, a, b, pinyin_a, pinyin_b):
|
||||
"""
|
||||
计算两个单词拼写相似度,分别由字形相似度和字音相似度组成
|
||||
:param a: 单词a
|
||||
:param b: 单词b,且单词a的长度小于等于b
|
||||
:param pinyin_a: 单词a的拼音
|
||||
:param pinyin_b: 单词b的拼音
|
||||
:return: 替换操作cost
|
||||
"""
|
||||
count = 0
|
||||
for i in range(len(a)):
|
||||
for j in range(len(b)):
|
||||
if a[i] == b[j] or (set(pinyin_a) & set(pinyin_b)) or (b[j] in self.confusion_dict.keys() and a[i] in self.confusion_dict[b[j]]) or (a[i] in self.confusion_dict.keys() and b[j] in self.confusion_dict[a[i]]):
|
||||
count += 1
|
||||
break
|
||||
return (len(a) - count) / (len(a) * 2)
|
||||
|
||||
def get_sub_cost(self, a_seg, b_seg):
|
||||
"""
|
||||
Calculate the substitution cost between words a and b
|
||||
计算两个单词替换操作的编辑cost,最大为2,等于一次删除和一次添加
|
||||
"""
|
||||
if a_seg[0] == b_seg[0]:
|
||||
return 0
|
||||
|
||||
if self.granularity == "word": # 词级别可以额外利用词性信息
|
||||
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
|
||||
pos_cost = self._get_pos_cost(a_seg[1], b_seg[1])
|
||||
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
|
||||
return semantic_cost + pos_cost + char_cost
|
||||
else: # 字级别只能利用字义信息(从大词林中获取)和字面相似度信息
|
||||
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
|
||||
if a_seg[0] in punct and b_seg[0] in punct:
|
||||
pos_cost = 0.0
|
||||
elif a_seg[0] not in punct and b_seg[0] not in punct:
|
||||
pos_cost = 0.25
|
||||
else:
|
||||
pos_cost = 0.499
|
||||
# pos_cost = 0.0 if (a_seg[0] in punct and b_seg[0] in punct) or (a_seg[0] not in punct and b_seg[0] not in punct) else 0.5
|
||||
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
|
||||
return semantic_cost + char_cost + pos_cost
|
||||
|
||||
def align(self,
|
||||
src: List[Tuple],
|
||||
tgt: List[Tuple]):
|
||||
"""
|
||||
Based on ERRANT's alignment
|
||||
基于改进的动态规划算法,为原句子的每个字打上编辑标签,以便使它能够成功转换为目标句子。
|
||||
编辑操作类别:
|
||||
1) M:Match,即KEEP,即当前字保持不变
|
||||
2) D:Delete,删除,即当前字需要被删除
|
||||
3) I:Insert,插入,即当前字需要被插入
|
||||
4) T:Transposition,移位操作,即涉及到词序问题
|
||||
"""
|
||||
cost_matrix = np.zeros((len(src) + 1, len(tgt) + 1)) # 编辑cost矩阵
|
||||
oper_matrix = np.full(
|
||||
(len(src) + 1, len(tgt) + 1), "O", dtype=object
|
||||
) # 操作矩阵
|
||||
# Fill in the edges
|
||||
for i in range(1, len(src) + 1):
|
||||
cost_matrix[i][0] = cost_matrix[i - 1][0] + 1
|
||||
oper_matrix[i][0] = ["D"]
|
||||
for j in range(1, len(tgt) + 1):
|
||||
cost_matrix[0][j] = cost_matrix[0][j - 1] + 1
|
||||
oper_matrix[0][j] = ["I"]
|
||||
|
||||
# Loop through the cost matrix
|
||||
for i in range(len(src)):
|
||||
for j in range(len(tgt)):
|
||||
# Matches
|
||||
if src[i][0] == tgt[j][0]: # 如果两个字相等,则匹配成功(Match),编辑距离为0
|
||||
cost_matrix[i + 1][j + 1] = cost_matrix[i][j]
|
||||
oper_matrix[i + 1][j + 1] = ["M"]
|
||||
# Non-matches
|
||||
else:
|
||||
del_cost = cost_matrix[i][j + 1] + self.deletion_cost # 由删除动作得到的总cost
|
||||
ins_cost = cost_matrix[i + 1][j] + self.insertion_cost # 由插入动作得到的总cost
|
||||
sub_cost = cost_matrix[i][j] + self.get_sub_cost(
|
||||
src[i], tgt[j]
|
||||
) # 由替换动作得到的总cost
|
||||
# Calculate transposition cost
|
||||
# 计算移位操作的总cost
|
||||
trans_cost = float("inf")
|
||||
k = 1
|
||||
while (
|
||||
i - k >= 0
|
||||
and j - k >= 0
|
||||
and cost_matrix[i - k + 1][j - k + 1]
|
||||
!= cost_matrix[i - k][j - k]
|
||||
):
|
||||
p1 = sorted([a[0] for a in src][i - k: i + 1])
|
||||
p2 = sorted([b[0] for b in tgt][j - k: j + 1])
|
||||
if p1 == p2:
|
||||
trans_cost = cost_matrix[i - k][j - k] + k
|
||||
break
|
||||
k += 1
|
||||
|
||||
costs = [trans_cost, sub_cost, ins_cost, del_cost]
|
||||
ind = costs.index(min(costs))
|
||||
cost_matrix[i + 1][j + 1] = costs[ind]
|
||||
# ind = costs.index(costs[ind], ind+1)
|
||||
for idx, cost in enumerate(costs):
|
||||
if cost == costs[ind]:
|
||||
if idx == 0:
|
||||
if oper_matrix[i + 1][j + 1] == "O":
|
||||
oper_matrix[i + 1][j + 1] = ["T" + str(k + 1)]
|
||||
else:
|
||||
oper_matrix[i + 1][j + 1].append("T" + str(k + 1))
|
||||
elif idx == 1:
|
||||
if oper_matrix[i + 1][j + 1] == "O":
|
||||
oper_matrix[i + 1][j + 1] = ["S"]
|
||||
else:
|
||||
oper_matrix[i + 1][j + 1].append("S")
|
||||
elif idx == 2:
|
||||
if oper_matrix[i + 1][j + 1] == "O":
|
||||
oper_matrix[i + 1][j + 1] = ["I"]
|
||||
else:
|
||||
oper_matrix[i + 1][j + 1].append("I")
|
||||
else:
|
||||
if oper_matrix[i + 1][j + 1] == "O":
|
||||
oper_matrix[i + 1][j + 1] = ["D"]
|
||||
else:
|
||||
oper_matrix[i + 1][j + 1].append("D")
|
||||
return cost_matrix, oper_matrix
|
||||
|
||||
def _dfs(self, i, j, align_seq_now, oper_matrix, strategy="all"):
|
||||
"""
|
||||
深度优先遍历,获取最小编辑距离相同的所有序列
|
||||
"""
|
||||
if i + j == 0:
|
||||
self.align_seqs.append(align_seq_now)
|
||||
else:
|
||||
ops = oper_matrix[i][j] # 可以类比成搜索一棵树从根结点到叶子结点的所有路径
|
||||
if strategy != "all": ops = ops[:1]
|
||||
for op in ops:
|
||||
if op in {"M", "S"}:
|
||||
self._dfs(i - 1, j - 1, align_seq_now + [(op, i - 1, i, j - 1, j)], oper_matrix, strategy)
|
||||
elif op == "D":
|
||||
self._dfs(i - 1, j, align_seq_now + [(op, i - 1, i, j, j)], oper_matrix, strategy)
|
||||
elif op == "I":
|
||||
self._dfs(i, j - 1, align_seq_now + [(op, i, i, j - 1, j)], oper_matrix, strategy)
|
||||
else:
|
||||
k = int(op[1:])
|
||||
self._dfs(i - k, j - k, align_seq_now + [(op, i - k, i, j - k, j)], oper_matrix, strategy)
|
||||
|
||||
def get_cheapest_align_seq(self, oper_matrix):
|
||||
"""
|
||||
回溯获得编辑距离最小的编辑序列
|
||||
"""
|
||||
self.align_seqs = []
|
||||
i = oper_matrix.shape[0] - 1
|
||||
j = oper_matrix.shape[1] - 1
|
||||
if abs(i - j) > 10:
|
||||
self._dfs(i, j , [], oper_matrix, "first")
|
||||
else:
|
||||
self._dfs(i, j , [], oper_matrix, "all")
|
||||
final_align_seqs = [seq[::-1] for seq in self.align_seqs]
|
||||
return final_align_seqs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = Tokenizer("word")
|
||||
semantic_dict, semantic_class = read_cilin()
|
||||
confusion_dict = read_confusion()
|
||||
alignment = Alignment(semantic_dict, confusion_dict)
|
||||
sents = ["首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 搾 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 6 粒 , 纯净 水 4量杯 、 香菜 半量杯 和 草菇 10 个 。".replace(" ", ""), "首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 榨 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 六 粒 , 纯净 水 四 量杯 、 香菜 半量杯 和 草菇 十 个 。".replace(" ", "")]
|
||||
src, tgt = tokenizer(sents)
|
||||
alignment(src, tgt, verbose=True)
|
@ -1,76 +1,76 @@
|
||||
from typing import List, Tuple
|
||||
from modules.alignment import read_cilin, read_confusion, Alignment
|
||||
from modules.merger import Merger
|
||||
from modules.classifier import Classifier
|
||||
|
||||
class Annotator:
|
||||
def __init__(self,
|
||||
align: Alignment,
|
||||
merger: Merger,
|
||||
classifier: Classifier,
|
||||
granularity: str = "word",
|
||||
strategy: str = "first"):
|
||||
self.align = align
|
||||
self.merger = merger
|
||||
self.classifier = classifier
|
||||
self.granularity = granularity
|
||||
self.strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def create_default(cls, granularity: str = "word", strategy: str = "first"):
|
||||
"""
|
||||
Default parameters used in the paper
|
||||
"""
|
||||
semantic_dict, semantic_class = read_cilin()
|
||||
confusion_dict = read_confusion()
|
||||
align = Alignment(semantic_dict, confusion_dict, granularity)
|
||||
merger = Merger(granularity)
|
||||
classifier = Classifier(granularity)
|
||||
return cls(align, merger, classifier, granularity, strategy)
|
||||
|
||||
def __call__(self,
|
||||
src: List[Tuple],
|
||||
tgt: List[Tuple],
|
||||
annotator_id: int = 0,
|
||||
verbose: bool = False):
|
||||
"""
|
||||
Align sentences and annotate them with error type information
|
||||
"""
|
||||
src_tokens = [x[0] for x in src]
|
||||
tgt_tokens = [x[0] for x in tgt]
|
||||
src_str = "".join(src_tokens)
|
||||
tgt_str = "".join(tgt_tokens)
|
||||
# convert to text form
|
||||
annotations_out = ["S " + " ".join(src_tokens) + "\n"]
|
||||
if tgt_str == "没有错误" or src_str == tgt_str: # Error Free Case
|
||||
annotations_out.append(f"T{annotator_id} 没有错误\n")
|
||||
cors = [tgt_str]
|
||||
op, toks, inds = "noop", "-NONE-", (-1, -1)
|
||||
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
|
||||
annotations_out.append(a_str)
|
||||
elif tgt_str == "无法标注": # Not Annotatable Case
|
||||
annotations_out.append(f"T{annotator_id} 无法标注\n")
|
||||
cors = [tgt_str]
|
||||
op, toks, inds = "NA", "-NONE-", (-1, -1)
|
||||
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
|
||||
annotations_out.append(a_str)
|
||||
else: # Other
|
||||
align_objs = self.align(src, tgt)
|
||||
edit_objs = []
|
||||
align_idx = 0
|
||||
if self.strategy == "first":
|
||||
align_objs = align_objs[:1]
|
||||
for align_obj in align_objs:
|
||||
edits = self.merger(align_obj, src, tgt, verbose)
|
||||
if edits not in edit_objs:
|
||||
edit_objs.append(edits)
|
||||
annotations_out.append(f"T{annotator_id}-A{align_idx} " + " ".join(tgt_tokens) + "\n")
|
||||
align_idx += 1
|
||||
cors = self.classifier(src, tgt, edits, verbose)
|
||||
# annotations_out = []
|
||||
for cor in cors:
|
||||
op, toks, inds = cor.op, cor.toks, cor.inds
|
||||
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
|
||||
annotations_out.append(a_str)
|
||||
annotations_out.append("\n")
|
||||
return annotations_out, cors
|
||||
from typing import List, Tuple
|
||||
from modules.alignment import read_cilin, read_confusion, Alignment
|
||||
from modules.merger import Merger
|
||||
from modules.classifier import Classifier
|
||||
|
||||
class Annotator:
|
||||
def __init__(self,
|
||||
align: Alignment,
|
||||
merger: Merger,
|
||||
classifier: Classifier,
|
||||
granularity: str = "word",
|
||||
strategy: str = "first"):
|
||||
self.align = align
|
||||
self.merger = merger
|
||||
self.classifier = classifier
|
||||
self.granularity = granularity
|
||||
self.strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def create_default(cls, granularity: str = "word", strategy: str = "first"):
|
||||
"""
|
||||
Default parameters used in the paper
|
||||
"""
|
||||
semantic_dict, semantic_class = read_cilin()
|
||||
confusion_dict = read_confusion()
|
||||
align = Alignment(semantic_dict, confusion_dict, granularity)
|
||||
merger = Merger(granularity)
|
||||
classifier = Classifier(granularity)
|
||||
return cls(align, merger, classifier, granularity, strategy)
|
||||
|
||||
def __call__(self,
|
||||
src: List[Tuple],
|
||||
tgt: List[Tuple],
|
||||
annotator_id: int = 0,
|
||||
verbose: bool = False):
|
||||
"""
|
||||
Align sentences and annotate them with error type information
|
||||
"""
|
||||
src_tokens = [x[0] for x in src]
|
||||
tgt_tokens = [x[0] for x in tgt]
|
||||
src_str = "".join(src_tokens)
|
||||
tgt_str = "".join(tgt_tokens)
|
||||
# convert to text form
|
||||
annotations_out = ["S " + " ".join(src_tokens) + "\n"]
|
||||
if tgt_str == "没有错误" or src_str == tgt_str: # Error Free Case
|
||||
annotations_out.append(f"T{annotator_id} 没有错误\n")
|
||||
cors = [tgt_str]
|
||||
op, toks, inds = "noop", "-NONE-", (-1, -1)
|
||||
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
|
||||
annotations_out.append(a_str)
|
||||
elif tgt_str == "无法标注": # Not Annotatable Case
|
||||
annotations_out.append(f"T{annotator_id} 无法标注\n")
|
||||
cors = [tgt_str]
|
||||
op, toks, inds = "NA", "-NONE-", (-1, -1)
|
||||
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
|
||||
annotations_out.append(a_str)
|
||||
else: # Other
|
||||
align_objs = self.align(src, tgt)
|
||||
edit_objs = []
|
||||
align_idx = 0
|
||||
if self.strategy == "first":
|
||||
align_objs = align_objs[:1]
|
||||
for align_obj in align_objs:
|
||||
edits = self.merger(align_obj, src, tgt, verbose)
|
||||
if edits not in edit_objs:
|
||||
edit_objs.append(edits)
|
||||
annotations_out.append(f"T{annotator_id}-A{align_idx} " + " ".join(tgt_tokens) + "\n")
|
||||
align_idx += 1
|
||||
cors = self.classifier(src, tgt, edits, verbose)
|
||||
# annotations_out = []
|
||||
for cor in cors:
|
||||
op, toks, inds = cor.op, cor.toks, cor.inds
|
||||
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
|
||||
annotations_out.append(a_str)
|
||||
annotations_out.append("\n")
|
||||
return annotations_out, cors
|
||||
|
@ -1,151 +1,151 @@
|
||||
from char_smi import CharFuncs
|
||||
from collections import namedtuple
|
||||
from pypinyin import pinyin, Style
|
||||
import os
|
||||
Correction = namedtuple(
|
||||
"Correction",
|
||||
[
|
||||
"op",
|
||||
"toks",
|
||||
"inds",
|
||||
],
|
||||
)
|
||||
file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
char_smi = CharFuncs(os.path.join(file_path.replace("modules", ""), 'data/char_meta.txt'))
|
||||
|
||||
def check_spell_error(src_span: str,
|
||||
tgt_span: str,
|
||||
threshold: float = 0.8) -> bool:
|
||||
if len(src_span) != len(tgt_span):
|
||||
return False
|
||||
src_chars = [ch for ch in src_span]
|
||||
tgt_chars = [ch for ch in tgt_span]
|
||||
if sorted(src_chars) == sorted(tgt_chars): # 词内部字符异位
|
||||
return True
|
||||
for src_char, tgt_char in zip(src_chars, tgt_chars):
|
||||
if src_char != tgt_char:
|
||||
if src_char not in char_smi.data or tgt_char not in char_smi.data:
|
||||
return False
|
||||
v_sim = char_smi.shape_similarity(src_char, tgt_char)
|
||||
p_sim = char_smi.pronunciation_similarity(src_char, tgt_char)
|
||||
if v_sim + p_sim < threshold and not (
|
||||
set(pinyin(src_char, style=Style.NORMAL, heteronym=True)[0]) & set(pinyin(tgt_char, style=Style.NORMAL, heteronym=True)[0])):
|
||||
return False
|
||||
return True
|
||||
|
||||
class Classifier:
|
||||
"""
|
||||
错误类型分类器
|
||||
"""
|
||||
def __init__(self,
|
||||
granularity: str = "word"):
|
||||
|
||||
self.granularity = granularity
|
||||
|
||||
@staticmethod
|
||||
def get_pos_type(pos):
|
||||
if pos in {"n", "nd"}:
|
||||
return "NOUN"
|
||||
if pos in {"nh", "ni", "nl", "ns", "nt", "nz"}:
|
||||
return "NOUN-NE"
|
||||
if pos in {"v"}:
|
||||
return "VERB"
|
||||
if pos in {"a", "b"}:
|
||||
return "ADJ"
|
||||
if pos in {"c"}:
|
||||
return "CONJ"
|
||||
if pos in {"r"}:
|
||||
return "PRON"
|
||||
if pos in {"d"}:
|
||||
return "ADV"
|
||||
if pos in {"u"}:
|
||||
return "AUX"
|
||||
# if pos in {"k"}: # TODO 后缀词比例太少,暂且分入其它
|
||||
# return "SUFFIX"
|
||||
if pos in {"m"}:
|
||||
return "NUM"
|
||||
if pos in {"p"}:
|
||||
return "PREP"
|
||||
if pos in {"q"}:
|
||||
return "QUAN"
|
||||
if pos in {"wp"}:
|
||||
return "PUNCT"
|
||||
return "OTHER"
|
||||
|
||||
def __call__(self,
|
||||
src,
|
||||
tgt,
|
||||
edits,
|
||||
verbose: bool = False):
|
||||
"""
|
||||
为编辑操作划分错误类型
|
||||
:param src: 错误句子信息
|
||||
:param tgt: 正确句子信息
|
||||
:param edits: 编辑操作
|
||||
:param verbose: 是否打印信息
|
||||
:return: 划分完错误类型后的编辑操作
|
||||
"""
|
||||
results = []
|
||||
src_tokens = [x[0] for x in src]
|
||||
tgt_tokens = [x[0] for x in tgt]
|
||||
for edit in edits:
|
||||
error_type = edit[0]
|
||||
src_span = " ".join(src_tokens[edit[1]: edit[2]])
|
||||
tgt_span = " ".join(tgt_tokens[edit[3]: edit[4]])
|
||||
# print(tgt_span)
|
||||
cor = None
|
||||
if error_type[0] == "T":
|
||||
cor = Correction("W", tgt_span, (edit[1], edit[2]))
|
||||
elif error_type[0] == "D":
|
||||
if self.granularity == "word": # 词级别可以细分错误类型
|
||||
if edit[2] - edit[1] > 1: # 词组冗余暂时分为OTHER
|
||||
cor = Correction("R:OTHER", "-NONE-", (edit[1], edit[2]))
|
||||
else:
|
||||
pos = self.get_pos_type(src[edit[1]][1])
|
||||
pos = "NOUN" if pos == "NOUN-NE" else pos
|
||||
pos = "MC" if tgt_span == "[缺失成分]" else pos
|
||||
cor = Correction("R:{:s}".format(pos), "-NONE-", (edit[1], edit[2]))
|
||||
else: # 字级别可以只需要根据操作划分类型即可
|
||||
cor = Correction("R", "-NONE-", (edit[1], edit[2]))
|
||||
elif error_type[0] == "I":
|
||||
if self.granularity == "word": # 词级别可以细分错误类型
|
||||
if edit[4] - edit[3] > 1: # 词组丢失暂时分为OTHER
|
||||
cor = Correction("M:OTHER", tgt_span, (edit[1], edit[2]))
|
||||
else:
|
||||
pos = self.get_pos_type(tgt[edit[3]][1])
|
||||
pos = "NOUN" if pos == "NOUN-NE" else pos
|
||||
pos = "MC" if tgt_span == "[缺失成分]" else pos
|
||||
cor = Correction("M:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
|
||||
else: # 字级别可以只需要根据操作划分类型即可
|
||||
cor = Correction("M", tgt_span, (edit[1], edit[2]))
|
||||
elif error_type[0] == "S":
|
||||
if self.granularity == "word": # 词级别可以细分错误类型
|
||||
if check_spell_error(src_span.replace(" ", ""), tgt_span.replace(" ", "")):
|
||||
cor = Correction("S:SPELL", tgt_span, (edit[1], edit[2]))
|
||||
# Todo 暂且不单独区分命名实体拼写错误
|
||||
# if edit[4] - edit[3] > 1:
|
||||
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
|
||||
# else:
|
||||
# pos = self.get_pos_type(tgt[edit[3]][1])
|
||||
# if pos == "NOUN-NE": # 命名实体拼写有误
|
||||
# cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2]))
|
||||
# else: # 普通词语拼写有误
|
||||
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
|
||||
else:
|
||||
if edit[4] - edit[3] > 1: # 词组被替换暂时分为OTHER
|
||||
cor = Correction("S:OTHER", tgt_span, (edit[1], edit[2]))
|
||||
else:
|
||||
pos = self.get_pos_type(tgt[edit[3]][1])
|
||||
pos = "NOUN" if pos == "NOUN-NE" else pos
|
||||
pos = "MC" if tgt_span == "[缺失成分]" else pos
|
||||
cor = Correction("S:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
|
||||
else: # 字级别可以只需要根据操作划分类型即可
|
||||
cor = Correction("S", tgt_span, (edit[1], edit[2]))
|
||||
results.append(cor)
|
||||
if verbose:
|
||||
print("========== Corrections ==========")
|
||||
for cor in results:
|
||||
print("Type: {:s}, Position: {:d} -> {:d}, Target: {:s}".format(cor.op, cor.inds[0], cor.inds[1], cor.toks))
|
||||
return results
|
||||
|
||||
# print(pinyin("朝", style=Style.NORMAL))
|
||||
from char_smi import CharFuncs
|
||||
from collections import namedtuple
|
||||
from pypinyin import pinyin, Style
|
||||
import os
|
||||
Correction = namedtuple(
|
||||
"Correction",
|
||||
[
|
||||
"op",
|
||||
"toks",
|
||||
"inds",
|
||||
],
|
||||
)
|
||||
file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
char_smi = CharFuncs(os.path.join(file_path.replace("modules", ""), 'data/char_meta.txt'))
|
||||
|
||||
def check_spell_error(src_span: str,
|
||||
tgt_span: str,
|
||||
threshold: float = 0.8) -> bool:
|
||||
if len(src_span) != len(tgt_span):
|
||||
return False
|
||||
src_chars = [ch for ch in src_span]
|
||||
tgt_chars = [ch for ch in tgt_span]
|
||||
if sorted(src_chars) == sorted(tgt_chars): # 词内部字符异位
|
||||
return True
|
||||
for src_char, tgt_char in zip(src_chars, tgt_chars):
|
||||
if src_char != tgt_char:
|
||||
if src_char not in char_smi.data or tgt_char not in char_smi.data:
|
||||
return False
|
||||
v_sim = char_smi.shape_similarity(src_char, tgt_char)
|
||||
p_sim = char_smi.pronunciation_similarity(src_char, tgt_char)
|
||||
if v_sim + p_sim < threshold and not (
|
||||
set(pinyin(src_char, style=Style.NORMAL, heteronym=True)[0]) & set(pinyin(tgt_char, style=Style.NORMAL, heteronym=True)[0])):
|
||||
return False
|
||||
return True
|
||||
|
||||
class Classifier:
|
||||
"""
|
||||
错误类型分类器
|
||||
"""
|
||||
def __init__(self,
|
||||
granularity: str = "word"):
|
||||
|
||||
self.granularity = granularity
|
||||
|
||||
@staticmethod
|
||||
def get_pos_type(pos):
|
||||
if pos in {"n", "nd"}:
|
||||
return "NOUN"
|
||||
if pos in {"nh", "ni", "nl", "ns", "nt", "nz"}:
|
||||
return "NOUN-NE"
|
||||
if pos in {"v"}:
|
||||
return "VERB"
|
||||
if pos in {"a", "b"}:
|
||||
return "ADJ"
|
||||
if pos in {"c"}:
|
||||
return "CONJ"
|
||||
if pos in {"r"}:
|
||||
return "PRON"
|
||||
if pos in {"d"}:
|
||||
return "ADV"
|
||||
if pos in {"u"}:
|
||||
return "AUX"
|
||||
# if pos in {"k"}: # TODO 后缀词比例太少,暂且分入其它
|
||||
# return "SUFFIX"
|
||||
if pos in {"m"}:
|
||||
return "NUM"
|
||||
if pos in {"p"}:
|
||||
return "PREP"
|
||||
if pos in {"q"}:
|
||||
return "QUAN"
|
||||
if pos in {"wp"}:
|
||||
return "PUNCT"
|
||||
return "OTHER"
|
||||
|
||||
def __call__(self,
|
||||
src,
|
||||
tgt,
|
||||
edits,
|
||||
verbose: bool = False):
|
||||
"""
|
||||
为编辑操作划分错误类型
|
||||
:param src: 错误句子信息
|
||||
:param tgt: 正确句子信息
|
||||
:param edits: 编辑操作
|
||||
:param verbose: 是否打印信息
|
||||
:return: 划分完错误类型后的编辑操作
|
||||
"""
|
||||
results = []
|
||||
src_tokens = [x[0] for x in src]
|
||||
tgt_tokens = [x[0] for x in tgt]
|
||||
for edit in edits:
|
||||
error_type = edit[0]
|
||||
src_span = " ".join(src_tokens[edit[1]: edit[2]])
|
||||
tgt_span = " ".join(tgt_tokens[edit[3]: edit[4]])
|
||||
# print(tgt_span)
|
||||
cor = None
|
||||
if error_type[0] == "T":
|
||||
cor = Correction("W", tgt_span, (edit[1], edit[2]))
|
||||
elif error_type[0] == "D":
|
||||
if self.granularity == "word": # 词级别可以细分错误类型
|
||||
if edit[2] - edit[1] > 1: # 词组冗余暂时分为OTHER
|
||||
cor = Correction("R:OTHER", "-NONE-", (edit[1], edit[2]))
|
||||
else:
|
||||
pos = self.get_pos_type(src[edit[1]][1])
|
||||
pos = "NOUN" if pos == "NOUN-NE" else pos
|
||||
pos = "MC" if tgt_span == "[缺失成分]" else pos
|
||||
cor = Correction("R:{:s}".format(pos), "-NONE-", (edit[1], edit[2]))
|
||||
else: # 字级别可以只需要根据操作划分类型即可
|
||||
cor = Correction("R", "-NONE-", (edit[1], edit[2]))
|
||||
elif error_type[0] == "I":
|
||||
if self.granularity == "word": # 词级别可以细分错误类型
|
||||
if edit[4] - edit[3] > 1: # 词组丢失暂时分为OTHER
|
||||
cor = Correction("M:OTHER", tgt_span, (edit[1], edit[2]))
|
||||
else:
|
||||
pos = self.get_pos_type(tgt[edit[3]][1])
|
||||
pos = "NOUN" if pos == "NOUN-NE" else pos
|
||||
pos = "MC" if tgt_span == "[缺失成分]" else pos
|
||||
cor = Correction("M:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
|
||||
else: # 字级别可以只需要根据操作划分类型即可
|
||||
cor = Correction("M", tgt_span, (edit[1], edit[2]))
|
||||
elif error_type[0] == "S":
|
||||
if self.granularity == "word": # 词级别可以细分错误类型
|
||||
if check_spell_error(src_span.replace(" ", ""), tgt_span.replace(" ", "")):
|
||||
cor = Correction("S:SPELL", tgt_span, (edit[1], edit[2]))
|
||||
# Todo 暂且不单独区分命名实体拼写错误
|
||||
# if edit[4] - edit[3] > 1:
|
||||
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
|
||||
# else:
|
||||
# pos = self.get_pos_type(tgt[edit[3]][1])
|
||||
# if pos == "NOUN-NE": # 命名实体拼写有误
|
||||
# cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2]))
|
||||
# else: # 普通词语拼写有误
|
||||
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
|
||||
else:
|
||||
if edit[4] - edit[3] > 1: # 词组被替换暂时分为OTHER
|
||||
cor = Correction("S:OTHER", tgt_span, (edit[1], edit[2]))
|
||||
else:
|
||||
pos = self.get_pos_type(tgt[edit[3]][1])
|
||||
pos = "NOUN" if pos == "NOUN-NE" else pos
|
||||
pos = "MC" if tgt_span == "[缺失成分]" else pos
|
||||
cor = Correction("S:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
|
||||
else: # 字级别可以只需要根据操作划分类型即可
|
||||
cor = Correction("S", tgt_span, (edit[1], edit[2]))
|
||||
results.append(cor)
|
||||
if verbose:
|
||||
print("========== Corrections ==========")
|
||||
for cor in results:
|
||||
print("Type: {:s}, Position: {:d} -> {:d}, Target: {:s}".format(cor.op, cor.inds[0], cor.inds[1], cor.toks))
|
||||
return results
|
||||
|
||||
# print(pinyin("朝", style=Style.NORMAL))
|
||||
|
@ -1,273 +1,273 @@
|
||||
from itertools import groupby
|
||||
from string import punctuation
|
||||
from typing import List
|
||||
from modules.tokenizer import Tokenizer
|
||||
from modules.alignment import Alignment, read_cilin, read_confusion
|
||||
import Levenshtein
|
||||
|
||||
class Merger:
|
||||
"""
|
||||
合并编辑操作,从Token-Level转换为Span-Level
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
granularity: str = "word",
|
||||
merge: bool = False):
|
||||
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧."
|
||||
self.punctuation = punctuation + chinese_punct
|
||||
self.not_merge_token = [punct for punct in self.punctuation]
|
||||
self.granularity = granularity
|
||||
self.merge = merge
|
||||
|
||||
@staticmethod
|
||||
def _merge_edits(seq, tag="X"):
|
||||
if seq:
|
||||
return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
|
||||
else:
|
||||
return seq
|
||||
|
||||
@staticmethod
|
||||
def _check_revolve(span_a, span_b):
|
||||
span_a = span_a + span_a
|
||||
return span_b in span_a
|
||||
|
||||
def _process_seq(self, seq, src_tokens, tgt_tokens):
|
||||
if len(seq) <= 1:
|
||||
return seq
|
||||
|
||||
ops = [op[0] for op in seq]
|
||||
if set(ops) == {"D"} or set(ops) == {"I"}:
|
||||
return self._merge_edits(seq, set(ops).pop())
|
||||
|
||||
if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}:
|
||||
# do not merge this pattern_from_qua.txt
|
||||
return seq
|
||||
|
||||
if set(ops) == {"S"}:
|
||||
if self.granularity == "word":
|
||||
return seq
|
||||
else:
|
||||
return self._merge_edits(seq, "S")
|
||||
|
||||
if set(ops) == {"M"}:
|
||||
return self._merge_edits(seq, "M")
|
||||
|
||||
return self._merge_edits(seq, "S")
|
||||
|
||||
def __call__(self,
|
||||
align_obj,
|
||||
src: List,
|
||||
tgt: List,
|
||||
verbose: bool = False):
|
||||
"""
|
||||
Based on ERRANT's merge, adapted for Chinese
|
||||
"""
|
||||
src_tokens = [x[0] for x in src]
|
||||
tgt_tokens = [x[0] for x in tgt]
|
||||
edits = []
|
||||
# Split alignment into groups of M, T and rest. (T has a number after it)
|
||||
# Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并
|
||||
# Todo 缺失成分标签也不与其它编辑合并
|
||||
for op, group in groupby(
|
||||
align_obj,
|
||||
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False,
|
||||
):
|
||||
group = list(group)
|
||||
# T is always split TODO: Evaluate this
|
||||
if op == "T":
|
||||
for seq in group:
|
||||
edits.append(seq)
|
||||
# Process D, I and S subsequence
|
||||
else:
|
||||
# Turn the processed sequence into edits
|
||||
processed = self._process_seq(group, src_tokens, tgt_tokens)
|
||||
for seq in processed:
|
||||
edits.append(seq)
|
||||
|
||||
filtered_edits = []
|
||||
i = 0
|
||||
while i < len(edits):
|
||||
e1 = edits[i][0][0]
|
||||
|
||||
if i < len(edits) - 2:
|
||||
e2 = edits[i + 1][0][0]
|
||||
e3 = edits[i + 2][0][0]
|
||||
|
||||
# Find "S M S" patterns
|
||||
# Ex:
|
||||
# S M S
|
||||
# 冬阴功 对 外国人
|
||||
# 外国人 对 冬阴功
|
||||
if e1 == "S" and e2 == "M" and e3 == "S":
|
||||
w1 = "".join(src_tokens[edits[i][1]: edits[i][2]])
|
||||
w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]])
|
||||
w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]])
|
||||
w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]])
|
||||
if min([len(w1), len(w2), len(w3), len(w4)]) == 1:
|
||||
if w1 == w4 and w2 == w3:
|
||||
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||||
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
|
||||
for seq in processed:
|
||||
filtered_edits.append(seq)
|
||||
i += 3
|
||||
else:
|
||||
filtered_edits.append(edits[i])
|
||||
i += 1
|
||||
else:
|
||||
if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1:
|
||||
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||||
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
|
||||
for seq in processed:
|
||||
filtered_edits.append(seq)
|
||||
i += 3
|
||||
else:
|
||||
filtered_edits.append(edits[i])
|
||||
i += 1
|
||||
# Find "D M I" or "I M D" patterns
|
||||
# Ex:
|
||||
# D M I
|
||||
# 旅游 去 陌生 的 地方
|
||||
# 去 陌生 的 地方 旅游
|
||||
elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"):
|
||||
if e1 == "D":
|
||||
delete_token = src_tokens[edits[i][1]: edits[i][2]]
|
||||
insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]
|
||||
else:
|
||||
delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]]
|
||||
insert_token = tgt_tokens[edits[i][3]: edits[i][4]]
|
||||
a, b = "".join(delete_token), "".join(insert_token)
|
||||
if len(a) < len(b):
|
||||
a, b = b, a
|
||||
if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1:
|
||||
if len(b) == 1:
|
||||
if a == b:
|
||||
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||||
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
|
||||
for seq in processed:
|
||||
filtered_edits.append(seq)
|
||||
i += 3
|
||||
else:
|
||||
filtered_edits.append(edits[i])
|
||||
i += 1
|
||||
else:
|
||||
if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)):
|
||||
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||||
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
|
||||
for seq in processed:
|
||||
filtered_edits.append(seq)
|
||||
i += 3
|
||||
else:
|
||||
filtered_edits.append(edits[i])
|
||||
i += 1
|
||||
else:
|
||||
filtered_edits.append(edits[i])
|
||||
i += 1
|
||||
else:
|
||||
if e1 != "M":
|
||||
filtered_edits.append(edits[i])
|
||||
i += 1
|
||||
else:
|
||||
if e1 != "M":
|
||||
filtered_edits.append(edits[i])
|
||||
i += 1
|
||||
# In rare cases with word-level tokenization, the following error can occur:
|
||||
# M D S M
|
||||
# 有 時 住 上層
|
||||
# 有 時住 上層
|
||||
# Which results in S: 時住 --> 時住
|
||||
# We need to filter this case out
|
||||
second_filter = []
|
||||
for edit in filtered_edits: # 避免因为分词错误导致的mismatch现象
|
||||
span1 = "".join(src_tokens[edit[1] : edit[2]])
|
||||
span2 = "".join(tgt_tokens[edit[3] : edit[4]])
|
||||
|
||||
if span1 != span2:
|
||||
if edit[0] == "S":
|
||||
b = True
|
||||
# In rare cases with word-level tokenization, the following error can occur:
|
||||
# S I I M
|
||||
# 负责任 老师
|
||||
# 负 责任 的 老师
|
||||
# Which results in S: 负责任 --> 负 责任 的
|
||||
# We need to convert this edit to I: --> 的
|
||||
|
||||
# 首部有重叠
|
||||
common_str = ""
|
||||
tmp_new_start_1 = edit[1]
|
||||
for i in range(edit[1], edit[2]):
|
||||
if not span2.startswith(common_str + src_tokens[i]):
|
||||
break
|
||||
common_str += src_tokens[i]
|
||||
tmp_new_start_1 = i + 1
|
||||
new_start_1, new_start_2 = edit[1], edit[3]
|
||||
if common_str:
|
||||
tmp_str = ""
|
||||
for i in range(edit[3], edit[4]):
|
||||
tmp_str += tgt_tokens[i]
|
||||
if tmp_str == common_str:
|
||||
new_start_1, new_start_2 = tmp_new_start_1, i + 1
|
||||
# second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4]))
|
||||
b = False
|
||||
break
|
||||
elif len(tmp_str) > len(common_str):
|
||||
break
|
||||
# 尾部有重叠
|
||||
common_str = ""
|
||||
new_end_1, new_end_2 = edit[2], edit[4]
|
||||
tmp_new_end_1 = edit[2]
|
||||
for i in reversed(range(new_start_1, edit[2])):
|
||||
if not span2.endswith(src_tokens[i] + common_str):
|
||||
break
|
||||
common_str = src_tokens[i] + common_str
|
||||
tmp_new_end_1 = i
|
||||
if common_str:
|
||||
tmp_str = ""
|
||||
for i in reversed(range(new_start_2, edit[4])):
|
||||
tmp_str = tgt_tokens[i] + tmp_str
|
||||
if tmp_str == common_str:
|
||||
new_end_1, new_end_2 = tmp_new_end_1, i
|
||||
b = False
|
||||
break
|
||||
elif len(tmp_str) > len(common_str):
|
||||
break
|
||||
if b:
|
||||
second_filter.append(edit)
|
||||
else:
|
||||
if new_start_1 == new_end_1:
|
||||
new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2)
|
||||
elif new_start_2 == new_end_2:
|
||||
new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2)
|
||||
else:
|
||||
new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2)
|
||||
second_filter.append(new_edit)
|
||||
else:
|
||||
second_filter.append(edit)
|
||||
if verbose:
|
||||
print("========== Parallels ==========")
|
||||
print("".join(src_tokens))
|
||||
print("".join(tgt_tokens))
|
||||
print("========== Results ==========")
|
||||
for edit in second_filter:
|
||||
op = edit[0]
|
||||
s = " ".join(src_tokens[edit[1]: edit[2]])
|
||||
t = " ".join(tgt_tokens[edit[3]: edit[4]])
|
||||
print(f"{op}:\t{s}\t-->\t{t}")
|
||||
print("========== Infos ==========")
|
||||
print(str(src))
|
||||
print(str(tgt))
|
||||
return second_filter
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = Tokenizer("char")
|
||||
semantic_dict, semantic_class = read_cilin()
|
||||
confusion_dict = read_confusion()
|
||||
alignment = Alignment(semantic_dict, confusion_dict)
|
||||
sents = [
|
||||
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace(
|
||||
" ", ""),
|
||||
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace(
|
||||
" ", "")]
|
||||
src, tgt = tokenizer(sents)
|
||||
align_obj = alignment(src, tgt)
|
||||
m = Merger()
|
||||
from itertools import groupby
|
||||
from string import punctuation
|
||||
from typing import List
|
||||
from modules.tokenizer import Tokenizer
|
||||
from modules.alignment import Alignment, read_cilin, read_confusion
|
||||
import Levenshtein
|
||||
|
||||
class Merger:
|
||||
"""
|
||||
合并编辑操作,从Token-Level转换为Span-Level
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
granularity: str = "word",
|
||||
merge: bool = False):
|
||||
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧."
|
||||
self.punctuation = punctuation + chinese_punct
|
||||
self.not_merge_token = [punct for punct in self.punctuation]
|
||||
self.granularity = granularity
|
||||
self.merge = merge
|
||||
|
||||
@staticmethod
|
||||
def _merge_edits(seq, tag="X"):
|
||||
if seq:
|
||||
return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
|
||||
else:
|
||||
return seq
|
||||
|
||||
@staticmethod
|
||||
def _check_revolve(span_a, span_b):
|
||||
span_a = span_a + span_a
|
||||
return span_b in span_a
|
||||
|
||||
def _process_seq(self, seq, src_tokens, tgt_tokens):
|
||||
if len(seq) <= 1:
|
||||
return seq
|
||||
|
||||
ops = [op[0] for op in seq]
|
||||
if set(ops) == {"D"} or set(ops) == {"I"}:
|
||||
return self._merge_edits(seq, set(ops).pop())
|
||||
|
||||
if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}:
|
||||
# do not merge this pattern_from_qua.txt
|
||||
return seq
|
||||
|
||||
if set(ops) == {"S"}:
|
||||
if self.granularity == "word":
|
||||
return seq
|
||||
else:
|
||||
return self._merge_edits(seq, "S")
|
||||
|
||||
if set(ops) == {"M"}:
|
||||
return self._merge_edits(seq, "M")
|
||||
|
||||
return self._merge_edits(seq, "S")
|
||||
|
||||
def __call__(self,
|
||||
align_obj,
|
||||
src: List,
|
||||
tgt: List,
|
||||
verbose: bool = False):
|
||||
"""
|
||||
Based on ERRANT's merge, adapted for Chinese
|
||||
"""
|
||||
src_tokens = [x[0] for x in src]
|
||||
tgt_tokens = [x[0] for x in tgt]
|
||||
edits = []
|
||||
# Split alignment into groups of M, T and rest. (T has a number after it)
|
||||
# Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并
|
||||
# Todo 缺失成分标签也不与其它编辑合并
|
||||
for op, group in groupby(
|
||||
align_obj,
|
||||
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False,
|
||||
):
|
||||
group = list(group)
|
||||
# T is always split TODO: Evaluate this
|
||||
if op == "T":
|
||||
for seq in group:
|
||||
edits.append(seq)
|
||||
# Process D, I and S subsequence
|
||||
else:
|
||||
# Turn the processed sequence into edits
|
||||
processed = self._process_seq(group, src_tokens, tgt_tokens)
|
||||
for seq in processed:
|
||||
edits.append(seq)
|
||||
|
||||
filtered_edits = []
|
||||
i = 0
|
||||
while i < len(edits):
|
||||
e1 = edits[i][0][0]
|
||||
|
||||
if i < len(edits) - 2:
|
||||
e2 = edits[i + 1][0][0]
|
||||
e3 = edits[i + 2][0][0]
|
||||
|
||||
# Find "S M S" patterns
|
||||
# Ex:
|
||||
# S M S
|
||||
# 冬阴功 对 外国人
|
||||
# 外国人 对 冬阴功
|
||||
if e1 == "S" and e2 == "M" and e3 == "S":
|
||||
w1 = "".join(src_tokens[edits[i][1]: edits[i][2]])
|
||||
w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]])
|
||||
w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]])
|
||||
w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]])
|
||||
if min([len(w1), len(w2), len(w3), len(w4)]) == 1:
|
||||
if w1 == w4 and w2 == w3:
|
||||
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||||
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
|
||||
for seq in processed:
|
||||
filtered_edits.append(seq)
|
||||
i += 3
|
||||
else:
|
||||
filtered_edits.append(edits[i])
|
||||
i += 1
|
||||
else:
|
||||
if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1:
|
||||
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||||
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
|
||||
for seq in processed:
|
||||
filtered_edits.append(seq)
|
||||
i += 3
|
||||
else:
|
||||
filtered_edits.append(edits[i])
|
||||
i += 1
|
||||
# Find "D M I" or "I M D" patterns
|
||||
# Ex:
|
||||
# D M I
|
||||
# 旅游 去 陌生 的 地方
|
||||
# 去 陌生 的 地方 旅游
|
||||
elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"):
|
||||
if e1 == "D":
|
||||
delete_token = src_tokens[edits[i][1]: edits[i][2]]
|
||||
insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]
|
||||
else:
|
||||
delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]]
|
||||
insert_token = tgt_tokens[edits[i][3]: edits[i][4]]
|
||||
a, b = "".join(delete_token), "".join(insert_token)
|
||||
if len(a) < len(b):
|
||||
a, b = b, a
|
||||
if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1:
|
||||
if len(b) == 1:
|
||||
if a == b:
|
||||
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||||
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
|
||||
for seq in processed:
|
||||
filtered_edits.append(seq)
|
||||
i += 3
|
||||
else:
|
||||
filtered_edits.append(edits[i])
|
||||
i += 1
|
||||
else:
|
||||
if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)):
|
||||
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||||
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
|
||||
for seq in processed:
|
||||
filtered_edits.append(seq)
|
||||
i += 3
|
||||
else:
|
||||
filtered_edits.append(edits[i])
|
||||
i += 1
|
||||
else:
|
||||
filtered_edits.append(edits[i])
|
||||
i += 1
|
||||
else:
|
||||
if e1 != "M":
|
||||
filtered_edits.append(edits[i])
|
||||
i += 1
|
||||
else:
|
||||
if e1 != "M":
|
||||
filtered_edits.append(edits[i])
|
||||
i += 1
|
||||
# In rare cases with word-level tokenization, the following error can occur:
|
||||
# M D S M
|
||||
# 有 時 住 上層
|
||||
# 有 時住 上層
|
||||
# Which results in S: 時住 --> 時住
|
||||
# We need to filter this case out
|
||||
second_filter = []
|
||||
for edit in filtered_edits: # 避免因为分词错误导致的mismatch现象
|
||||
span1 = "".join(src_tokens[edit[1] : edit[2]])
|
||||
span2 = "".join(tgt_tokens[edit[3] : edit[4]])
|
||||
|
||||
if span1 != span2:
|
||||
if edit[0] == "S":
|
||||
b = True
|
||||
# In rare cases with word-level tokenization, the following error can occur:
|
||||
# S I I M
|
||||
# 负责任 老师
|
||||
# 负 责任 的 老师
|
||||
# Which results in S: 负责任 --> 负 责任 的
|
||||
# We need to convert this edit to I: --> 的
|
||||
|
||||
# 首部有重叠
|
||||
common_str = ""
|
||||
tmp_new_start_1 = edit[1]
|
||||
for i in range(edit[1], edit[2]):
|
||||
if not span2.startswith(common_str + src_tokens[i]):
|
||||
break
|
||||
common_str += src_tokens[i]
|
||||
tmp_new_start_1 = i + 1
|
||||
new_start_1, new_start_2 = edit[1], edit[3]
|
||||
if common_str:
|
||||
tmp_str = ""
|
||||
for i in range(edit[3], edit[4]):
|
||||
tmp_str += tgt_tokens[i]
|
||||
if tmp_str == common_str:
|
||||
new_start_1, new_start_2 = tmp_new_start_1, i + 1
|
||||
# second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4]))
|
||||
b = False
|
||||
break
|
||||
elif len(tmp_str) > len(common_str):
|
||||
break
|
||||
# 尾部有重叠
|
||||
common_str = ""
|
||||
new_end_1, new_end_2 = edit[2], edit[4]
|
||||
tmp_new_end_1 = edit[2]
|
||||
for i in reversed(range(new_start_1, edit[2])):
|
||||
if not span2.endswith(src_tokens[i] + common_str):
|
||||
break
|
||||
common_str = src_tokens[i] + common_str
|
||||
tmp_new_end_1 = i
|
||||
if common_str:
|
||||
tmp_str = ""
|
||||
for i in reversed(range(new_start_2, edit[4])):
|
||||
tmp_str = tgt_tokens[i] + tmp_str
|
||||
if tmp_str == common_str:
|
||||
new_end_1, new_end_2 = tmp_new_end_1, i
|
||||
b = False
|
||||
break
|
||||
elif len(tmp_str) > len(common_str):
|
||||
break
|
||||
if b:
|
||||
second_filter.append(edit)
|
||||
else:
|
||||
if new_start_1 == new_end_1:
|
||||
new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2)
|
||||
elif new_start_2 == new_end_2:
|
||||
new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2)
|
||||
else:
|
||||
new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2)
|
||||
second_filter.append(new_edit)
|
||||
else:
|
||||
second_filter.append(edit)
|
||||
if verbose:
|
||||
print("========== Parallels ==========")
|
||||
print("".join(src_tokens))
|
||||
print("".join(tgt_tokens))
|
||||
print("========== Results ==========")
|
||||
for edit in second_filter:
|
||||
op = edit[0]
|
||||
s = " ".join(src_tokens[edit[1]: edit[2]])
|
||||
t = " ".join(tgt_tokens[edit[3]: edit[4]])
|
||||
print(f"{op}:\t{s}\t-->\t{t}")
|
||||
print("========== Infos ==========")
|
||||
print(str(src))
|
||||
print(str(tgt))
|
||||
return second_filter
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = Tokenizer("char")
|
||||
semantic_dict, semantic_class = read_cilin()
|
||||
confusion_dict = read_confusion()
|
||||
alignment = Alignment(semantic_dict, confusion_dict)
|
||||
sents = [
|
||||
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace(
|
||||
" ", ""),
|
||||
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace(
|
||||
" ", "")]
|
||||
src, tgt = tokenizer(sents)
|
||||
align_obj = alignment(src, tgt)
|
||||
m = Merger()
|
||||
m(align_obj, src, tgt, verbose=True)
|
@ -1,92 +1,92 @@
|
||||
from ltp import LTP
|
||||
from typing import List
|
||||
from pypinyin import pinyin, Style, lazy_pinyin
|
||||
import torch
|
||||
import os
|
||||
import functools
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
分词器
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
granularity: str = "word",
|
||||
device: str = "cpu",
|
||||
segmented: bool = False,
|
||||
bpe: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
构造函数
|
||||
:param mode: 分词模式,可选级别:字级别(char)、词级别(word)
|
||||
"""
|
||||
self.ltp = None
|
||||
if granularity == "word":
|
||||
self.ltp = LTP(device=torch.device(device) if torch.cuda.is_available() else torch.device("cpu"))
|
||||
self.ltp.add_words(words=["[缺失成分]"], max_window=6)
|
||||
self.segmented = segmented
|
||||
self.granularity = granularity
|
||||
if self.granularity == "word":
|
||||
self.tokenizer = self.split_word
|
||||
elif self.granularity == "char":
|
||||
self.tokenizer = functools.partial(self.split_char, bpe=bpe)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{:s}\nMode:{:s}\n}".format(str(self.__class__.__name__), self.mode)
|
||||
|
||||
def __call__(self,
|
||||
input_strings: List[str]
|
||||
) -> List:
|
||||
"""
|
||||
分词函数
|
||||
:param input_strings: 需要分词的字符串列表
|
||||
:return: 分词后的结果列表,由元组组成,元组为(token,pos_tag,pinyin)的形式
|
||||
"""
|
||||
if not self.segmented:
|
||||
input_strings = ["".join(s.split(" ")) for s in input_strings]
|
||||
results = self.tokenizer(input_strings)
|
||||
return results
|
||||
|
||||
def split_char(self, input_strings: List[str], bpe=False) -> List:
|
||||
"""
|
||||
分字函数
|
||||
:param input_strings: 需要分字的字符串
|
||||
:return: 分字结果
|
||||
"""
|
||||
if bpe:
|
||||
from . import tokenization
|
||||
project_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(project_dir,"data","chinese_vocab.txt"), do_lower_case=False)
|
||||
results = []
|
||||
for input_string in input_strings:
|
||||
if not self.segmented: # 如果没有被分字,就按照每个字符隔开(不考虑英文标点的特殊处理,也不考虑BPE),否则遵循原分字结果
|
||||
segment_string = " ".join([char for char in input_string] if not bpe else tokenizer.tokenize(input_string))
|
||||
else:
|
||||
segment_string = input_string
|
||||
# print(segment_string)
|
||||
segment_string = segment_string.replace("[ 缺 失 成 分 ]", "[缺失成分]").split(" ") # 缺失成分当成一个单独的token
|
||||
results.append([(char, "unk", pinyin(char, style=Style.NORMAL, heteronym=True)[0]) for char in segment_string])
|
||||
return results
|
||||
|
||||
def split_word(self, input_strings: List[str]) -> List:
|
||||
"""
|
||||
分词函数
|
||||
:param input_strings: 需要分词的字符串
|
||||
:return: 分词结果
|
||||
"""
|
||||
if self.segmented:
|
||||
seg, hidden = self.ltp.seg([input_string.split(" ") for input_string in input_strings], is_preseged=True)
|
||||
else:
|
||||
seg, hidden = self.ltp.seg(input_strings)
|
||||
pos = self.ltp.pos(hidden)
|
||||
result = []
|
||||
for s, p in zip(seg, pos):
|
||||
pinyin = [lazy_pinyin(word) for word in s]
|
||||
result.append(list(zip(s, p, pinyin)))
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = Tokenizer("word")
|
||||
print(tokenizer(["LAC是个优秀的分词工具", "百度是一家高科技公司"]))
|
||||
from ltp import LTP
|
||||
from typing import List
|
||||
from pypinyin import pinyin, Style, lazy_pinyin
|
||||
import torch
|
||||
import os
|
||||
import functools
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
分词器
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
granularity: str = "word",
|
||||
device: str = "cpu",
|
||||
segmented: bool = False,
|
||||
bpe: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
构造函数
|
||||
:param mode: 分词模式,可选级别:字级别(char)、词级别(word)
|
||||
"""
|
||||
self.ltp = None
|
||||
if granularity == "word":
|
||||
self.ltp = LTP(device=torch.device(device) if torch.cuda.is_available() else torch.device("cpu"))
|
||||
self.ltp.add_words(words=["[缺失成分]"], max_window=6)
|
||||
self.segmented = segmented
|
||||
self.granularity = granularity
|
||||
if self.granularity == "word":
|
||||
self.tokenizer = self.split_word
|
||||
elif self.granularity == "char":
|
||||
self.tokenizer = functools.partial(self.split_char, bpe=bpe)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{:s}\nMode:{:s}\n}".format(str(self.__class__.__name__), self.mode)
|
||||
|
||||
def __call__(self,
|
||||
input_strings: List[str]
|
||||
) -> List:
|
||||
"""
|
||||
分词函数
|
||||
:param input_strings: 需要分词的字符串列表
|
||||
:return: 分词后的结果列表,由元组组成,元组为(token,pos_tag,pinyin)的形式
|
||||
"""
|
||||
if not self.segmented:
|
||||
input_strings = ["".join(s.split(" ")) for s in input_strings]
|
||||
results = self.tokenizer(input_strings)
|
||||
return results
|
||||
|
||||
def split_char(self, input_strings: List[str], bpe=False) -> List:
|
||||
"""
|
||||
分字函数
|
||||
:param input_strings: 需要分字的字符串
|
||||
:return: 分字结果
|
||||
"""
|
||||
if bpe:
|
||||
from . import tokenization
|
||||
project_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(project_dir,"data","chinese_vocab.txt"), do_lower_case=False)
|
||||
results = []
|
||||
for input_string in input_strings:
|
||||
if not self.segmented: # 如果没有被分字,就按照每个字符隔开(不考虑英文标点的特殊处理,也不考虑BPE),否则遵循原分字结果
|
||||
segment_string = " ".join([char for char in input_string] if not bpe else tokenizer.tokenize(input_string))
|
||||
else:
|
||||
segment_string = input_string
|
||||
# print(segment_string)
|
||||
segment_string = segment_string.replace("[ 缺 失 成 分 ]", "[缺失成分]").split(" ") # 缺失成分当成一个单独的token
|
||||
results.append([(char, "unk", pinyin(char, style=Style.NORMAL, heteronym=True)[0]) for char in segment_string])
|
||||
return results
|
||||
|
||||
def split_word(self, input_strings: List[str]) -> List:
|
||||
"""
|
||||
分词函数
|
||||
:param input_strings: 需要分词的字符串
|
||||
:return: 分词结果
|
||||
"""
|
||||
if self.segmented:
|
||||
seg, hidden = self.ltp.seg([input_string.split(" ") for input_string in input_strings], is_preseged=True)
|
||||
else:
|
||||
seg, hidden = self.ltp.seg(input_strings)
|
||||
pos = self.ltp.pos(hidden)
|
||||
result = []
|
||||
for s, p in zip(seg, pos):
|
||||
pinyin = [lazy_pinyin(word) for word in s]
|
||||
result.append(list(zip(s, p, pinyin)))
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = Tokenizer("word")
|
||||
print(tokenizer(["LAC是个优秀的分词工具", "百度是一家高科技公司"]))
|
||||
|
@ -1,221 +1,221 @@
|
||||
import os
|
||||
from modules.annotator import Annotator
|
||||
from modules.tokenizer import Tokenizer
|
||||
import argparse
|
||||
from collections import Counter
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
from multiprocessing import Pool
|
||||
from opencc import OpenCC
|
||||
import timeout_decorator
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
annotator, sentence_to_tokenized = None, None
|
||||
cc = OpenCC("t2s")
|
||||
|
||||
@timeout_decorator.timeout(10)
|
||||
def annotate_with_time_out(line):
|
||||
"""
|
||||
:param line:
|
||||
:return:
|
||||
"""
|
||||
sent_list = line.split("\t")[1:]
|
||||
source = sent_list[0]
|
||||
if args.segmented:
|
||||
source = source.strip()
|
||||
else:
|
||||
source = "".join(source.strip().split())
|
||||
output_str = ""
|
||||
for idx, target in enumerate(sent_list[1:]):
|
||||
try:
|
||||
if args.segmented:
|
||||
target = target.strip()
|
||||
else:
|
||||
target = "".join(target.strip().split())
|
||||
if not args.no_simplified:
|
||||
target = cc.convert(target)
|
||||
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
|
||||
out, cors = annotator(source_tokenized, target_tokenized, idx)
|
||||
if idx == 0:
|
||||
output_str += "".join(out[:-1])
|
||||
else:
|
||||
output_str += "".join(out[1:-1])
|
||||
except Exception:
|
||||
raise Exception
|
||||
return output_str
|
||||
|
||||
|
||||
def annotate(line):
|
||||
"""
|
||||
:param line:
|
||||
:return:
|
||||
"""
|
||||
sent_list = line.split("\t")[1:]
|
||||
source = sent_list[0]
|
||||
if args.segmented:
|
||||
source = source.strip()
|
||||
else:
|
||||
source = "".join(source.strip().split())
|
||||
output_str = ""
|
||||
for idx, target in enumerate(sent_list[1:]):
|
||||
try:
|
||||
if args.segmented:
|
||||
target = target.strip()
|
||||
else:
|
||||
target = "".join(target.strip().split())
|
||||
if not args.no_simplified:
|
||||
target = cc.convert(target)
|
||||
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
|
||||
out, cors = annotator(source_tokenized, target_tokenized, idx)
|
||||
if idx == 0:
|
||||
output_str += "".join(out[:-1])
|
||||
else:
|
||||
output_str += "".join(out[1:-1])
|
||||
except Exception:
|
||||
raise Exception
|
||||
return output_str
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def firsttime_process(args):
|
||||
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
|
||||
global annotator, sentence_to_tokenized
|
||||
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
|
||||
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n") # format: id src tgt1 tgt2...
|
||||
# error_types = []
|
||||
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
count = 0
|
||||
sentence_set = set()
|
||||
sentence_to_tokenized = {}
|
||||
for line in lines:
|
||||
sent_list = line.split("\t")[1:]
|
||||
for idx, sent in enumerate(sent_list):
|
||||
if args.segmented:
|
||||
# print(sent)
|
||||
sent = sent.strip()
|
||||
else:
|
||||
sent = "".join(sent.split()).strip()
|
||||
if idx >= 1:
|
||||
if not args.no_simplified:
|
||||
sentence_set.add(cc.convert(sent))
|
||||
else:
|
||||
sentence_set.add(sent)
|
||||
else:
|
||||
sentence_set.add(sent)
|
||||
batch = []
|
||||
for sent in tqdm(sentence_set):
|
||||
count += 1
|
||||
if sent:
|
||||
batch.append(sent)
|
||||
if count % args.batch_size == 0:
|
||||
results = tokenizer(batch)
|
||||
for s, r in zip(batch, results):
|
||||
sentence_to_tokenized[s] = r # Get tokenization map.
|
||||
batch = []
|
||||
if batch:
|
||||
results = tokenizer(batch)
|
||||
for s, r in zip(batch, results):
|
||||
sentence_to_tokenized[s] = r # Get tokenization map.
|
||||
|
||||
timeout_indices = []
|
||||
|
||||
# 单进程模式
|
||||
for idx, line in enumerate(tqdm(lines)):
|
||||
try:
|
||||
ret = annotate_with_time_out(line)
|
||||
except Exception:
|
||||
timeout_indices.append(idx)
|
||||
return timeout_indices
|
||||
|
||||
|
||||
|
||||
def main(args):
|
||||
timeout_indices = firsttime_process(args)
|
||||
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
|
||||
global annotator, sentence_to_tokenized
|
||||
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
|
||||
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n")
|
||||
new_lines = []# format: id src tgt1 tgt2...
|
||||
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
count = 0
|
||||
sentence_set = set()
|
||||
sentence_to_tokenized = {}
|
||||
for line_idx, line in enumerate(lines):
|
||||
|
||||
if line_idx in timeout_indices:
|
||||
# print(f"line before split: {line}")
|
||||
line_split = line.split("\t")
|
||||
line_number, sent_list = line_split[0], line_split[1:]
|
||||
assert len(sent_list) == 2
|
||||
sent_list[-1] = " 无"
|
||||
line = line_number + "\t" + "\t".join(sent_list)
|
||||
# print(f"line time out: {line}")
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line)
|
||||
|
||||
sent_list = line.split("\t")[1:]
|
||||
for idx, sent in enumerate(sent_list):
|
||||
if args.segmented:
|
||||
# print(sent)
|
||||
sent = sent.strip()
|
||||
else:
|
||||
sent = "".join(sent.split()).strip()
|
||||
if idx >= 1:
|
||||
if not args.no_simplified:
|
||||
sentence_set.add(cc.convert(sent))
|
||||
else:
|
||||
sentence_set.add(sent)
|
||||
else:
|
||||
sentence_set.add(sent)
|
||||
batch = []
|
||||
for sent in tqdm(sentence_set):
|
||||
count += 1
|
||||
if sent:
|
||||
batch.append(sent)
|
||||
if count % args.batch_size == 0:
|
||||
results = tokenizer(batch)
|
||||
for s, r in zip(batch, results):
|
||||
sentence_to_tokenized[s] = r # Get tokenization map.
|
||||
batch = []
|
||||
if batch:
|
||||
results = tokenizer(batch)
|
||||
for s, r in zip(batch, results):
|
||||
sentence_to_tokenized[s] = r # Get tokenization map.
|
||||
|
||||
# 单进程模式
|
||||
lines = new_lines
|
||||
for idx, line in enumerate(tqdm(lines)):
|
||||
ret = annotate(line)
|
||||
f.write(ret)
|
||||
f.write("\n")
|
||||
|
||||
# 多进程模式:仅在Linux环境下测试,建议在linux服务器上使用
|
||||
# with Pool(args.worker_num) as pool:
|
||||
# for ret in pool.imap(annotate, tqdm(lines), chunksize=8):
|
||||
# if ret:
|
||||
# f.write(ret)
|
||||
# f.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Choose input file to annotate")
|
||||
parser.add_argument("-f", "--file", type=str, required=True, help="Input parallel file")
|
||||
parser.add_argument("-o", "--output", type=str, help="Output file", required=True)
|
||||
parser.add_argument("-b", "--batch_size", type=int, help="The size of batch", default=128)
|
||||
parser.add_argument("-d", "--device", type=int, help="The ID of GPU", default=0)
|
||||
parser.add_argument("-w", "--worker_num", type=int, help="The number of workers", default=16)
|
||||
parser.add_argument("-g", "--granularity", type=str, help="Choose char-level or word-level evaluation", default="char")
|
||||
parser.add_argument("-m", "--merge", help="Whether merge continuous replacement/deletion/insertion", action="store_true")
|
||||
parser.add_argument("-s", "--multi_cheapest_strategy", type=str, choices=["first", "all"], default="all")
|
||||
parser.add_argument("--segmented", help="Whether tokens have been segmented", action="store_true") # 支持提前token化,用空格隔开
|
||||
parser.add_argument("--no_simplified", help="Whether simplifying chinese", action="store_true") # 将所有corrections转换为简体中文
|
||||
parser.add_argument("--bpe", help="Whether to use bpe", action="store_true") # 支持 bpe 切分英文单词
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
import os
|
||||
from modules.annotator import Annotator
|
||||
from modules.tokenizer import Tokenizer
|
||||
import argparse
|
||||
from collections import Counter
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
from multiprocessing import Pool
|
||||
from opencc import OpenCC
|
||||
import timeout_decorator
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
annotator, sentence_to_tokenized = None, None
|
||||
cc = OpenCC("t2s")
|
||||
|
||||
@timeout_decorator.timeout(10)
|
||||
def annotate_with_time_out(line):
|
||||
"""
|
||||
:param line:
|
||||
:return:
|
||||
"""
|
||||
sent_list = line.split("\t")[1:]
|
||||
source = sent_list[0]
|
||||
if args.segmented:
|
||||
source = source.strip()
|
||||
else:
|
||||
source = "".join(source.strip().split())
|
||||
output_str = ""
|
||||
for idx, target in enumerate(sent_list[1:]):
|
||||
try:
|
||||
if args.segmented:
|
||||
target = target.strip()
|
||||
else:
|
||||
target = "".join(target.strip().split())
|
||||
if not args.no_simplified:
|
||||
target = cc.convert(target)
|
||||
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
|
||||
out, cors = annotator(source_tokenized, target_tokenized, idx)
|
||||
if idx == 0:
|
||||
output_str += "".join(out[:-1])
|
||||
else:
|
||||
output_str += "".join(out[1:-1])
|
||||
except Exception:
|
||||
raise Exception
|
||||
return output_str
|
||||
|
||||
|
||||
def annotate(line):
|
||||
"""
|
||||
:param line:
|
||||
:return:
|
||||
"""
|
||||
sent_list = line.split("\t")[1:]
|
||||
source = sent_list[0]
|
||||
if args.segmented:
|
||||
source = source.strip()
|
||||
else:
|
||||
source = "".join(source.strip().split())
|
||||
output_str = ""
|
||||
for idx, target in enumerate(sent_list[1:]):
|
||||
try:
|
||||
if args.segmented:
|
||||
target = target.strip()
|
||||
else:
|
||||
target = "".join(target.strip().split())
|
||||
if not args.no_simplified:
|
||||
target = cc.convert(target)
|
||||
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
|
||||
out, cors = annotator(source_tokenized, target_tokenized, idx)
|
||||
if idx == 0:
|
||||
output_str += "".join(out[:-1])
|
||||
else:
|
||||
output_str += "".join(out[1:-1])
|
||||
except Exception:
|
||||
raise Exception
|
||||
return output_str
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def firsttime_process(args):
|
||||
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
|
||||
global annotator, sentence_to_tokenized
|
||||
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
|
||||
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n") # format: id src tgt1 tgt2...
|
||||
# error_types = []
|
||||
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
count = 0
|
||||
sentence_set = set()
|
||||
sentence_to_tokenized = {}
|
||||
for line in lines:
|
||||
sent_list = line.split("\t")[1:]
|
||||
for idx, sent in enumerate(sent_list):
|
||||
if args.segmented:
|
||||
# print(sent)
|
||||
sent = sent.strip()
|
||||
else:
|
||||
sent = "".join(sent.split()).strip()
|
||||
if idx >= 1:
|
||||
if not args.no_simplified:
|
||||
sentence_set.add(cc.convert(sent))
|
||||
else:
|
||||
sentence_set.add(sent)
|
||||
else:
|
||||
sentence_set.add(sent)
|
||||
batch = []
|
||||
for sent in tqdm(sentence_set):
|
||||
count += 1
|
||||
if sent:
|
||||
batch.append(sent)
|
||||
if count % args.batch_size == 0:
|
||||
results = tokenizer(batch)
|
||||
for s, r in zip(batch, results):
|
||||
sentence_to_tokenized[s] = r # Get tokenization map.
|
||||
batch = []
|
||||
if batch:
|
||||
results = tokenizer(batch)
|
||||
for s, r in zip(batch, results):
|
||||
sentence_to_tokenized[s] = r # Get tokenization map.
|
||||
|
||||
timeout_indices = []
|
||||
|
||||
# 单进程模式
|
||||
for idx, line in enumerate(tqdm(lines)):
|
||||
try:
|
||||
ret = annotate_with_time_out(line)
|
||||
except Exception:
|
||||
timeout_indices.append(idx)
|
||||
return timeout_indices
|
||||
|
||||
|
||||
|
||||
def main(args):
|
||||
timeout_indices = firsttime_process(args)
|
||||
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
|
||||
global annotator, sentence_to_tokenized
|
||||
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
|
||||
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n")
|
||||
new_lines = []# format: id src tgt1 tgt2...
|
||||
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
count = 0
|
||||
sentence_set = set()
|
||||
sentence_to_tokenized = {}
|
||||
for line_idx, line in enumerate(lines):
|
||||
|
||||
if line_idx in timeout_indices:
|
||||
# print(f"line before split: {line}")
|
||||
line_split = line.split("\t")
|
||||
line_number, sent_list = line_split[0], line_split[1:]
|
||||
assert len(sent_list) == 2
|
||||
sent_list[-1] = " 无"
|
||||
line = line_number + "\t" + "\t".join(sent_list)
|
||||
# print(f"line time out: {line}")
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line)
|
||||
|
||||
sent_list = line.split("\t")[1:]
|
||||
for idx, sent in enumerate(sent_list):
|
||||
if args.segmented:
|
||||
# print(sent)
|
||||
sent = sent.strip()
|
||||
else:
|
||||
sent = "".join(sent.split()).strip()
|
||||
if idx >= 1:
|
||||
if not args.no_simplified:
|
||||
sentence_set.add(cc.convert(sent))
|
||||
else:
|
||||
sentence_set.add(sent)
|
||||
else:
|
||||
sentence_set.add(sent)
|
||||
batch = []
|
||||
for sent in tqdm(sentence_set):
|
||||
count += 1
|
||||
if sent:
|
||||
batch.append(sent)
|
||||
if count % args.batch_size == 0:
|
||||
results = tokenizer(batch)
|
||||
for s, r in zip(batch, results):
|
||||
sentence_to_tokenized[s] = r # Get tokenization map.
|
||||
batch = []
|
||||
if batch:
|
||||
results = tokenizer(batch)
|
||||
for s, r in zip(batch, results):
|
||||
sentence_to_tokenized[s] = r # Get tokenization map.
|
||||
|
||||
# 单进程模式
|
||||
lines = new_lines
|
||||
for idx, line in enumerate(tqdm(lines)):
|
||||
ret = annotate(line)
|
||||
f.write(ret)
|
||||
f.write("\n")
|
||||
|
||||
# 多进程模式:仅在Linux环境下测试,建议在linux服务器上使用
|
||||
# with Pool(args.worker_num) as pool:
|
||||
# for ret in pool.imap(annotate, tqdm(lines), chunksize=8):
|
||||
# if ret:
|
||||
# f.write(ret)
|
||||
# f.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Choose input file to annotate")
|
||||
parser.add_argument("-f", "--file", type=str, required=True, help="Input parallel file")
|
||||
parser.add_argument("-o", "--output", type=str, help="Output file", required=True)
|
||||
parser.add_argument("-b", "--batch_size", type=int, help="The size of batch", default=128)
|
||||
parser.add_argument("-d", "--device", type=int, help="The ID of GPU", default=0)
|
||||
parser.add_argument("-w", "--worker_num", type=int, help="The number of workers", default=16)
|
||||
parser.add_argument("-g", "--granularity", type=str, help="Choose char-level or word-level evaluation", default="char")
|
||||
parser.add_argument("-m", "--merge", help="Whether merge continuous replacement/deletion/insertion", action="store_true")
|
||||
parser.add_argument("-s", "--multi_cheapest_strategy", type=str, choices=["first", "all"], default="all")
|
||||
parser.add_argument("--segmented", help="Whether tokens have been segmented", action="store_true") # 支持提前token化,用空格隔开
|
||||
parser.add_argument("--no_simplified", help="Whether simplifying chinese", action="store_true") # 将所有corrections转换为简体中文
|
||||
parser.add_argument("--bpe", help="Whether to use bpe", action="store_true") # 支持 bpe 切分英文单词
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -1,10 +1,13 @@
|
||||
import datetime
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import ConfigDict
|
||||
@ -43,6 +46,11 @@ class DLCRunner(BaseRunner):
|
||||
self.max_num_workers = max_num_workers
|
||||
self.retry = retry
|
||||
|
||||
logger = get_logger()
|
||||
logger.warning(
|
||||
'To ensure the integrity of the log results, the log displayed '
|
||||
f'by {self.__class__.__name__} has a 10-second delay.')
|
||||
|
||||
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
|
||||
"""Launch multiple tasks.
|
||||
|
||||
@ -63,18 +71,23 @@ class DLCRunner(BaseRunner):
|
||||
status = [self._launch(task, random_sleep=False) for task in tasks]
|
||||
return status
|
||||
|
||||
def _launch(self, cfg: ConfigDict, random_sleep: bool = True):
|
||||
def _launch(self, cfg: ConfigDict, random_sleep: Optional[bool] = None):
|
||||
"""Launch a single task.
|
||||
|
||||
Args:
|
||||
cfg (ConfigDict): Task config.
|
||||
random_sleep (bool): Whether to sleep for a random time before
|
||||
running the command. This avoids cluster error when launching
|
||||
multiple tasks at the same time. Default: True.
|
||||
running the command. When Aliyun has many tasks to schedule,
|
||||
its stability decreases. Therefore, when we need to submit a
|
||||
large number of tasks at once, we adopt the "random_sleep"
|
||||
strategy. Tasks that would have been submitted all at once are
|
||||
now evenly spread out over a 10-second period. Default: None.
|
||||
|
||||
Returns:
|
||||
tuple[str, int]: Task name and exit code.
|
||||
"""
|
||||
if random_sleep is None:
|
||||
random_sleep = (self.max_num_workers > 32)
|
||||
|
||||
task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type']))
|
||||
num_gpus = task.num_gpus
|
||||
@ -116,7 +129,7 @@ class DLCRunner(BaseRunner):
|
||||
|
||||
# Run command with retry
|
||||
if self.debug:
|
||||
stdout = None
|
||||
stdout = sys.stdout
|
||||
else:
|
||||
out_path = task.get_log_path(file_extension='out')
|
||||
mmengine.mkdir_or_exist(osp.split(out_path)[0])
|
||||
@ -124,30 +137,92 @@ class DLCRunner(BaseRunner):
|
||||
|
||||
if random_sleep:
|
||||
time.sleep(random.randint(0, 10))
|
||||
result = subprocess.run(cmd,
|
||||
shell=True,
|
||||
text=True,
|
||||
stdout=stdout,
|
||||
stderr=stdout)
|
||||
|
||||
def _run_within_retry():
|
||||
try:
|
||||
process = subprocess.Popen(cmd,
|
||||
shell=True,
|
||||
text=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
job_id = None
|
||||
job_allocated = False
|
||||
job_finished = False
|
||||
last_end_time = datetime.datetime.now().strftime(
|
||||
'%Y-%m-%dT%H:%M:%SZ')
|
||||
while True:
|
||||
if not job_allocated:
|
||||
line = process.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
match = re.search(r'(dlc[0-9a-z]+)', line)
|
||||
if match and job_id is None:
|
||||
job_id = match.group(1)
|
||||
stdout.write(line)
|
||||
match = re.search(r'Job .* is \[Running\]', line)
|
||||
if match:
|
||||
job_allocated = True
|
||||
else:
|
||||
try:
|
||||
process.wait(10)
|
||||
except subprocess.TimeoutExpired:
|
||||
pass
|
||||
else:
|
||||
job_finished = True
|
||||
if job_finished:
|
||||
this_end_time = datetime.datetime.now(
|
||||
).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
else:
|
||||
this_end_time = (
|
||||
datetime.datetime.now() -
|
||||
datetime.timedelta(seconds=10)
|
||||
).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
logs_cmd = (
|
||||
'dlc logs'
|
||||
f' {job_id} {job_id}-worker-0'
|
||||
f' --start_time {last_end_time}'
|
||||
f' --end_time {this_end_time}'
|
||||
f" -c {self.aliyun_cfg['dlc_config_path']}")
|
||||
log_process = subprocess.Popen(
|
||||
logs_cmd,
|
||||
shell=True,
|
||||
text=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
log_output, log_err = log_process.communicate()
|
||||
log_output = '\n'.join(log_output.split('\n')[2:])
|
||||
stdout.write(log_output)
|
||||
last_end_time = this_end_time
|
||||
stdout.flush()
|
||||
if job_finished:
|
||||
break
|
||||
process.wait()
|
||||
return process.returncode
|
||||
finally:
|
||||
if job_id is not None:
|
||||
cancel_cmd = (
|
||||
'dlc stop job'
|
||||
f' {job_id}'
|
||||
f" -c {self.aliyun_cfg['dlc_config_path']}"
|
||||
' -f')
|
||||
subprocess.run(cancel_cmd,
|
||||
shell=True,
|
||||
text=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
|
||||
return_code = _run_within_retry()
|
||||
retry = self.retry
|
||||
output_paths = task.get_output_paths()
|
||||
while self._job_failed(result.returncode,
|
||||
output_paths) and retry > 0:
|
||||
while self._job_failed(return_code, output_paths) and retry > 0:
|
||||
retry -= 1
|
||||
if random_sleep:
|
||||
time.sleep(random.randint(0, 10))
|
||||
# Re-generate command to refresh ports.
|
||||
cmd = get_cmd()
|
||||
result = subprocess.run(cmd,
|
||||
shell=True,
|
||||
text=True,
|
||||
stdout=stdout,
|
||||
stderr=stdout)
|
||||
return_code = _run_within_retry()
|
||||
finally:
|
||||
# Clean up
|
||||
os.remove(param_file)
|
||||
return task_name, result.returncode
|
||||
|
||||
return task_name, return_code
|
||||
|
||||
def _job_failed(self, return_code: int, output_paths: List[str]) -> bool:
|
||||
return return_code != 0 or not all(
|
||||
|
Loading…
Reference in New Issue
Block a user