mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Add lawbench (#460)
* add lawbench * update requirements * update
This commit is contained in:
parent
fbf5089c40
commit
861942ab1b
@ -3,7 +3,9 @@ exclude: |
|
|||||||
tests/data/|
|
tests/data/|
|
||||||
opencompass/models/internal/|
|
opencompass/models/internal/|
|
||||||
opencompass/utils/internal/|
|
opencompass/utils/internal/|
|
||||||
opencompass/openicl/icl_evaluator/hf_metrics/
|
opencompass/openicl/icl_evaluator/hf_metrics/|
|
||||||
|
opencompass/datasets/lawbench/utils|
|
||||||
|
opencompass/datasets/lawbench/evaluation_functions/
|
||||||
)
|
)
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
|
62
configs/datasets/lawbench/lawbench_one_shot_gen_002588.py
Normal file
62
configs/datasets/lawbench/lawbench_one_shot_gen_002588.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
|
from opencompass.datasets import LawBenchDataset
|
||||||
|
|
||||||
|
names = [
|
||||||
|
["1-1", "article_recitation"],
|
||||||
|
["1-2", "knowledge_question_answering"],
|
||||||
|
["2-1", "document_proofreading"],
|
||||||
|
["2-2", "dispute_focus_identification"],
|
||||||
|
["2-3", "marital_disputes_identification"],
|
||||||
|
["2-4", "issue_topic_identification"],
|
||||||
|
["2-5", "reading_comprehension"],
|
||||||
|
["2-6", "named_entity_recognition"],
|
||||||
|
["2-7", "opinion_summarization"],
|
||||||
|
["2-8", "argument_mining"],
|
||||||
|
["2-9", "event_detection"],
|
||||||
|
["2-10", "trigger_word_extraction"],
|
||||||
|
["3-1", "fact_based_article_prediction"],
|
||||||
|
["3-2", "scene_based_article_prediction"],
|
||||||
|
["3-3", "charge_prediction"],
|
||||||
|
["3-4", "prison_term_prediction_wo_article"],
|
||||||
|
["3-5", "prison_term_prediction_w_article"],
|
||||||
|
["3-6", "case_analysis"],
|
||||||
|
["3-7", "criminal_damages_calculation"],
|
||||||
|
["3-8", "consultation"],
|
||||||
|
]
|
||||||
|
|
||||||
|
lawbench_datasets = []
|
||||||
|
for index, name in names:
|
||||||
|
lawbench_reader_cfg = dict(
|
||||||
|
input_columns=['instruction', 'question'],
|
||||||
|
output_column='answer')
|
||||||
|
|
||||||
|
lawbench_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(
|
||||||
|
round=[
|
||||||
|
dict(role="HUMAN", prompt="{instruction}\n{question}"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=GenInferencer, max_out_len=1024)
|
||||||
|
)
|
||||||
|
|
||||||
|
lawbench_eval_cfg = dict(
|
||||||
|
evaluator=dict(type='LawBenchEvaluator_' + index.replace('-', '_'))
|
||||||
|
)
|
||||||
|
|
||||||
|
lawbench_datasets.append(
|
||||||
|
dict(
|
||||||
|
abbr='lawbench-' + index + '-' + name + '-1-shot',
|
||||||
|
type=LawBenchDataset,
|
||||||
|
path='./data/lawbench/one_shot',
|
||||||
|
index=index,
|
||||||
|
reader_cfg=lawbench_reader_cfg,
|
||||||
|
infer_cfg=lawbench_infer_cfg,
|
||||||
|
eval_cfg=lawbench_eval_cfg
|
||||||
|
)
|
||||||
|
)
|
62
configs/datasets/lawbench/lawbench_zero_shot_gen_002588.py
Normal file
62
configs/datasets/lawbench/lawbench_zero_shot_gen_002588.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
|
from opencompass.datasets import LawBenchDataset
|
||||||
|
|
||||||
|
names = [
|
||||||
|
["1-1", "article_recitation"],
|
||||||
|
["1-2", "knowledge_question_answering"],
|
||||||
|
["2-1", "document_proofreading"],
|
||||||
|
["2-2", "dispute_focus_identification"],
|
||||||
|
["2-3", "marital_disputes_identification"],
|
||||||
|
["2-4", "issue_topic_identification"],
|
||||||
|
["2-5", "reading_comprehension"],
|
||||||
|
["2-6", "named_entity_recognition"],
|
||||||
|
["2-7", "opinion_summarization"],
|
||||||
|
["2-8", "argument_mining"],
|
||||||
|
["2-9", "event_detection"],
|
||||||
|
["2-10", "trigger_word_extraction"],
|
||||||
|
["3-1", "fact_based_article_prediction"],
|
||||||
|
["3-2", "scene_based_article_prediction"],
|
||||||
|
["3-3", "charge_prediction"],
|
||||||
|
["3-4", "prison_term_prediction_wo_article"],
|
||||||
|
["3-5", "prison_term_prediction_w_article"],
|
||||||
|
["3-6", "case_analysis"],
|
||||||
|
["3-7", "criminal_damages_calculation"],
|
||||||
|
["3-8", "consultation"],
|
||||||
|
]
|
||||||
|
|
||||||
|
lawbench_datasets = []
|
||||||
|
for index, name in names:
|
||||||
|
lawbench_reader_cfg = dict(
|
||||||
|
input_columns=['instruction', 'question'],
|
||||||
|
output_column='answer')
|
||||||
|
|
||||||
|
lawbench_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(
|
||||||
|
round=[
|
||||||
|
dict(role="HUMAN", prompt="{instruction}\n{question}"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=GenInferencer, max_out_len=1024)
|
||||||
|
)
|
||||||
|
|
||||||
|
lawbench_eval_cfg = dict(
|
||||||
|
evaluator=dict(type='LawBenchEvaluator_' + index.replace('-', '_'))
|
||||||
|
)
|
||||||
|
|
||||||
|
lawbench_datasets.append(
|
||||||
|
dict(
|
||||||
|
abbr='lawbench-' + index + '-' + name + '-0-shot',
|
||||||
|
type=LawBenchDataset,
|
||||||
|
path='./data/lawbench/zero_shot',
|
||||||
|
index=index,
|
||||||
|
reader_cfg=lawbench_reader_cfg,
|
||||||
|
infer_cfg=lawbench_infer_cfg,
|
||||||
|
eval_cfg=lawbench_eval_cfg
|
||||||
|
)
|
||||||
|
)
|
11
configs/eval_qwen_7b_chat_lawbench.py
Normal file
11
configs/eval_qwen_7b_chat_lawbench.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
from .models.qwen.hf_qwen_7b_chat import models
|
||||||
|
from .datasets.lawbench.lawbench_zero_shot_gen_002588 import lawbench_datasets as lawbench_zero_shot_datasets
|
||||||
|
from .datasets.lawbench.lawbench_one_shot_gen_002588 import lawbench_datasets as lawbench_one_shot_datasets
|
||||||
|
from .summarizers.lawbench import summarizer
|
||||||
|
|
||||||
|
datasets = lawbench_zero_shot_datasets + lawbench_one_shot_datasets
|
||||||
|
for d in datasets:
|
||||||
|
d["infer_cfg"]["inferencer"]["save_every"] = 1
|
29
configs/summarizers/groups/lawbench.py
Normal file
29
configs/summarizers/groups/lawbench.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
names = [
|
||||||
|
["1-1", "article_recitation"],
|
||||||
|
["1-2", "knowledge_question_answering"],
|
||||||
|
["2-1", "document_proofreading"],
|
||||||
|
["2-2", "dispute_focus_identification"],
|
||||||
|
["2-3", "marital_disputes_identification"],
|
||||||
|
["2-4", "issue_topic_identification"],
|
||||||
|
["2-5", "reading_comprehension"],
|
||||||
|
["2-6", "named_entity_recognition"],
|
||||||
|
["2-7", "opinion_summarization"],
|
||||||
|
["2-8", "argument_mining"],
|
||||||
|
["2-9", "event_detection"],
|
||||||
|
["2-10", "trigger_word_extraction"],
|
||||||
|
["3-1", "fact_based_article_prediction"],
|
||||||
|
["3-2", "scene_based_article_prediction"],
|
||||||
|
["3-3", "charge_prediction"],
|
||||||
|
["3-4", "prison_term_prediction_wo_article"],
|
||||||
|
["3-5", "prison_term_prediction_w_article"],
|
||||||
|
["3-6", "case_analysis"],
|
||||||
|
["3-7", "criminal_damages_calculation"],
|
||||||
|
["3-8", "consultation"],
|
||||||
|
]
|
||||||
|
|
||||||
|
lawbench_summary_groups = []
|
||||||
|
|
||||||
|
_lawbench_0_shot = ['lawbench-' + index + '-' + name + '-0-shot' for index, name in names]
|
||||||
|
lawbench_summary_groups.append({'name': 'lawbench-0-shot', 'subsets': _lawbench_0_shot})
|
||||||
|
_lawbench_1_shot = ['lawbench-' + index + '-' + name + '-1-shot' for index, name in names]
|
||||||
|
lawbench_summary_groups.append({'name': 'lawbench-1-shot', 'subsets': _lawbench_1_shot})
|
58
configs/summarizers/lawbench.py
Normal file
58
configs/summarizers/lawbench.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
from .groups.lawbench import lawbench_summary_groups
|
||||||
|
|
||||||
|
summarizer = dict(
|
||||||
|
dataset_abbrs = [
|
||||||
|
'--------- 0-shot ---------', # category
|
||||||
|
'lawbench-0-shot',
|
||||||
|
'lawbench-1-1-article_recitation-0-shot',
|
||||||
|
'lawbench-1-2-knowledge_question_answering-0-shot',
|
||||||
|
'lawbench-2-1-document_proofreading-0-shot',
|
||||||
|
'lawbench-2-2-dispute_focus_identification-0-shot',
|
||||||
|
'lawbench-2-3-marital_disputes_identification-0-shot',
|
||||||
|
'lawbench-2-4-issue_topic_identification-0-shot',
|
||||||
|
'lawbench-2-5-reading_comprehension-0-shot',
|
||||||
|
'lawbench-2-6-named_entity_recognition-0-shot',
|
||||||
|
'lawbench-2-7-opinion_summarization-0-shot',
|
||||||
|
'lawbench-2-8-argument_mining-0-shot',
|
||||||
|
'lawbench-2-9-event_detection-0-shot',
|
||||||
|
'lawbench-2-10-trigger_word_extraction-0-shot',
|
||||||
|
'lawbench-3-1-fact_based_article_prediction-0-shot',
|
||||||
|
'lawbench-3-2-scene_based_article_prediction-0-shot',
|
||||||
|
'lawbench-3-3-charge_prediction-0-shot',
|
||||||
|
'lawbench-3-4-prison_term_prediction_wo_article-0-shot',
|
||||||
|
'lawbench-3-5-prison_term_prediction_w_article-0-shot',
|
||||||
|
'lawbench-3-6-case_analysis-0-shot',
|
||||||
|
'lawbench-3-7-criminal_damages_calculation-0-shot',
|
||||||
|
'lawbench-3-8-consultation-0-shot',
|
||||||
|
'--------- 1-shot ---------', # category
|
||||||
|
'lawbench-1-shot',
|
||||||
|
'lawbench-1-1-article_recitation-1-shot',
|
||||||
|
'lawbench-1-2-knowledge_question_answering-1-shot',
|
||||||
|
'lawbench-2-1-document_proofreading-1-shot',
|
||||||
|
'lawbench-2-2-dispute_focus_identification-1-shot',
|
||||||
|
'lawbench-2-3-marital_disputes_identification-1-shot',
|
||||||
|
'lawbench-2-4-issue_topic_identification-1-shot',
|
||||||
|
'lawbench-2-5-reading_comprehension-1-shot',
|
||||||
|
'lawbench-2-6-named_entity_recognition-1-shot',
|
||||||
|
'lawbench-2-7-opinion_summarization-1-shot',
|
||||||
|
'lawbench-2-8-argument_mining-1-shot',
|
||||||
|
'lawbench-2-9-event_detection-1-shot',
|
||||||
|
'lawbench-2-10-trigger_word_extraction-1-shot',
|
||||||
|
'lawbench-3-1-fact_based_article_prediction-1-shot',
|
||||||
|
'lawbench-3-2-scene_based_article_prediction-1-shot',
|
||||||
|
'lawbench-3-3-charge_prediction-1-shot',
|
||||||
|
'lawbench-3-4-prison_term_prediction_wo_article-1-shot',
|
||||||
|
'lawbench-3-5-prison_term_prediction_w_article-1-shot',
|
||||||
|
'lawbench-3-6-case_analysis-1-shot',
|
||||||
|
'lawbench-3-7-criminal_damages_calculation-1-shot',
|
||||||
|
'lawbench-3-8-consultation-1-shot',
|
||||||
|
],
|
||||||
|
summary_groups=sum([v for k, v in locals().items() if k.endswith("_summary_groups")], []),
|
||||||
|
prompt_db=dict(
|
||||||
|
database_path='configs/datasets/log.json',
|
||||||
|
config_dir='configs/datasets',
|
||||||
|
blacklist='.promptignore'),
|
||||||
|
)
|
@ -40,6 +40,7 @@ from .iwslt2017 import * # noqa: F401, F403
|
|||||||
from .jigsawmultilingual import * # noqa: F401, F403
|
from .jigsawmultilingual import * # noqa: F401, F403
|
||||||
from .kaoshi import KaoshiDataset, KaoshiEvaluator # noqa: F401, F403
|
from .kaoshi import KaoshiDataset, KaoshiEvaluator # noqa: F401, F403
|
||||||
from .lambada import * # noqa: F401, F403
|
from .lambada import * # noqa: F401, F403
|
||||||
|
from .lawbench import * # noqa: F401, F403
|
||||||
from .lcsts import * # noqa: F401, F403
|
from .lcsts import * # noqa: F401, F403
|
||||||
from .leval import * # noqa: F401, F403
|
from .leval import * # noqa: F401, F403
|
||||||
from .longbench import * # noqa: F401, F403
|
from .longbench import * # noqa: F401, F403
|
||||||
|
1
opencompass/datasets/lawbench/__init__.py
Normal file
1
opencompass/datasets/lawbench/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .lawbench import LawBenchDataset # noqa: F401
|
19
opencompass/datasets/lawbench/evaluation_functions/cjft.py
Normal file
19
opencompass/datasets/lawbench/evaluation_functions/cjft.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from ..utils.function_utils import compute_rouge
|
||||||
|
|
||||||
|
#情景法条识别
|
||||||
|
|
||||||
|
def compute_cjft(data_dict):
|
||||||
|
"""
|
||||||
|
Compute the ROUGE-L score between the prediction and the reference
|
||||||
|
"""
|
||||||
|
references, predictions = [], []
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
predictions.append(prediction)
|
||||||
|
references.append(answer)
|
||||||
|
|
||||||
|
# compute the accuracy of score_list
|
||||||
|
rouge_scores = compute_rouge(predictions, references)
|
||||||
|
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
|
||||||
|
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
|
||||||
|
return {"score": average_rouge_l}
|
18
opencompass/datasets/lawbench/evaluation_functions/flzx.py
Normal file
18
opencompass/datasets/lawbench/evaluation_functions/flzx.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from ..utils.function_utils import compute_rouge
|
||||||
|
|
||||||
|
#法律咨询
|
||||||
|
def compute_flzx(data_dict):
|
||||||
|
"""
|
||||||
|
Compute the ROUGE-L score between the prediction and the reference
|
||||||
|
"""
|
||||||
|
references, predictions = [], []
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
predictions.append(prediction)
|
||||||
|
references.append(answer)
|
||||||
|
|
||||||
|
# compute the accuracy of score_list
|
||||||
|
rouge_scores = compute_rouge(predictions, references)
|
||||||
|
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
|
||||||
|
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
|
||||||
|
return {"score": average_rouge_l}
|
19
opencompass/datasets/lawbench/evaluation_functions/ftcs.py
Normal file
19
opencompass/datasets/lawbench/evaluation_functions/ftcs.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from ..utils.function_utils import compute_rouge
|
||||||
|
|
||||||
|
#法条记忆问答
|
||||||
|
def compute_ftcs(data_dict):
|
||||||
|
"""
|
||||||
|
Compute the ROUGE-L score between the prediction and the reference
|
||||||
|
"""
|
||||||
|
references, predictions = [], []
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
answer = answer.replace("答案:", "")
|
||||||
|
predictions.append(prediction)
|
||||||
|
references.append(answer)
|
||||||
|
|
||||||
|
# compute the accuracy of score_list
|
||||||
|
rouge_scores = compute_rouge(predictions, references)
|
||||||
|
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
|
||||||
|
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
|
||||||
|
return {"score": average_rouge_l}
|
36
opencompass/datasets/lawbench/evaluation_functions/jdzy.py
Normal file
36
opencompass/datasets/lawbench/evaluation_functions/jdzy.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from ..utils.function_utils import multi_choice_judge
|
||||||
|
|
||||||
|
"""
|
||||||
|
multi-choice single-label selection
|
||||||
|
metric: accuracy
|
||||||
|
争议焦点:识别案件涉及的争议焦点
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_jdzy(data_dict):
|
||||||
|
"""
|
||||||
|
Compute the Accuracy
|
||||||
|
The JEC dataset has 16 possible answers for each question, stored in the option_list
|
||||||
|
A prediction is correct if
|
||||||
|
1. The correct answer appears in the prediction, and
|
||||||
|
2. Options other than the answer do not appear in the prediction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
score_list, abstentions = [], 0
|
||||||
|
option_list = ["诉讼主体", "租金情况", "利息", "本金争议", "责任认定", "责任划分", "损失认定及处理",
|
||||||
|
"原审判决是否适当", "合同效力", "财产分割", "责任承担", "鉴定结论采信问题", "诉讼时效", "违约", "合同解除", "肇事逃逸"]
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
if answer[7:-1] == "赔偿":
|
||||||
|
# todo: dataset imperfection
|
||||||
|
continue
|
||||||
|
assert answer.startswith("争议焦点类别:") and answer[7:-1] in option_list, \
|
||||||
|
f"answer: {answer} \n question: {question}"
|
||||||
|
|
||||||
|
answer_letter = answer[7:-1]
|
||||||
|
judge = multi_choice_judge(prediction, option_list, answer_letter)
|
||||||
|
score_list.append(judge["score"])
|
||||||
|
abstentions += judge["abstention"]
|
||||||
|
|
||||||
|
# compute the accuracy of score_list
|
||||||
|
accuracy = sum(score_list) / len(score_list)
|
||||||
|
return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)}
|
29
opencompass/datasets/lawbench/evaluation_functions/jec_ac.py
Normal file
29
opencompass/datasets/lawbench/evaluation_functions/jec_ac.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from ..utils.function_utils import multi_choice_judge
|
||||||
|
|
||||||
|
"""
|
||||||
|
Task: multi-choice selection
|
||||||
|
Metric: Accuracy
|
||||||
|
司法考试-案例分析
|
||||||
|
"""
|
||||||
|
def compute_jec_ac(data_dict):
|
||||||
|
"""
|
||||||
|
Compute the Accuracy
|
||||||
|
The JEC dataset has 4 options for each question: A, B, C, D
|
||||||
|
A prediction is correct if
|
||||||
|
1. The correct answer appears in the prediction, and
|
||||||
|
2. Options other than the answer do not appear in the prediction.
|
||||||
|
"""
|
||||||
|
score_list, abstentions = [], 0
|
||||||
|
option_list = ["A", "B", "C", "D"]
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
assert answer.startswith("正确答案:") and answer[5] in option_list, f"answer[5]: {answer}, question: {question}"
|
||||||
|
|
||||||
|
answer_letter = answer[5]
|
||||||
|
judge = multi_choice_judge(prediction, option_list, answer_letter)
|
||||||
|
score_list.append(judge["score"])
|
||||||
|
abstentions += judge["abstention"]
|
||||||
|
|
||||||
|
# compute the accuracy of score_list
|
||||||
|
accuracy = sum(score_list) / len(score_list)
|
||||||
|
return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)}
|
29
opencompass/datasets/lawbench/evaluation_functions/jec_kd.py
Normal file
29
opencompass/datasets/lawbench/evaluation_functions/jec_kd.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from ..utils.function_utils import multi_choice_judge
|
||||||
|
|
||||||
|
"""
|
||||||
|
Task: multi-choice selection
|
||||||
|
Metric: Accuracy
|
||||||
|
司法考试
|
||||||
|
"""
|
||||||
|
def compute_jec_kd(data_dict):
|
||||||
|
"""
|
||||||
|
Compute the Accuracy
|
||||||
|
The JEC_KD dataset has 4 options for each question: A, B, C, D
|
||||||
|
A prediction is correct if
|
||||||
|
1. The correct answer appears in the prediction, and
|
||||||
|
2. Options other than the answer do not appear in the prediction.
|
||||||
|
"""
|
||||||
|
score_list, abstentions = [], 0
|
||||||
|
option_list = ["A", "B", "C", "D"]
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
assert answer.startswith("正确答案:") and answer[5] in option_list, f"answer[5]: {answer}, question: {question}"
|
||||||
|
|
||||||
|
answer_letter = answer[5]
|
||||||
|
judge = multi_choice_judge(prediction, option_list, answer_letter)
|
||||||
|
score_list.append(judge["score"])
|
||||||
|
abstentions += judge["abstention"]
|
||||||
|
|
||||||
|
# compute the accuracy of score_list
|
||||||
|
accuracy = sum(score_list) / len(score_list)
|
||||||
|
return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)}
|
43
opencompass/datasets/lawbench/evaluation_functions/jetq.py
Normal file
43
opencompass/datasets/lawbench/evaluation_functions/jetq.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
"""
|
||||||
|
number prediction
|
||||||
|
metric: accuracy
|
||||||
|
金额提取
|
||||||
|
"""
|
||||||
|
def compute_jetq(data_dict):
|
||||||
|
"""
|
||||||
|
Compute the Accuracy
|
||||||
|
we extract the total amount of cost involved in the crime from the prediction and compare it with the reference
|
||||||
|
The prediction is correct if
|
||||||
|
the total amount of cost provided in the reference, appears in the prediction.
|
||||||
|
"""
|
||||||
|
score_list, abstentions = [], 0
|
||||||
|
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
assert answer.startswith("上文涉及到的犯罪金额:"), f"answer: {answer}, question: {question}"
|
||||||
|
assert answer.endswith("元。"), f"answer: {answer}, question: {question}"
|
||||||
|
answer = answer.replace("上文涉及到的犯罪金额:", "")
|
||||||
|
|
||||||
|
assert "千元" not in answer, f"answer: {answer}, question: {question}"
|
||||||
|
assert "万" not in answer, f"answer: {answer}, question: {question}"
|
||||||
|
|
||||||
|
# remove "元"
|
||||||
|
answer = answer.replace("元。", "")
|
||||||
|
answer = float(answer)
|
||||||
|
|
||||||
|
prediction_digits = re.findall(r"\d+\.?\d*", prediction)
|
||||||
|
prediction_digits = [float(digit) for digit in prediction_digits]
|
||||||
|
|
||||||
|
if len(prediction_digits) == 0:
|
||||||
|
abstentions += 1
|
||||||
|
if answer in prediction_digits:
|
||||||
|
score_list.append(1)
|
||||||
|
else:
|
||||||
|
score_list.append(0)
|
||||||
|
|
||||||
|
|
||||||
|
# compute the accuracy of score_list
|
||||||
|
accuracy = sum(score_list) / len(score_list)
|
||||||
|
return {"score": accuracy, "abstention_rate": abstentions/len(data_dict)}
|
29
opencompass/datasets/lawbench/evaluation_functions/lblj.py
Normal file
29
opencompass/datasets/lawbench/evaluation_functions/lblj.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from ..utils.function_utils import multi_choice_judge
|
||||||
|
|
||||||
|
"""
|
||||||
|
Task: multi-choice selection
|
||||||
|
Metric: Accuracy
|
||||||
|
论辩挖掘
|
||||||
|
"""
|
||||||
|
def compute_lblj(data_dict):
|
||||||
|
"""
|
||||||
|
Compute the Accuracy
|
||||||
|
The LBLJ dataset has 5 options for each question: A, B, C, D, E
|
||||||
|
A prediction is correct if
|
||||||
|
1. The correct answer appears in the prediction, and
|
||||||
|
2. Options other than the answer do not appear in the prediction.
|
||||||
|
"""
|
||||||
|
score_list, abstentions = [], 0
|
||||||
|
option_list = ["A", "B", "C", "D", "E"]
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
assert answer.startswith("[正确答案]") and answer[6] in option_list, f"answer[6]: {answer}, question: {question}"
|
||||||
|
|
||||||
|
answer_letter = answer[6]
|
||||||
|
judge = multi_choice_judge(prediction, option_list, answer_letter)
|
||||||
|
score_list.append(judge["score"])
|
||||||
|
abstentions += judge["abstention"]
|
||||||
|
|
||||||
|
# compute the accuracy of score_list
|
||||||
|
accuracy = sum(score_list) / len(score_list)
|
||||||
|
return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)}
|
@ -0,0 +1,76 @@
|
|||||||
|
from ..utils.function_utils import compute_f1_two_sets
|
||||||
|
"""
|
||||||
|
task: legal accusation prediction
|
||||||
|
metric: f1 score
|
||||||
|
法律判决预测-罪名预测
|
||||||
|
"""
|
||||||
|
|
||||||
|
option_list = ["侮辱", "违法发放贷款", "失火", "票据诈骗", "帮助犯罪分子逃避处罚", "重大责任事故", "对非国家工作人员行贿",
|
||||||
|
"非法制造、销售非法制造的注册商标标识", "非法制造、买卖、运输、邮寄、储存枪支、弹药、爆炸物", "非法获取公民个人信息",
|
||||||
|
"扰乱无线电通讯管理秩序", "非法持有、私藏枪支、弹药", "拒不执行判决、裁定", "虚开发票", "巨额财产来源不明",
|
||||||
|
"组织、领导、参加黑社会性质组织", "非法获取国家秘密", "以危险方法危害公共安全", "非法持有毒品",
|
||||||
|
"聚众扰乱公共场所秩序、交通秩序", "包庇毒品犯罪分子", "滥伐林木", "伪造公司、企业、事业单位、人民团体印章",
|
||||||
|
"非法占用农用地", "走私废物", "串通投标", "非法采伐、毁坏国家重点保护植物", "冒充军人招摇撞骗", "玩忽职守",
|
||||||
|
"重婚", "招收公务员、学生徇私舞弊", "组织、领导传销活动", "非法猎捕、杀害珍贵、濒危野生动物", "侵犯著作权",
|
||||||
|
"非法种植毒品原植物", "伪造、变造、买卖武装部队公文、证件、印章", "倒卖文物", "伪造、变造居民身份证", "滥用职权",
|
||||||
|
"诽谤", "猥亵儿童", "非法转让、倒卖土地使用权", "挪用公款", "污染环境", "出售、购买、运输假币", "敲诈勒索",
|
||||||
|
"高利转贷", "故意伤害", "持有、使用假币", "单位受贿", "强奸", "引诱、容留、介绍卖淫", "虐待",
|
||||||
|
"生产、销售伪劣农药、兽药、化肥、种子", "妨害公务", "容留他人吸毒", "拐骗儿童", "强制猥亵、侮辱妇女",
|
||||||
|
"非法处置查封、扣押、冻结的财产", "骗取贷款、票据承兑、金融票证", "强迫他人吸毒", "非法拘禁",
|
||||||
|
"非法携带枪支、弹药、管制刀具、危险物品危及公共安全", "绑架", "聚众斗殴", "破坏计算机信息系统",
|
||||||
|
"制造、贩卖、传播淫秽物品", "虐待被监管人", "贷款诈骗", "赌博", "徇私舞弊不征、少征税款",
|
||||||
|
"盗窃、抢夺枪支、弹药、爆炸物、危险物质", "故意杀人", "介绍贿赂", "提供侵入、非法控制计算机信息系统程序、工具",
|
||||||
|
"编造、故意传播虚假恐怖信息", "妨害作证", "强迫卖淫", "走私、贩卖、运输、制造毒品", "伪证", "拐卖妇女、儿童",
|
||||||
|
"过失损坏武器装备、军事设施、军事通信", "破坏广播电视设施、公用电信设施", "洗钱", "职务侵占", "倒卖车票、船票",
|
||||||
|
"抢劫", "侵占", "掩饰、隐瞒犯罪所得、犯罪所得收益", "徇私舞弊不移交刑事案件", "引诱、教唆、欺骗他人吸毒", "遗弃",
|
||||||
|
"生产、销售伪劣产品", "放火", "非法采矿", "对单位行贿", "盗窃、抢夺枪支、弹药、爆炸物", "破坏易燃易爆设备",
|
||||||
|
"妨害信用卡管理", "制作、复制、出版、贩卖、传播淫秽物品牟利", "金融凭证诈骗", "私分国有资产",
|
||||||
|
"走私国家禁止进出口的货物、物品", "假冒注册商标", "危险物品肇事", "走私普通货物、物品", "经济犯", "虚报注册资本",
|
||||||
|
"盗掘古文化遗址、古墓葬", "传播淫秽物品", "窝藏、包庇", "拒不支付劳动报酬", "行贿", "开设赌场", "传授犯罪方法",
|
||||||
|
"协助组织卖淫", "保险诈骗", "破坏生产经营", "破坏交通设施", "打击报复证人", "非法侵入住宅", "非国家工作人员受贿",
|
||||||
|
"过失致人重伤", "伪造、变造金融票证", "窝藏、转移、隐瞒毒品、毒赃", "帮助毁灭、伪造证据", "走私珍贵动物、珍贵动物制品",
|
||||||
|
"生产、销售假药", "逃税", "挪用特定款物", "聚众扰乱社会秩序", "组织、强迫、引诱、容留、介绍卖淫", "合同诈骗",
|
||||||
|
"非法生产、销售间谍专用器材", "破坏交通工具", "传播性病", "强迫交易", "隐匿、故意销毁会计凭证、会计帐簿、财务会计报告",
|
||||||
|
"非法组织卖血", "强迫劳动", "破坏电力设备", "销售假冒注册商标的商品", "收买被拐卖的妇女、儿童", "诬告陷害", "脱逃",
|
||||||
|
"非法经营", "徇私枉法", "信用卡诈骗", "生产、销售不符合安全标准的食品", "非法行医", "伪造货币", "动植物检疫徇私舞弊",
|
||||||
|
"单位行贿", "破坏监管秩序", "盗窃", "盗伐林木", "重大劳动安全事故", "非法吸收公众存款",
|
||||||
|
"非法制造、出售非法制造的发票", "非法狩猎", "组织卖淫", "非法买卖、运输、携带、持有毒品原植物种子、幼苗", "挪用资金",
|
||||||
|
"诈骗", "伪造、变造、买卖国家机关公文、证件、印章", "持有伪造的发票", "贪污", "非法生产、买卖警用装备",
|
||||||
|
"投放危险物质", "伪造、倒卖伪造的有价票证", "集资诈骗", "抢夺", "生产、销售有毒、有害食品", "非法捕捞水产品",
|
||||||
|
"过失致人死亡", "非法买卖制毒物品", "虚开增值税专用发票、用于骗取出口退税、抵扣税款发票", "寻衅滋事", "危险驾驶",
|
||||||
|
"故意毁坏财物", "招摇撞骗", "盗窃、侮辱尸体", "走私武器、弹药",
|
||||||
|
"非法收购、运输、加工、出售国家重点保护植物、国家重点保护植物制品", "非法出售发票", "劫持船只、汽车",
|
||||||
|
"受贿", "聚众哄抢", "交通肇事"]
|
||||||
|
|
||||||
|
|
||||||
|
def compute_ljp_accusation(data_dict):
|
||||||
|
"""
|
||||||
|
Compute the F1-score
|
||||||
|
The LJP_Accusation dataset a set of 189 different accusation types.
|
||||||
|
A question may involve one or more accusation types.
|
||||||
|
Given a list of accusation types from both the ground truth and the prediction, we compute the F1-score between
|
||||||
|
these two lists.
|
||||||
|
"""
|
||||||
|
score_list, abstentions = [], 0
|
||||||
|
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
|
||||||
|
assert answer.startswith("罪名:"), f"answer: {answer} \n question: {question}"
|
||||||
|
answer = answer.replace("罪名:", "")
|
||||||
|
answers = answer.split(";")
|
||||||
|
|
||||||
|
prediction_list =[]
|
||||||
|
for option in option_list:
|
||||||
|
if option in prediction:
|
||||||
|
prediction_list.append(option)
|
||||||
|
|
||||||
|
if len(prediction_list) == 0:
|
||||||
|
abstentions += 1
|
||||||
|
gt_set = set(answers)
|
||||||
|
pred_set = set(prediction_list)
|
||||||
|
score = compute_f1_two_sets(gt_set, pred_set)
|
||||||
|
score_list.append(score)
|
||||||
|
|
||||||
|
f1_score_average = sum(score_list) / len(score_list)
|
||||||
|
return {"score": f1_score_average, "abstention_rate": abstentions/len(data_dict)}
|
@ -0,0 +1,70 @@
|
|||||||
|
import re
|
||||||
|
import cn2an
|
||||||
|
|
||||||
|
"""
|
||||||
|
task: law article prediction
|
||||||
|
metric: F1 score
|
||||||
|
法律判决预测-法条预测
|
||||||
|
"""
|
||||||
|
def replace_match(match):
|
||||||
|
return match.group(1)
|
||||||
|
|
||||||
|
def compute_ljp_article(data_dict):
|
||||||
|
"""
|
||||||
|
Compute the F1-score
|
||||||
|
A reference contains a list of articles of the Criminal Law of the People's Republic of China.
|
||||||
|
We compute the F1-score between the prediction and the reference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
score_list, abstentions = [], 0
|
||||||
|
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
assert answer.startswith("法条:刑法第"), f"answer: {answer}"
|
||||||
|
assert answer.endswith("条"), f"answer: {answer}"
|
||||||
|
|
||||||
|
answer = answer.replace("法条:刑法第", "")
|
||||||
|
answer = answer.replace("条", "")
|
||||||
|
|
||||||
|
answer_law_indices = answer.split("、")
|
||||||
|
answer_law_index_digit_list = []
|
||||||
|
for answer_law_index in answer_law_indices:
|
||||||
|
assert answer_law_index.isdigit(), f"answer_law_index: {answer_law_index}"
|
||||||
|
answer_law_index_digit = int(answer_law_index)
|
||||||
|
assert answer_law_index_digit <= 490, "刑法总共只有490条"
|
||||||
|
answer_law_index_digit_list.append(answer_law_index_digit)
|
||||||
|
|
||||||
|
prediction_law_chunks = prediction.split("、")
|
||||||
|
prediction_law_index_digit_list = []
|
||||||
|
|
||||||
|
for prediction_law_chunk in prediction_law_chunks:
|
||||||
|
prediction_law_chunk = prediction_law_chunk.replace("万元", "元")
|
||||||
|
|
||||||
|
# delete phrase starts with "第" and ends with "款", we don't care about it in the answer
|
||||||
|
prediction_law_chunk = re.sub(r'第(.*?)款', "", prediction_law_chunk)
|
||||||
|
# keep only the digits in the phrase starts with "第" and ends with "条", otherwise cn may fail to convert
|
||||||
|
prediction_law_chunk = re.sub(r'第(.*?)条', replace_match, prediction_law_chunk)
|
||||||
|
prediction_law_chunk = cn2an.transform(prediction_law_chunk, "cn2an")
|
||||||
|
# find digtis in prediction_law_chunk
|
||||||
|
prediction_law_section_numbers = re.findall(r"\d+", prediction_law_chunk)
|
||||||
|
if len(prediction_law_section_numbers) == 0:
|
||||||
|
continue
|
||||||
|
if len(prediction_law_section_numbers) != 1:
|
||||||
|
# in this case, we only take the first number, and reject the others
|
||||||
|
pass
|
||||||
|
|
||||||
|
prediction_law_index_digit = int(prediction_law_section_numbers[0])
|
||||||
|
prediction_law_index_digit_list.append(prediction_law_index_digit)
|
||||||
|
|
||||||
|
gt_set = set(answer_law_index_digit_list)
|
||||||
|
pred_set = set(prediction_law_index_digit_list)
|
||||||
|
if len(pred_set) == 0:
|
||||||
|
abstentions += 1
|
||||||
|
precision = len(gt_set.intersection(pred_set)) / len(pred_set) if len(pred_set) != 0 else 0
|
||||||
|
recall = len(gt_set.intersection(pred_set)) / len(gt_set) if len(gt_set) != 0 else 0
|
||||||
|
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0
|
||||||
|
score_list.append(f1_score)
|
||||||
|
|
||||||
|
# compute the accuracy of score_list
|
||||||
|
average_f1 = sum(score_list) / len(score_list)
|
||||||
|
return {'score': average_f1, 'abstention_rate': abstentions/len(data_dict)}
|
@ -0,0 +1,49 @@
|
|||||||
|
import math
|
||||||
|
import cn2an
|
||||||
|
import re
|
||||||
|
|
||||||
|
#法律判决预测-刑期预测
|
||||||
|
def compute_ljp_imprison(data_dict):
|
||||||
|
score_list, abstentions = [], 0
|
||||||
|
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
# get answer digit, which is the number between "刑期:" and "个月"
|
||||||
|
if "死刑" in answer or "无期" in answer:
|
||||||
|
# TODO: data imperfection
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert answer.startswith("刑期:") and answer.endswith("个月"), f"answer: {answer}, question: {question}"
|
||||||
|
answer = answer.replace("刑期:", "")
|
||||||
|
answer = answer.replace("个月", "")
|
||||||
|
answer_digit = int(answer)
|
||||||
|
prediction = cn2an.transform(prediction, "cn2an")
|
||||||
|
|
||||||
|
# use regular expression to extract the digits from prediction, only consider digits before "个月" or "月"
|
||||||
|
prediction_digit_month_list = re.findall(r"\d+个月", prediction)
|
||||||
|
prediction_digit_month_list = [int(digit.replace("个月", "")) for digit in prediction_digit_month_list]
|
||||||
|
prediction_digit_month_list2 = re.findall(r"\d+月", prediction)
|
||||||
|
prediction_digit_month_list2 = [int(digit.replace("月", "")) for digit in prediction_digit_month_list2]
|
||||||
|
prediction_digit_month_list.extend(prediction_digit_month_list2)
|
||||||
|
# catches the digits before "年"
|
||||||
|
prediction_digit_year_list = re.findall(r"\d+年", prediction)
|
||||||
|
prediction_digit_year_list = [int(digit.replace("年", "")) for digit in prediction_digit_year_list]
|
||||||
|
|
||||||
|
if len(prediction_digit_month_list) > 0:
|
||||||
|
prediction_digit_month = int(prediction_digit_month_list[0])
|
||||||
|
elif len(prediction_digit_year_list) > 0:
|
||||||
|
prediction_digit_month = int(prediction_digit_year_list[0]) * 12
|
||||||
|
else:
|
||||||
|
abstentions += 1
|
||||||
|
prediction_digit_month = -1
|
||||||
|
|
||||||
|
if prediction_digit_month != -1:
|
||||||
|
score_list.append(abs(math.log(answer_digit + 1) - math.log(prediction_digit_month + 1)))
|
||||||
|
else:
|
||||||
|
score_list.append(math.log(216))
|
||||||
|
|
||||||
|
# compute the average of score_list (log distance)
|
||||||
|
log_distance = sum(score_list) / len(score_list)
|
||||||
|
# normalize the score to between 0 and 1
|
||||||
|
log_distance = (math.log(216) - log_distance)/math.log(216)
|
||||||
|
return {"score": log_distance, "abstention_rate": abstentions/len(data_dict)}
|
64
opencompass/datasets/lawbench/evaluation_functions/sjjc.py
Normal file
64
opencompass/datasets/lawbench/evaluation_functions/sjjc.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from ..utils.function_utils import compute_f1_two_sets
|
||||||
|
from ..utils.rc_f1 import CJRCEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
task: event detection
|
||||||
|
metric: F1 score
|
||||||
|
事件检测
|
||||||
|
"""
|
||||||
|
option_list = ["支付/给付", "欺骗", "搜查/扣押", "要求/请求", "卖出", "买入", "获利", "拘捕", "鉴定", "同意/接受", "供述", "联络", "帮助/救助", "租用/借用", "受伤", "伪造", "卖淫", "伤害人身", "赔偿", "归还/偿还"]
|
||||||
|
|
||||||
|
def compute_sjjc(data_dict):
|
||||||
|
"""
|
||||||
|
Compute the F1-score
|
||||||
|
The sjjc task covers 20 event types.
|
||||||
|
A question may involve one or more event types.
|
||||||
|
Given a list of event types from both the ground truth and the prediction, we compute the F1-score between
|
||||||
|
these two lists.
|
||||||
|
"""
|
||||||
|
score_list, abstentions = [], 0
|
||||||
|
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
|
||||||
|
answers = answer.split(";")
|
||||||
|
|
||||||
|
prediction_list =[]
|
||||||
|
for option in option_list:
|
||||||
|
if option in prediction:
|
||||||
|
prediction_list.append(option)
|
||||||
|
|
||||||
|
if len(prediction_list) == 0:
|
||||||
|
abstentions += 1
|
||||||
|
gt_set = set(answers)
|
||||||
|
pred_set = set(prediction_list)
|
||||||
|
score = compute_f1_two_sets(gt_set, pred_set)
|
||||||
|
score_list.append(score)
|
||||||
|
|
||||||
|
f1_score_average = sum(score_list) / len(score_list)
|
||||||
|
return {"score": f1_score_average, "abstention_rate": abstentions/len(data_dict)}
|
||||||
|
|
||||||
|
"""
|
||||||
|
task: trigger word extraction
|
||||||
|
metric: F1 score
|
||||||
|
触发词抽取
|
||||||
|
"""
|
||||||
|
def compute_cfcy(data_dict):
|
||||||
|
|
||||||
|
scores = 0
|
||||||
|
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
|
||||||
|
answers = answer.split(";")
|
||||||
|
predictions = prediction.split(";")
|
||||||
|
intersected = [CJRCEvaluator.compute_f1(r, h) for r, h in zip(answers, predictions)]
|
||||||
|
|
||||||
|
prec = sum(intersected) / len(predictions) if len(predictions) > 0 else 0
|
||||||
|
rec = sum(intersected) / len(answers) if len(answers) > 0 else 0
|
||||||
|
# print(prec, rec, intersected)
|
||||||
|
scores += 2 * prec * rec / (prec + rec + 1e-10)
|
||||||
|
|
||||||
|
f1_score_average = scores / len(data_dict)
|
||||||
|
return {"score": f1_score_average}
|
42
opencompass/datasets/lawbench/evaluation_functions/wbfl.py
Normal file
42
opencompass/datasets/lawbench/evaluation_functions/wbfl.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
"""
|
||||||
|
task: multiple choice classification
|
||||||
|
metric: F1 score
|
||||||
|
婚姻文本分类
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_wbfl(data_dict):
|
||||||
|
"""
|
||||||
|
A reference (R) contains a list of options, each option is from the option_list.
|
||||||
|
We will extract the options appearing in the prediction and convert them into a set (P).
|
||||||
|
We compute the F1 score between the prediction (P) and the reference (R).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
score_list, abstentions = [], 0
|
||||||
|
option_list = ["婚后有子女", "限制行为能力子女抚养", "有夫妻共同财产", "支付抚养费", "不动产分割", "婚后分局",
|
||||||
|
"二次起诉离婚", "按月给付抚养费", "准予离婚", "有夫妻共同债务", "婚前个人财产", "法定离婚", "不履行家庭义务",
|
||||||
|
"存在非婚生子", "适当帮助", "不履行离婚协议", "损害赔偿", "感情不和分居满二年", "子女随非抚养权人生活", "婚后个人财产"]
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
assert answer.startswith("类别:") and answer.endswith("。"), f"answer: {answer}, question: {question}"
|
||||||
|
|
||||||
|
gt_list = (answer[3:-1].split("、"))
|
||||||
|
for gt in gt_list:
|
||||||
|
assert gt in option_list, f"gt: {gt}, question: {question}"
|
||||||
|
gt_set = set(gt_list)
|
||||||
|
|
||||||
|
prediction_list = []
|
||||||
|
for option in option_list:
|
||||||
|
if option in prediction:
|
||||||
|
prediction_list.append(option)
|
||||||
|
if len(prediction_list) == 0:
|
||||||
|
abstentions += 1
|
||||||
|
predict_set = set(prediction_list)
|
||||||
|
precision = len(gt_set.intersection(predict_set)) / len(predict_set) if len(predict_set) != 0 else 0
|
||||||
|
recall = len(gt_set.intersection(predict_set)) / len(gt_set) if len(gt_set) != 0 else 0
|
||||||
|
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0
|
||||||
|
score_list.append(f1_score)
|
||||||
|
|
||||||
|
# compute the accuracy of score_list
|
||||||
|
final_f1_score = sum(score_list) / len(score_list)
|
||||||
|
return {'score': final_f1_score, 'abstention_rate': abstentions / len(data_dict)}
|
50
opencompass/datasets/lawbench/evaluation_functions/wsjd.py
Normal file
50
opencompass/datasets/lawbench/evaluation_functions/wsjd.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import re
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
"""
|
||||||
|
Task: legal document grammar correction
|
||||||
|
Metric: F0.5 score
|
||||||
|
文书校对
|
||||||
|
"""
|
||||||
|
def compute_wsjd(data_dict):
|
||||||
|
origins, references, predictions = [], [], []
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
if isinstance(question, list):
|
||||||
|
question = question[0]['prompt']
|
||||||
|
start = question.index('句子:\n') + 4
|
||||||
|
origins.append(re.sub(r'\n|\t', '', question[start:].split('\n')[0]))
|
||||||
|
# truncate predictions >5 tokens longer than the reference
|
||||||
|
prediction = re.sub(r'\n|\t', '', prediction)
|
||||||
|
if len(prediction) - len(answer) > 5:
|
||||||
|
prediction = prediction[:len(answer) + 5]
|
||||||
|
if len(prediction) == 0:
|
||||||
|
prediction = "无内容"
|
||||||
|
predictions.append(prediction)
|
||||||
|
references.append(re.sub(r'\n|\t', '', answer))
|
||||||
|
|
||||||
|
#generate input files for ChERRANT
|
||||||
|
preds = [f'{i} \t {origin} \t {prediction} \n' for i, (origin, prediction) in enumerate(zip(origins, predictions))]
|
||||||
|
golds = [f'{i} \t {origin} \t {reference} \n' for i, (origin, reference) in enumerate(zip(origins, references))]
|
||||||
|
|
||||||
|
now_path = os.path.abspath(os.getcwd())
|
||||||
|
utils_path = os.path.abspath(os.path.join(__file__, '..', '..', 'utils'))
|
||||||
|
uid = os.getuid()
|
||||||
|
os.chdir(utils_path)
|
||||||
|
with open(f'/tmp/tmp_pred_{uid}.para', 'w') as f:
|
||||||
|
f.writelines(preds)
|
||||||
|
with open(f'/tmp/tmp_gold_{uid}.para', 'w') as f:
|
||||||
|
f.writelines(golds)
|
||||||
|
os.environ['KMP_DUPLICATE_LIB_OK']='True'
|
||||||
|
os.system(f'python3 parallel_to_m2.py -f /tmp/tmp_pred_{uid}.para -o /tmp/tmp_pred_{uid}.para.m2 -g char')
|
||||||
|
os.system(f'python3 parallel_to_m2.py -f /tmp/tmp_gold_{uid}.para -o /tmp/tmp_gold_{uid}.para.m2 -g char')
|
||||||
|
output = subprocess.check_output(f"python3 compare_m2_for_evaluation.py -hyp /tmp/tmp_pred_{uid}.para.m2 -ref /tmp/tmp_gold_{uid}.para.m2", shell = True)
|
||||||
|
score = float(output.decode().split('\t')[-1].split('\n')[0])
|
||||||
|
#remove prediction files
|
||||||
|
os.remove(f'/tmp/tmp_pred_{uid}.para')
|
||||||
|
os.remove(f'/tmp/tmp_gold_{uid}.para')
|
||||||
|
os.remove(f'/tmp/tmp_pred_{uid}.para.m2')
|
||||||
|
os.remove(f'/tmp/tmp_gold_{uid}.para.m2')
|
||||||
|
os.chdir(now_path)
|
||||||
|
return {"score": score}
|
17
opencompass/datasets/lawbench/evaluation_functions/xxcq.py
Normal file
17
opencompass/datasets/lawbench/evaluation_functions/xxcq.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from ..utils.comprehension_scores import compute_ie_f1
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
task: information extraction
|
||||||
|
metric: F1 score
|
||||||
|
信息抽取
|
||||||
|
"""
|
||||||
|
def compute_xxcq(data_dict):
|
||||||
|
references, predictions = [], []
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
predictions.append(prediction)
|
||||||
|
references.append(answer)
|
||||||
|
|
||||||
|
return compute_ie_f1(predictions, references, {"犯罪嫌疑人", "受害人", "被盗货币", "物品价值", "盗窃获利",
|
||||||
|
"被盗物品", "作案工具", "时间", "地点", "组织机构"})
|
17
opencompass/datasets/lawbench/evaluation_functions/ydlj.py
Normal file
17
opencompass/datasets/lawbench/evaluation_functions/ydlj.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from ..utils.comprehension_scores import compute_rc_f1
|
||||||
|
|
||||||
|
"""
|
||||||
|
Task: machine reading comprehension
|
||||||
|
Metric: F1 score
|
||||||
|
法律阅读理解
|
||||||
|
"""
|
||||||
|
def compute_ydlj(data_dict):
|
||||||
|
references, predictions = [], []
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
answer = answer.replace("回答:", "")
|
||||||
|
predictions.append(prediction)
|
||||||
|
references.append(answer)
|
||||||
|
|
||||||
|
f1_score = compute_rc_f1(predictions, references)
|
||||||
|
return f1_score
|
18
opencompass/datasets/lawbench/evaluation_functions/yqzy.py
Normal file
18
opencompass/datasets/lawbench/evaluation_functions/yqzy.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from ..utils.function_utils import compute_rouge
|
||||||
|
|
||||||
|
#舆情摘要
|
||||||
|
def compute_yqzy(data_dict):
|
||||||
|
"""
|
||||||
|
Compute the ROUGE-L score between the prediction and the reference
|
||||||
|
"""
|
||||||
|
references, predictions = [], []
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
predictions.append(prediction)
|
||||||
|
references.append(answer)
|
||||||
|
|
||||||
|
# compute the accuracy of score_list
|
||||||
|
rouge_scores = compute_rouge(predictions, references)
|
||||||
|
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
|
||||||
|
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
|
||||||
|
return {"score": average_rouge_l}
|
27
opencompass/datasets/lawbench/evaluation_functions/zxfl.py
Normal file
27
opencompass/datasets/lawbench/evaluation_functions/zxfl.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from ..utils.function_utils import multi_choice_judge
|
||||||
|
|
||||||
|
"""
|
||||||
|
task: multiple choice classification
|
||||||
|
metric: accuracy
|
||||||
|
咨询分类
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_zxfl(data_dict):
|
||||||
|
"""
|
||||||
|
A reference (R) contains a list of options, each option is from the option_list.
|
||||||
|
We will extract the options appearing in the prediction and convert them into a set (P).
|
||||||
|
We compute the accuracy between the prediction (P) and the reference (R).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
score_list, abstentions = [], 0
|
||||||
|
option_list = ['婚姻家庭', '劳动纠纷', '交通事故', '债权债务', '刑事辩护', '合同纠纷', '房产纠纷', '侵权', '公司法', '医疗纠纷', '拆迁安置', '行政诉讼', '建设工程', '知识产权', '综合咨询', '人身损害', '涉外法律', '海事海商', '消费权益', '抵押担保']
|
||||||
|
for example in data_dict:
|
||||||
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
||||||
|
judge = multi_choice_judge(prediction, option_list, answer)
|
||||||
|
score_list.append(judge["score"])
|
||||||
|
abstentions += judge["abstention"]
|
||||||
|
|
||||||
|
# compute the accuracy of score_list
|
||||||
|
final_accuracy_score = sum(score_list) / len(score_list)
|
||||||
|
return {'score': final_accuracy_score, 'abstention_rate': abstentions / len(data_dict)}
|
83
opencompass/datasets/lawbench/lawbench.py
Normal file
83
opencompass/datasets/lawbench/lawbench.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||||
|
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
||||||
|
|
||||||
|
from ..base import BaseDataset
|
||||||
|
from .evaluation_functions import (cjft, flzx, ftcs, jdzy, jec_ac, jec_kd,
|
||||||
|
jetq, lblj, ljp_accusation, ljp_article,
|
||||||
|
ljp_imprison, sjjc, wbfl, wsjd, xxcq, ydlj,
|
||||||
|
yqzy, zxfl)
|
||||||
|
|
||||||
|
|
||||||
|
@LOAD_DATASET.register_module()
|
||||||
|
class LawBenchDataset(BaseDataset):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(path: str, index: str) -> Dataset:
|
||||||
|
path = os.path.join(path, index + '.json')
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return Dataset.from_list(data)
|
||||||
|
|
||||||
|
|
||||||
|
funct_dict = {
|
||||||
|
'1-1': ftcs.compute_ftcs,
|
||||||
|
'1-2': jec_kd.compute_jec_kd,
|
||||||
|
'2-1': wsjd.compute_wsjd,
|
||||||
|
'2-2': jdzy.compute_jdzy,
|
||||||
|
'2-3': wbfl.compute_wbfl,
|
||||||
|
'2-4': zxfl.compute_zxfl,
|
||||||
|
'2-5': ydlj.compute_ydlj,
|
||||||
|
'2-6': xxcq.compute_xxcq,
|
||||||
|
'2-7': yqzy.compute_yqzy,
|
||||||
|
'2-8': lblj.compute_lblj,
|
||||||
|
'2-9': sjjc.compute_sjjc,
|
||||||
|
'2-10': sjjc.compute_cfcy,
|
||||||
|
'3-1': ljp_article.compute_ljp_article,
|
||||||
|
'3-2': cjft.compute_cjft,
|
||||||
|
'3-3': ljp_accusation.compute_ljp_accusation,
|
||||||
|
'3-4': ljp_imprison.compute_ljp_imprison,
|
||||||
|
'3-5': ljp_imprison.compute_ljp_imprison,
|
||||||
|
'3-6': jec_ac.compute_jec_ac,
|
||||||
|
'3-7': jetq.compute_jetq,
|
||||||
|
'3-8': flzx.compute_flzx,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LawBenchEvaluator(BaseEvaluator):
|
||||||
|
|
||||||
|
def __init__(self, index) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
def score(self, predictions, references, origin_prompt):
|
||||||
|
if len(predictions) != len(references):
|
||||||
|
return {
|
||||||
|
'error': 'predictions and references have different '
|
||||||
|
'length'
|
||||||
|
}
|
||||||
|
|
||||||
|
data_dict = [{
|
||||||
|
'origin_prompt': origin_prompt[i],
|
||||||
|
'prediction': predictions[i],
|
||||||
|
'refr': references[i],
|
||||||
|
} for i in range(len(predictions))]
|
||||||
|
scores = funct_dict[self.index](data_dict)
|
||||||
|
scores = {k: v * 100 for k, v in scores.items()}
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
for index in funct_dict:
|
||||||
|
# fix classic closure problem
|
||||||
|
def _register(index):
|
||||||
|
ICL_EVALUATORS.register_module(
|
||||||
|
name='LawBenchEvaluator_' + index.replace('-', '_'),
|
||||||
|
module=lambda *args, **kwargs: LawBenchEvaluator(
|
||||||
|
index=index, *args, **kwargs))
|
||||||
|
|
||||||
|
_register(index)
|
456
opencompass/datasets/lawbench/utils/char_smi.py
Normal file
456
opencompass/datasets/lawbench/utils/char_smi.py
Normal file
@ -0,0 +1,456 @@
|
|||||||
|
### Copy from https://github.com/iqiyi/FASPell ###
|
||||||
|
|
||||||
|
"""
|
||||||
|
Requirements:
|
||||||
|
- java (required only if tree edit distance is used)
|
||||||
|
- numpy
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from subprocess import Popen, PIPE, STDOUT
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
IDCS = {'\u2ff0': 2, # 12 ideographic description characters and their capacity of son nodes
|
||||||
|
'\u2ff1': 2,
|
||||||
|
'\u2ff2': 3,
|
||||||
|
'\u2ff3': 3,
|
||||||
|
'\u2ff4': 2,
|
||||||
|
'\u2ff5': 2,
|
||||||
|
'\u2ff6': 2,
|
||||||
|
'\u2ff7': 2,
|
||||||
|
'\u2ff8': 2,
|
||||||
|
'\u2ff9': 2,
|
||||||
|
'\u2ffa': 2,
|
||||||
|
'\u2ffb': 2, }
|
||||||
|
|
||||||
|
PINYIN = {'ā': ['a', 1], 'á': ['a', 2], 'ǎ': ['a', 3], 'à': ['a', 4],
|
||||||
|
'ē': ['e', 1], 'é': ['e', 2], 'ě': ['e', 3], 'è': ['e', 4],
|
||||||
|
'ī': ['i', 1], 'í': ['i', 2], 'ǐ': ['i', 3], 'ì': ['i', 4],
|
||||||
|
'ō': ['o', 1], 'ó': ['o', 2], 'ǒ': ['o', 3], 'ò': ['o', 4],
|
||||||
|
'ū': ['u', 1], 'ú': ['u', 2], 'ǔ': ['u', 3], 'ù': ['u', 4],
|
||||||
|
'ǖ': ['ü', 1], 'ǘ': ['ü', 2], 'ǚ': ['ü', 3], 'ǜ': ['ü', 4],
|
||||||
|
'': ['m', 2], 'ń': ['n', 2], 'ň': ['n', 3], 'ǹ': ['n', 4],
|
||||||
|
}
|
||||||
|
|
||||||
|
# APTED_JAR_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'apted.jar')
|
||||||
|
APTED_JAR_PATH = 'apted.jar'
|
||||||
|
|
||||||
|
|
||||||
|
def tree_edit_distance(tree_a, tree_b):
|
||||||
|
"""
|
||||||
|
We use APTED algorithm proposed by M. Pawlik and N. Augsten
|
||||||
|
github link: https://github.com/DatabaseGroup/apted
|
||||||
|
"""
|
||||||
|
p = Popen(['java', '-jar', APTED_JAR_PATH, '-t', tree_a, tree_b], stdout=PIPE, stderr=STDOUT)
|
||||||
|
|
||||||
|
res = [line for line in p.stdout]
|
||||||
|
res = res[0]
|
||||||
|
res = res.strip()
|
||||||
|
res = float(res)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def edit_distance(string_a, string_b, name='Levenshtein'):
|
||||||
|
"""
|
||||||
|
>>> edit_distance('abcde', 'avbcude')
|
||||||
|
2
|
||||||
|
>>> edit_distance(['至', '刂'], ['亻', '至', '刂'])
|
||||||
|
1
|
||||||
|
>>> edit_distance('fang', 'qwe')
|
||||||
|
4
|
||||||
|
>>> edit_distance('fang', 'hen')
|
||||||
|
3
|
||||||
|
"""
|
||||||
|
size_x = len(string_a) + 1
|
||||||
|
size_y = len(string_b) + 1
|
||||||
|
matrix = np.zeros((size_x, size_y), dtype=int)
|
||||||
|
for x in range(size_x):
|
||||||
|
matrix[x, 0] = x
|
||||||
|
for y in range(size_y):
|
||||||
|
matrix[0, y] = y
|
||||||
|
|
||||||
|
for x in range(1, size_x):
|
||||||
|
for y in range(1, size_y):
|
||||||
|
if string_a[x - 1] == string_b[y - 1]:
|
||||||
|
matrix[x, y] = min(
|
||||||
|
matrix[x - 1, y] + 1,
|
||||||
|
matrix[x - 1, y - 1],
|
||||||
|
matrix[x, y - 1] + 1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if name == 'Levenshtein':
|
||||||
|
matrix[x, y] = min(
|
||||||
|
matrix[x - 1, y] + 1,
|
||||||
|
matrix[x - 1, y - 1] + 1,
|
||||||
|
matrix[x, y - 1] + 1
|
||||||
|
)
|
||||||
|
else: # Canonical
|
||||||
|
matrix[x, y] = min(
|
||||||
|
matrix[x - 1, y] + 1,
|
||||||
|
matrix[x - 1, y - 1] + 2,
|
||||||
|
matrix[x, y - 1] + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
return matrix[size_x - 1, size_y - 1]
|
||||||
|
|
||||||
|
|
||||||
|
class CharFuncs(object):
|
||||||
|
def __init__(self, char_meta_fname):
|
||||||
|
self.data = self.load_char_meta(char_meta_fname)
|
||||||
|
self.char_dict = dict([(c, 0) for c in self.data])
|
||||||
|
|
||||||
|
self.safe = {'\u2ff0': 'A',
|
||||||
|
# to eliminate the bug that, in Windows CMD, char ⿻ and ⿵ are encoded to be the same.
|
||||||
|
'\u2ff1': 'B',
|
||||||
|
'\u2ff2': 'C',
|
||||||
|
'\u2ff3': 'D',
|
||||||
|
'\u2ff4': 'E',
|
||||||
|
'\u2ff5': 'F',
|
||||||
|
'\u2ff6': 'G',
|
||||||
|
'\u2ff7': 'H',
|
||||||
|
'\u2ff8': 'I',
|
||||||
|
'\u2ff9': 'J',
|
||||||
|
'\u2ffa': 'L',
|
||||||
|
'\u2ffb': 'M', }
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_char_meta(fname):
|
||||||
|
data = {}
|
||||||
|
f = open(fname, 'r', encoding='utf-8')
|
||||||
|
for line in f:
|
||||||
|
items = line.strip().split('\t')
|
||||||
|
code_point = items[0]
|
||||||
|
char = items[1]
|
||||||
|
pronunciation = items[2]
|
||||||
|
decompositions = items[3:]
|
||||||
|
assert char not in data
|
||||||
|
data[char] = {"code_point": code_point, "pronunciation": pronunciation, "decompositions": decompositions}
|
||||||
|
return data
|
||||||
|
|
||||||
|
def shape_distance(self, char1, char2, safe=True, as_tree=False):
|
||||||
|
"""
|
||||||
|
>>> c = CharFuncs('data/char_meta.txt')
|
||||||
|
>>> c.shape_distance('田', '由')
|
||||||
|
1
|
||||||
|
>>> c.shape_distance('牛', '午')
|
||||||
|
1
|
||||||
|
"""
|
||||||
|
assert char1 in self.data
|
||||||
|
assert char2 in self.data
|
||||||
|
|
||||||
|
def safe_encode(decomp):
|
||||||
|
tree = ''
|
||||||
|
for c in string_to_tree(decomp):
|
||||||
|
if c not in self.safe:
|
||||||
|
tree += c
|
||||||
|
else:
|
||||||
|
tree += self.safe[c]
|
||||||
|
return tree
|
||||||
|
|
||||||
|
def safe_encode_string(decomp):
|
||||||
|
tree = ''
|
||||||
|
for c in decomp:
|
||||||
|
if c not in self.safe:
|
||||||
|
tree += c
|
||||||
|
else:
|
||||||
|
tree += self.safe[c]
|
||||||
|
return tree
|
||||||
|
|
||||||
|
decomps_1 = self.data[char1]["decompositions"]
|
||||||
|
decomps_2 = self.data[char2]["decompositions"]
|
||||||
|
|
||||||
|
distance = 1e5
|
||||||
|
if as_tree:
|
||||||
|
for decomp1 in decomps_1:
|
||||||
|
for decomp2 in decomps_2:
|
||||||
|
if not safe:
|
||||||
|
ted = tree_edit_distance(string_to_tree(decomp1), string_to_tree(decomp2))
|
||||||
|
else:
|
||||||
|
ted = tree_edit_distance(safe_encode(decomp1), safe_encode(decomp2))
|
||||||
|
distance = min(distance, ted)
|
||||||
|
else:
|
||||||
|
for decomp1 in decomps_1:
|
||||||
|
for decomp2 in decomps_2:
|
||||||
|
if not safe:
|
||||||
|
ed = edit_distance(decomp1, decomp2)
|
||||||
|
else:
|
||||||
|
ed = edit_distance(safe_encode_string(decomp1), safe_encode_string(decomp2))
|
||||||
|
distance = min(distance, ed)
|
||||||
|
|
||||||
|
return distance
|
||||||
|
|
||||||
|
def pronunciation_distance(self, char1, char2):
|
||||||
|
"""
|
||||||
|
>>> c = CharFuncs('data/char_meta.txt')
|
||||||
|
>>> c.pronunciation_distance('田', '由')
|
||||||
|
3.4
|
||||||
|
>>> c.pronunciation_distance('牛', '午')
|
||||||
|
2.6
|
||||||
|
"""
|
||||||
|
assert char1 in self.data
|
||||||
|
assert char2 in self.data
|
||||||
|
pronunciations1 = self.data[char1]["pronunciation"]
|
||||||
|
pronunciations2 = self.data[char2]["pronunciation"]
|
||||||
|
|
||||||
|
if pronunciations1[0] == 'null' or pronunciations2 == 'null':
|
||||||
|
return 0.0
|
||||||
|
else:
|
||||||
|
|
||||||
|
pronunciations1 = pronunciations1.split(';') # separate by lan
|
||||||
|
pronunciations2 = pronunciations2.split(';') # separate by lan
|
||||||
|
|
||||||
|
distance = 0.0
|
||||||
|
count = 0
|
||||||
|
for pron_lan1, pron_lan2 in zip(pronunciations1, pronunciations2):
|
||||||
|
if (pron_lan1 == 'null') or (pron_lan2 == 'null'):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
distance_lan = 1e5
|
||||||
|
for p1 in pron_lan1.split(','):
|
||||||
|
for p2 in pron_lan2.split(','):
|
||||||
|
distance_lan = min(distance_lan, edit_distance(p1, p2))
|
||||||
|
distance += distance_lan
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return distance / count
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_dict(fname):
|
||||||
|
data = {}
|
||||||
|
f = open(fname, 'r', encoding='utf-8')
|
||||||
|
for line in f:
|
||||||
|
char, freq = line.strip().split('\t')
|
||||||
|
assert char not in data
|
||||||
|
data[char] = freq
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def similarity(self, char1, char2, weights=(0.8, 0.2, 0.0), as_tree=False):
|
||||||
|
"""
|
||||||
|
this function returns weighted similarity. When used in FASPell, each weight can only be 0 or 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# assert char1 in self.char_dict
|
||||||
|
# assert char2 in self.char_dict
|
||||||
|
shape_w, sound_w, freq_w = weights
|
||||||
|
|
||||||
|
if char1 in self.char_dict and char2 in self.char_dict:
|
||||||
|
|
||||||
|
shape_sim = self.shape_similarity(char1, char2, as_tree=as_tree)
|
||||||
|
sound_sim = self.pronunciation_similarity(char1, char2)
|
||||||
|
freq_sim = 1.0 - self.char_dict[char2] / len(self.char_dict)
|
||||||
|
|
||||||
|
return shape_sim * shape_w + sound_sim * sound_w + freq_sim * freq_w
|
||||||
|
else:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def shape_similarity(self, char1, char2, safe=True, as_tree=False):
|
||||||
|
"""
|
||||||
|
>>> c = CharFuncs('data/char_meta.txt')
|
||||||
|
>>> c.shape_similarity('牛', '午')
|
||||||
|
0.8571428571428572
|
||||||
|
>>> c.shape_similarity('田', '由')
|
||||||
|
0.8888888888888888
|
||||||
|
"""
|
||||||
|
assert char1 in self.data
|
||||||
|
assert char2 in self.data
|
||||||
|
|
||||||
|
def safe_encode(decomp):
|
||||||
|
tree = ''
|
||||||
|
for c in string_to_tree(decomp):
|
||||||
|
if c not in self.safe:
|
||||||
|
tree += c
|
||||||
|
else:
|
||||||
|
tree += self.safe[c]
|
||||||
|
return tree
|
||||||
|
|
||||||
|
def safe_encode_string(decomp):
|
||||||
|
tree = ''
|
||||||
|
for c in decomp:
|
||||||
|
if c not in self.safe:
|
||||||
|
tree += c
|
||||||
|
else:
|
||||||
|
tree += self.safe[c]
|
||||||
|
return tree
|
||||||
|
|
||||||
|
decomps_1 = self.data[char1]["decompositions"]
|
||||||
|
decomps_2 = self.data[char2]["decompositions"]
|
||||||
|
|
||||||
|
similarity = 0.0
|
||||||
|
if as_tree:
|
||||||
|
for decomp1 in decomps_1:
|
||||||
|
for decomp2 in decomps_2:
|
||||||
|
if not safe:
|
||||||
|
ted = tree_edit_distance(string_to_tree(decomp1), string_to_tree(decomp2))
|
||||||
|
else:
|
||||||
|
ted = tree_edit_distance(safe_encode(decomp1), safe_encode(decomp2))
|
||||||
|
normalized_ted = 2 * ted / (len(decomp1) + len(decomp2) + ted)
|
||||||
|
similarity = max(similarity, 1 - normalized_ted)
|
||||||
|
else:
|
||||||
|
for decomp1 in decomps_1:
|
||||||
|
for decomp2 in decomps_2:
|
||||||
|
if not safe:
|
||||||
|
ed = edit_distance(decomp1, decomp2)
|
||||||
|
else:
|
||||||
|
ed = edit_distance(safe_encode_string(decomp1), safe_encode_string(decomp2))
|
||||||
|
normalized_ed = ed / max(len(decomp1), len(decomp2))
|
||||||
|
similarity = max(similarity, 1 - normalized_ed)
|
||||||
|
|
||||||
|
return similarity
|
||||||
|
|
||||||
|
def pronunciation_similarity(self, char1, char2):
|
||||||
|
"""
|
||||||
|
>>> c = CharFuncs('data/char_meta.txt')
|
||||||
|
>>> c.pronunciation_similarity('牛', '午')
|
||||||
|
0.27999999999999997
|
||||||
|
>>> c.pronunciation_similarity('由', '田')
|
||||||
|
0.09
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert char1 in self.data
|
||||||
|
assert char2 in self.data
|
||||||
|
pronunciations1 = self.data[char1]["pronunciation"]
|
||||||
|
pronunciations2 = self.data[char2]["pronunciation"]
|
||||||
|
|
||||||
|
if pronunciations1[0] == 'null' or pronunciations2 == 'null':
|
||||||
|
return 0.0
|
||||||
|
else:
|
||||||
|
|
||||||
|
pronunciations1 = pronunciations1.split(';') # separate by lan
|
||||||
|
pronunciations2 = pronunciations2.split(';') # separate by lan
|
||||||
|
|
||||||
|
similarity = 0.0
|
||||||
|
count = 0
|
||||||
|
for pron_lan1, pron_lan2 in zip(pronunciations1, pronunciations2):
|
||||||
|
if (pron_lan1 == 'null') or (pron_lan2 == 'null'):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
similarity_lan = 0.0
|
||||||
|
for p1 in pron_lan1.split(','):
|
||||||
|
for p2 in pron_lan2.split(','):
|
||||||
|
tmp_sim = 1 - edit_distance(p1, p2) / max(len(p1), len(p2))
|
||||||
|
similarity_lan = max(similarity_lan, tmp_sim)
|
||||||
|
similarity += similarity_lan
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return similarity / count if count else 0
|
||||||
|
|
||||||
|
|
||||||
|
def string_to_tree(string):
|
||||||
|
"""
|
||||||
|
This function converts ids string to a string that can be used as a tree input to APTED.
|
||||||
|
Any Error raised by this function implies that the input string is invalid.
|
||||||
|
>>> string_to_tree('⿱⿱⿰丿㇏⿰丿㇏⿱⿰丿㇏⿰丿㇏') # 炎
|
||||||
|
'{⿱{⿱{⿰{丿}{㇏}}{⿰{丿}{㇏}}}{⿱{⿰{丿}{㇏}}{⿰{丿}{㇏}}}}'
|
||||||
|
>>> string_to_tree('⿱⿰丿㇏⿱一⿱⿻一丨一') # 全
|
||||||
|
'{⿱{⿰{丿}{㇏}}{⿱{一}{⿱{⿻{一}{丨}}{一}}}}'
|
||||||
|
>>> string_to_tree('⿱⿰丿㇏⿻⿱一⿱⿻一丨一丷') # 金
|
||||||
|
'{⿱{⿰{丿}{㇏}}{⿻{⿱{一}{⿱{⿻{一}{丨}}{一}}}{丷}}}'
|
||||||
|
>>> string_to_tree('⿻⿻⿻一丨一⿴⿱⿰丨𠃌一一') # 車
|
||||||
|
'{⿻{⿻{⿻{一}{丨}}{一}}{⿴{⿱{⿰{丨}{𠃌}}{一}}{一}}}'
|
||||||
|
>>> string_to_tree('⿻⿻⿻一丨⿰丿㇏⿴⿱⿰丨𠃌一一') # 東
|
||||||
|
'{⿻{⿻{⿻{一}{丨}}{⿰{丿}{㇏}}}{⿴{⿱{⿰{丨}{𠃌}}{一}}{一}}}'
|
||||||
|
>>> string_to_tree('丿') # 丿
|
||||||
|
'{丿}'
|
||||||
|
>>> string_to_tree('⿻') # ⿻
|
||||||
|
'{⿻}'
|
||||||
|
"""
|
||||||
|
if string[0] in IDCS and len(string) != 1:
|
||||||
|
bracket_stack = []
|
||||||
|
tree = []
|
||||||
|
|
||||||
|
def add_brackets(num):
|
||||||
|
if num == 2:
|
||||||
|
bracket_stack.extend(['}', '{', '}'])
|
||||||
|
else:
|
||||||
|
bracket_stack.extend(['}', '{', '}', '{', '}'])
|
||||||
|
tree.append('{')
|
||||||
|
|
||||||
|
global_just_put = '{'
|
||||||
|
|
||||||
|
for c in string:
|
||||||
|
tree.append(c)
|
||||||
|
if c in IDCS:
|
||||||
|
assert global_just_put != '}'
|
||||||
|
add_brackets(IDCS[c])
|
||||||
|
global_just_put = '{'
|
||||||
|
else:
|
||||||
|
just_put = ''
|
||||||
|
while just_put != '{' and bracket_stack:
|
||||||
|
just_put = bracket_stack.pop(-1)
|
||||||
|
tree.append(just_put)
|
||||||
|
global_just_put = just_put
|
||||||
|
|
||||||
|
res = ''.join(tree)
|
||||||
|
assert res[-1] == '}'
|
||||||
|
else:
|
||||||
|
assert len(string) == 1 or string == 'null'
|
||||||
|
res = string[0]
|
||||||
|
|
||||||
|
return '{' + res + '}'
|
||||||
|
|
||||||
|
|
||||||
|
def pinyin_map(standard_pinyin):
|
||||||
|
"""
|
||||||
|
>>> pinyin_map('xuě')
|
||||||
|
'xue3'
|
||||||
|
>>> pinyin_map('xue')
|
||||||
|
'xue'
|
||||||
|
>>> pinyin_map('lǜ')
|
||||||
|
'lü4'
|
||||||
|
>>> pinyin_map('fá')
|
||||||
|
'fa2'
|
||||||
|
"""
|
||||||
|
tone = ''
|
||||||
|
pinyin = ''
|
||||||
|
|
||||||
|
assert ' ' not in standard_pinyin
|
||||||
|
for c in standard_pinyin:
|
||||||
|
if c in PINYIN:
|
||||||
|
pinyin += PINYIN[c][0]
|
||||||
|
assert tone == ''
|
||||||
|
tone = str(PINYIN[c][1])
|
||||||
|
else:
|
||||||
|
pinyin += c
|
||||||
|
pinyin += tone
|
||||||
|
return pinyin
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
usage = '\n1. You can compute character similarity by:\n' \
|
||||||
|
'python char_sim.py 午 牛 年 千\n' \
|
||||||
|
'\n' \
|
||||||
|
'2. You can use ted in computing character similarity by:\n' \
|
||||||
|
'python char_sim.py 午 牛 年 千 -t\n' \
|
||||||
|
'\n'
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='A script to compute Chinese character (Kanji) similarity', usage=usage)
|
||||||
|
|
||||||
|
parser.add_argument('multiargs', nargs='*', type=str, default=None,
|
||||||
|
help='Chinese characters in question')
|
||||||
|
parser.add_argument('--ted', '-t', action="store_true", default=False,
|
||||||
|
help='True=to use tree edit distence (TED)'
|
||||||
|
'False=to use string edit distance')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
c = CharFuncs('data/char_meta.txt')
|
||||||
|
if not args.ted:
|
||||||
|
for i, c1 in enumerate(args.multiargs):
|
||||||
|
for c2 in args.multiargs[i:]:
|
||||||
|
if c1 != c2:
|
||||||
|
print(f'For character pair ({c1}, {c2}):')
|
||||||
|
print(f' v-sim = {c.shape_similarity(c1, c2)}')
|
||||||
|
print(f' p-sim = {c.pronunciation_similarity(c1, c2)}\n')
|
||||||
|
else:
|
||||||
|
for i, c1 in enumerate(args.multiargs):
|
||||||
|
for c2 in args.multiargs[i:]:
|
||||||
|
if c1 != c2:
|
||||||
|
print(f'For character pair ({c1}, {c2}):')
|
||||||
|
print(f' v-sim = {c.shape_similarity(c1, c2, as_tree=True)}')
|
||||||
|
print(f' p-sim = {c.pronunciation_similarity(c1, c2)}\n')
|
433
opencompass/datasets/lawbench/utils/compare_m2_for_evaluation.py
Normal file
433
opencompass/datasets/lawbench/utils/compare_m2_for_evaluation.py
Normal file
@ -0,0 +1,433 @@
|
|||||||
|
import argparse
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Parse command line args
|
||||||
|
args = parse_args()
|
||||||
|
# Open hypothesis and reference m2 files and split into chunks
|
||||||
|
hyp_m2 = open(args.hyp).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.hyp).read().strip().split("\n\n")
|
||||||
|
ref_m2 = open(args.ref).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.ref).read().strip().split("\n\n")
|
||||||
|
# Make sure they have the same number of sentences
|
||||||
|
assert len(hyp_m2) == len(ref_m2), print(len(hyp_m2), len(ref_m2))
|
||||||
|
|
||||||
|
# Store global corpus level best counts here
|
||||||
|
best_dict = Counter({"tp":0, "fp":0, "fn":0})
|
||||||
|
best_cats = {}
|
||||||
|
# Process each sentence
|
||||||
|
sents = zip(hyp_m2, ref_m2)
|
||||||
|
for sent_id, sent in enumerate(sents):
|
||||||
|
# Simplify the edits into lists of lists
|
||||||
|
# if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons:
|
||||||
|
# sent_id_cons.append(sent_id)
|
||||||
|
src = sent[0].split("\n")[0]
|
||||||
|
hyp_edits = simplify_edits(sent[0], args.max_answer_num)
|
||||||
|
ref_edits = simplify_edits(sent[1], args.max_answer_num)
|
||||||
|
# Process the edits for detection/correction based on args
|
||||||
|
hyp_dict = process_edits(hyp_edits, args)
|
||||||
|
ref_dict = process_edits(ref_edits, args)
|
||||||
|
if args.reference_num is None or len(ref_dict.keys()) == args.reference_num:
|
||||||
|
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
|
||||||
|
count_dict, cat_dict = evaluate_edits(src,
|
||||||
|
hyp_dict, ref_dict, best_dict, sent_id, args)
|
||||||
|
# Merge these dicts with best_dict and best_cats
|
||||||
|
best_dict += Counter(count_dict)
|
||||||
|
best_cats = merge_dict(best_cats, cat_dict)
|
||||||
|
# Print results
|
||||||
|
print_results(best_dict, best_cats, args)
|
||||||
|
|
||||||
|
# Parse command line args
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Calculate F-scores for error detection and/or correction.\n"
|
||||||
|
"Flags let you evaluate at different levels of granularity.",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
usage="%(prog)s [options] -hyp HYP -ref REF")
|
||||||
|
parser.add_argument(
|
||||||
|
"-hyp",
|
||||||
|
help="A hypothesis M2 file.",
|
||||||
|
required=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"-ref",
|
||||||
|
help="A reference M2 file.",
|
||||||
|
required=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"--start",
|
||||||
|
type=int,
|
||||||
|
default=None
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--end",
|
||||||
|
type=int,
|
||||||
|
default=None
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_answer_num",
|
||||||
|
type=int,
|
||||||
|
default=None
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reference_num",
|
||||||
|
type=int,
|
||||||
|
default=None
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-b",
|
||||||
|
"--beta",
|
||||||
|
help="Value of beta in F-score. (default: 0.5)",
|
||||||
|
default=0.5,
|
||||||
|
type=float)
|
||||||
|
parser.add_argument(
|
||||||
|
"-v",
|
||||||
|
"--verbose",
|
||||||
|
help="Print verbose output.",
|
||||||
|
action="store_true")
|
||||||
|
eval_type = parser.add_mutually_exclusive_group()
|
||||||
|
eval_type.add_argument(
|
||||||
|
"-dt",
|
||||||
|
help="Evaluate Detection in terms of Tokens.",
|
||||||
|
action="store_true")
|
||||||
|
eval_type.add_argument(
|
||||||
|
"-ds",
|
||||||
|
help="Evaluate Detection in terms of Spans.",
|
||||||
|
action="store_true")
|
||||||
|
eval_type.add_argument(
|
||||||
|
"-cs",
|
||||||
|
help="Evaluate Correction in terms of Spans. (default)",
|
||||||
|
action="store_true")
|
||||||
|
eval_type.add_argument(
|
||||||
|
"-cse",
|
||||||
|
help="Evaluate Correction in terms of Spans and Error types.",
|
||||||
|
action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"-single",
|
||||||
|
help="Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1",
|
||||||
|
action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"-multi",
|
||||||
|
help="Only evaluate multi token edits; i.e. 2+:n or n:2+",
|
||||||
|
action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"-multi_hyp_avg",
|
||||||
|
help="When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence.",
|
||||||
|
action="store_true") # For IAA calculation
|
||||||
|
parser.add_argument(
|
||||||
|
"-multi_hyp_max",
|
||||||
|
help="When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence.",
|
||||||
|
action="store_true") # For multiple hypotheses system evaluation
|
||||||
|
parser.add_argument(
|
||||||
|
"-filt",
|
||||||
|
help="Do not evaluate the specified error types.",
|
||||||
|
nargs="+",
|
||||||
|
default=[])
|
||||||
|
parser.add_argument(
|
||||||
|
"-cat",
|
||||||
|
help="Show error category scores.\n"
|
||||||
|
"1: Only show operation tier scores; e.g. R.\n"
|
||||||
|
"2: Only show main tier scores; e.g. NOUN.\n"
|
||||||
|
"3: Show all category scores; e.g. R:NOUN.",
|
||||||
|
choices=[1, 2, 3],
|
||||||
|
type=int)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
# Input: An m2 format sentence with edits.
|
||||||
|
# Output: A list of lists. Each edit: [start, end, cat, cor, coder]
|
||||||
|
def simplify_edits(sent, max_answer_num):
|
||||||
|
out_edits = []
|
||||||
|
# Get the edit lines from an m2 block.
|
||||||
|
edits = sent.split("\n")
|
||||||
|
# Loop through the edits
|
||||||
|
for edit in edits:
|
||||||
|
# Preprocessing
|
||||||
|
if edit.startswith("A "):
|
||||||
|
edit = edit[2:].split("|||") # Ignore "A " then split.
|
||||||
|
span = edit[0].split()
|
||||||
|
start = int(span[0])
|
||||||
|
end = int(span[1])
|
||||||
|
cat = edit[1]
|
||||||
|
cor = edit[2].replace(" ", "")
|
||||||
|
coder = int(edit[-1])
|
||||||
|
out_edit = [start, end, cat, cor, coder]
|
||||||
|
out_edits.append(out_edit)
|
||||||
|
# return [edit for edit in out_edits if edit[-1] in [0,1]]
|
||||||
|
if max_answer_num is None:
|
||||||
|
return out_edits
|
||||||
|
elif max_answer_num == 1:
|
||||||
|
return [edit for edit in out_edits if edit[-1] == 0]
|
||||||
|
elif max_answer_num == 2:
|
||||||
|
return [edit for edit in out_edits if edit[-1] in [0, 1]]
|
||||||
|
elif max_answer_num == 3:
|
||||||
|
return [edit for edit in out_edits if edit[-1] in [0, 1, 2]]
|
||||||
|
|
||||||
|
# Input 1: A list of edits. Each edit: [start, end, cat, cor, coder]
|
||||||
|
# Input 2: Command line args
|
||||||
|
# Output: A dict; key is coder, value is edit dict.
|
||||||
|
def process_edits(edits, args):
|
||||||
|
coder_dict = {}
|
||||||
|
# Add an explicit noop edit if there are no edits.
|
||||||
|
if not edits: edits = [[-1, -1, "noop", "-NONE-", 0]]
|
||||||
|
# Loop through the edits
|
||||||
|
for edit in edits:
|
||||||
|
# Name the edit elements for clarity
|
||||||
|
start = edit[0]
|
||||||
|
end = edit[1]
|
||||||
|
cat = edit[2]
|
||||||
|
cor = edit[3]
|
||||||
|
coder = edit[4]
|
||||||
|
# Add the coder to the coder_dict if necessary
|
||||||
|
if coder not in coder_dict: coder_dict[coder] = {}
|
||||||
|
|
||||||
|
# Optionally apply filters based on args
|
||||||
|
# 1. UNK type edits are only useful for detection, not correction.
|
||||||
|
if not args.dt and not args.ds and cat == "UNK": continue
|
||||||
|
# 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1
|
||||||
|
if args.single and (end-start >= 2 or len(cor.split()) >= 2): continue
|
||||||
|
# 3. Only evaluate multi token edits; i.e. 2+:n or n:2+
|
||||||
|
if args.multi and end-start < 2 and len(cor.split()) < 2: continue
|
||||||
|
# 4. If there is a filter, ignore the specified error types
|
||||||
|
if args.filt and cat in args.filt: continue
|
||||||
|
|
||||||
|
# Token Based Detection
|
||||||
|
if args.dt:
|
||||||
|
# Preserve noop edits.
|
||||||
|
if start == -1:
|
||||||
|
if (start, start) in coder_dict[coder].keys():
|
||||||
|
coder_dict[coder][(start, start)].append(cat)
|
||||||
|
else:
|
||||||
|
coder_dict[coder][(start, start)] = [cat]
|
||||||
|
# Insertions defined as affecting the token on the right
|
||||||
|
elif start == end and start >= 0:
|
||||||
|
if (start, start+1) in coder_dict[coder].keys():
|
||||||
|
coder_dict[coder][(start, start+1)].append(cat)
|
||||||
|
else:
|
||||||
|
coder_dict[coder][(start, start+1)] = [cat]
|
||||||
|
# Edit spans are split for each token in the range.
|
||||||
|
else:
|
||||||
|
for tok_id in range(start, end):
|
||||||
|
if (tok_id, tok_id+1) in coder_dict[coder].keys():
|
||||||
|
coder_dict[coder][(tok_id, tok_id+1)].append(cat)
|
||||||
|
else:
|
||||||
|
coder_dict[coder][(tok_id, tok_id+1)] = [cat]
|
||||||
|
|
||||||
|
# Span Based Detection
|
||||||
|
elif args.ds:
|
||||||
|
if (start, end) in coder_dict[coder].keys():
|
||||||
|
coder_dict[coder][(start, end)].append(cat)
|
||||||
|
else:
|
||||||
|
coder_dict[coder][(start, end)] = [cat]
|
||||||
|
|
||||||
|
# Span Based Correction
|
||||||
|
else:
|
||||||
|
# With error type classification
|
||||||
|
if args.cse:
|
||||||
|
if (start, end, cat, cor) in coder_dict[coder].keys():
|
||||||
|
coder_dict[coder][(start, end, cat, cor)].append(cat)
|
||||||
|
else:
|
||||||
|
coder_dict[coder][(start, end, cat, cor)] = [cat]
|
||||||
|
# Without error type classification
|
||||||
|
else:
|
||||||
|
if (start, end, cor) in coder_dict[coder].keys():
|
||||||
|
coder_dict[coder][(start, end, cor)].append(cat)
|
||||||
|
else:
|
||||||
|
coder_dict[coder][(start, end, cor)] = [cat]
|
||||||
|
return coder_dict
|
||||||
|
|
||||||
|
# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits.
|
||||||
|
# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits.
|
||||||
|
# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far.
|
||||||
|
# Input 4: Sentence ID (for verbose output only)
|
||||||
|
# Input 5: Command line args
|
||||||
|
# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
|
||||||
|
# Output 2: The corresponding error type dict for the above dict.
|
||||||
|
def evaluate_edits(src, hyp_dict, ref_dict, best, sent_id, args):
|
||||||
|
# Store the best sentence level scores and hyp+ref combination IDs
|
||||||
|
# best_f is initialised as -1 cause 0 is a valid result.
|
||||||
|
best_tp, best_fp, best_fn, best_f, best_hyp, best_ref = 0, 0, 0, -1, 0, 0
|
||||||
|
best_cat = {}
|
||||||
|
# skip not annotatable sentence
|
||||||
|
if len(ref_dict.keys()) == 1:
|
||||||
|
ref_id = list(ref_dict.keys())[0]
|
||||||
|
if len(ref_dict[ref_id].keys()) == 1:
|
||||||
|
cat = list(ref_dict[ref_id].values())[0][0]
|
||||||
|
if cat == "NA":
|
||||||
|
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
|
||||||
|
return best_dict, best_cat
|
||||||
|
|
||||||
|
# Compare each hyp and ref combination
|
||||||
|
for hyp_id in hyp_dict.keys():
|
||||||
|
for ref_id in ref_dict.keys():
|
||||||
|
# Get the local counts for the current combination.
|
||||||
|
tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id])
|
||||||
|
# Compute the local sentence scores (for verbose output only)
|
||||||
|
loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta)
|
||||||
|
# Compute the global sentence scores
|
||||||
|
p, r, f = computeFScore(
|
||||||
|
tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta)
|
||||||
|
# Save the scores if they are better in terms of:
|
||||||
|
# 1. Higher F-score
|
||||||
|
# 2. Same F-score, higher TP
|
||||||
|
# 3. Same F-score and TP, lower FP
|
||||||
|
# 4. Same F-score, TP and FP, lower FN
|
||||||
|
if (f > best_f) or \
|
||||||
|
(f == best_f and tp > best_tp) or \
|
||||||
|
(f == best_f and tp == best_tp and fp < best_fp) or \
|
||||||
|
(f == best_f and tp == best_tp and fp == best_fp and fn < best_fn):
|
||||||
|
best_tp, best_fp, best_fn = tp, fp, fn
|
||||||
|
best_f, best_hyp, best_ref = f, hyp_id, ref_id
|
||||||
|
best_cat = cat_dict
|
||||||
|
# Verbose output
|
||||||
|
if args.verbose:
|
||||||
|
# Prepare verbose output edits.
|
||||||
|
hyp_verb = list(sorted(hyp_dict[hyp_id].keys()))
|
||||||
|
ref_verb = list(sorted(ref_dict[ref_id].keys()))
|
||||||
|
# Ignore noop edits
|
||||||
|
if not hyp_verb or hyp_verb[0][0] == -1: hyp_verb = []
|
||||||
|
if not ref_verb or ref_verb[0][0] == -1: ref_verb = []
|
||||||
|
# Print verbose info
|
||||||
|
print('{:-^40}'.format(""))
|
||||||
|
print("SENTENCE "+str(sent_id)+src[1:])
|
||||||
|
print('{:-^40}'.format(""))
|
||||||
|
print("SENTENCE "+str(sent_id)+" - HYP "+str(hyp_id)+" - REF "+str(ref_id))
|
||||||
|
print("HYPOTHESIS EDITS :", hyp_verb)
|
||||||
|
print("REFERENCE EDITS :", ref_verb)
|
||||||
|
print("Local TP/FP/FN :", str(tp), str(fp), str(fn))
|
||||||
|
print("Local P/R/F"+str(args.beta)+" :", str(loc_p), str(loc_r), str(loc_f))
|
||||||
|
print("Global TP/FP/FN :", str(tp+best["tp"]), str(fp+best["fp"]), str(fn+best["fn"]))
|
||||||
|
print("Global P/R/F"+str(args.beta)+" :", str(p), str(r), str(f))
|
||||||
|
# Verbose output: display the best hyp+ref combination
|
||||||
|
if args.verbose:
|
||||||
|
print('{:-^40}'.format(""))
|
||||||
|
print("^^ HYP "+str(best_hyp)+", REF "+str(best_ref)+" chosen for sentence "+str(sent_id))
|
||||||
|
# Save the best TP, FP and FNs as a dict, and return this and the best_cat dict
|
||||||
|
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
|
||||||
|
return best_dict, best_cat
|
||||||
|
|
||||||
|
# Input 1: A dictionary of hypothesis edits for a single system.
|
||||||
|
# Input 2: A dictionary of reference edits for a single annotator.
|
||||||
|
# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator.
|
||||||
|
# Output 4: A dictionary of the error type counts.
|
||||||
|
def compareEdits(hyp_edits, ref_edits):
|
||||||
|
tp = 0 # True Positives
|
||||||
|
fp = 0 # False Positives
|
||||||
|
fn = 0 # False Negatives
|
||||||
|
cat_dict = {} # {cat: [tp, fp, fn], ...}
|
||||||
|
|
||||||
|
for h_edit, h_cats in hyp_edits.items():
|
||||||
|
# noop hyp edits cannot be TP or FP
|
||||||
|
if h_cats[0] == "noop": continue
|
||||||
|
# TRUE POSITIVES
|
||||||
|
if h_edit in ref_edits.keys():
|
||||||
|
# On occasion, multiple tokens at same span.
|
||||||
|
for h_cat in ref_edits[h_edit]: # Use ref dict for TP
|
||||||
|
tp += 1
|
||||||
|
# Each dict value [TP, FP, FN]
|
||||||
|
if h_cat in cat_dict.keys():
|
||||||
|
cat_dict[h_cat][0] += 1
|
||||||
|
else:
|
||||||
|
cat_dict[h_cat] = [1, 0, 0]
|
||||||
|
# FALSE POSITIVES
|
||||||
|
else:
|
||||||
|
# On occasion, multiple tokens at same span.
|
||||||
|
for h_cat in h_cats:
|
||||||
|
fp += 1
|
||||||
|
# Each dict value [TP, FP, FN]
|
||||||
|
if h_cat in cat_dict.keys():
|
||||||
|
cat_dict[h_cat][1] += 1
|
||||||
|
else:
|
||||||
|
cat_dict[h_cat] = [0, 1, 0]
|
||||||
|
for r_edit, r_cats in ref_edits.items():
|
||||||
|
# noop ref edits cannot be FN
|
||||||
|
if r_cats[0] == "noop": continue
|
||||||
|
# FALSE NEGATIVES
|
||||||
|
if r_edit not in hyp_edits.keys():
|
||||||
|
# On occasion, multiple tokens at same span.
|
||||||
|
for r_cat in r_cats:
|
||||||
|
fn += 1
|
||||||
|
# Each dict value [TP, FP, FN]
|
||||||
|
if r_cat in cat_dict.keys():
|
||||||
|
cat_dict[r_cat][2] += 1
|
||||||
|
else:
|
||||||
|
cat_dict[r_cat] = [0, 0, 1]
|
||||||
|
return tp, fp, fn, cat_dict
|
||||||
|
|
||||||
|
# Input 1-3: True positives, false positives, false negatives
|
||||||
|
# Input 4: Value of beta in F-score.
|
||||||
|
# Output 1-3: Precision, Recall and F-score rounded to 4dp.
|
||||||
|
def computeFScore(tp, fp, fn, beta):
|
||||||
|
p = float(tp)/(tp+fp) if fp else 1.0
|
||||||
|
r = float(tp)/(tp+fn) if fn else 1.0
|
||||||
|
f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0
|
||||||
|
return round(p, 4), round(r, 4), round(f, 4)
|
||||||
|
|
||||||
|
# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN.
|
||||||
|
# Output: The dictionaries combined with cumulative TP, FP, FN.
|
||||||
|
def merge_dict(dict1, dict2):
|
||||||
|
for cat, stats in dict2.items():
|
||||||
|
if cat in dict1.keys():
|
||||||
|
dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)]
|
||||||
|
else:
|
||||||
|
dict1[cat] = stats
|
||||||
|
return dict1
|
||||||
|
|
||||||
|
# Input 1: A dict; key is error cat, value is counts for [tp, fp, fn]
|
||||||
|
# Input 2: Integer value denoting level of error category granularity.
|
||||||
|
# 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything.
|
||||||
|
# Output: A dictionary of category TP, FP and FN based on Input 2.
|
||||||
|
def processCategories(cat_dict, setting):
|
||||||
|
# Otherwise, do some processing.
|
||||||
|
proc_cat_dict = {}
|
||||||
|
for cat, cnt in cat_dict.items():
|
||||||
|
if cat == "UNK":
|
||||||
|
proc_cat_dict[cat] = cnt
|
||||||
|
continue
|
||||||
|
# M, U, R or UNK combined only.
|
||||||
|
if setting == 1:
|
||||||
|
if cat[0] in proc_cat_dict.keys():
|
||||||
|
proc_cat_dict[cat[0]] = [x+y for x, y in zip(proc_cat_dict[cat[0]], cnt)]
|
||||||
|
else:
|
||||||
|
proc_cat_dict[cat[0]] = cnt
|
||||||
|
# Everything without M, U or R.
|
||||||
|
elif setting == 2:
|
||||||
|
if cat[2:] in proc_cat_dict.keys():
|
||||||
|
proc_cat_dict[cat[2:]] = [x+y for x, y in zip(proc_cat_dict[cat[2:]], cnt)]
|
||||||
|
else:
|
||||||
|
proc_cat_dict[cat[2:]] = cnt
|
||||||
|
# All error category combinations
|
||||||
|
else:
|
||||||
|
return cat_dict
|
||||||
|
return proc_cat_dict
|
||||||
|
|
||||||
|
# Input 1: A dict of global best TP, FP and FNs
|
||||||
|
# Input 2: A dict of error types and counts for those TP, FP and FNs
|
||||||
|
# Input 3: Command line args
|
||||||
|
def print_results(best, best_cats, args):
|
||||||
|
# Prepare output title.
|
||||||
|
if args.dt: title = " Token-Based Detection "
|
||||||
|
elif args.ds: title = " Span-Based Detection "
|
||||||
|
elif args.cse: title = " Span-Based Correction + Classification "
|
||||||
|
else: title = " Span-Based Correction "
|
||||||
|
|
||||||
|
# Category Scores
|
||||||
|
if args.cat:
|
||||||
|
best_cats = processCategories(best_cats, args.cat)
|
||||||
|
print("")
|
||||||
|
print('{:=^66}'.format(title))
|
||||||
|
print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8),
|
||||||
|
"P".ljust(8), "R".ljust(8), "F"+str(args.beta))
|
||||||
|
for cat, cnts in sorted(best_cats.items()):
|
||||||
|
cat_p, cat_r, cat_f = computeFScore(cnts[0], cnts[1], cnts[2], args.beta)
|
||||||
|
print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8),
|
||||||
|
str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f)
|
||||||
|
|
||||||
|
# Print the overall results.
|
||||||
|
print("")
|
||||||
|
print('{:=^46}'.format(title))
|
||||||
|
print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)]))
|
||||||
|
print("\t".join(map(str, [best["tp"], best["fp"],
|
||||||
|
best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta)))))
|
||||||
|
print('{:=^46}'.format(""))
|
||||||
|
print("")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run the program
|
||||||
|
main()
|
82
opencompass/datasets/lawbench/utils/comprehension_scores.py
Normal file
82
opencompass/datasets/lawbench/utils/comprehension_scores.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import re
|
||||||
|
from ..utils.rc_f1 import CJRCEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
given a target substring. find its all occurances in the string s
|
||||||
|
return the starting and ending index of every occurance
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def __find_substring_starts(s, target):
|
||||||
|
return [(m.start(), m.end()) for m in re.finditer(target, s)]
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
compute the reading comprehension F1 scores
|
||||||
|
hyps and refs are lists of hyposisis and reference strings
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def compute_rc_f1(hyps, refs):
|
||||||
|
scores = 0
|
||||||
|
for h, r in zip(hyps, refs):
|
||||||
|
scores += CJRCEvaluator.compute_f1(r, h)
|
||||||
|
return {'score': scores / len(hyps)}
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
compute the information extraction F1 scores
|
||||||
|
hyps and refs are lists of hyposisis and reference strings
|
||||||
|
entity_types: a set of all possible entity types
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def compute_ie_f1(hyps, refs, entity_types):
|
||||||
|
assert (len(hyps) == len(refs))
|
||||||
|
scores, abstentions = 0, 0
|
||||||
|
for h, r in zip(hyps, refs):
|
||||||
|
h = __extract_entities_pred(h, entity_types)
|
||||||
|
r = __extract_entities_ref(r)
|
||||||
|
if r == {}:
|
||||||
|
scores += 1 if h == {} else 0
|
||||||
|
continue
|
||||||
|
if h == {}:
|
||||||
|
abstentions += 1
|
||||||
|
intersected = [CJRCEvaluator.compute_f1(r[etype], einstance) for etype, einstance in h.items() if etype in r]
|
||||||
|
prec = sum(intersected) / len(h) if len(h) > 0 else 0
|
||||||
|
rec = sum(intersected) / len(r) if len(r) > 0 else 0
|
||||||
|
# print(prec, rec, intersected)
|
||||||
|
scores += 2 * prec * rec / (prec + rec + 1e-10)
|
||||||
|
return {'score': scores / len(hyps), "anstention_rate": abstentions / len(hyps)}
|
||||||
|
|
||||||
|
|
||||||
|
def __extract_entities_ref(ref):
|
||||||
|
outputs = {}
|
||||||
|
if ref.strip() == '':
|
||||||
|
return outputs
|
||||||
|
for seg in ref.split(';'):
|
||||||
|
seg = seg.split(':')
|
||||||
|
outputs[seg[0]] = seg[1]
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
extract entity type and instances from the model prediction
|
||||||
|
pred: string of model prediction
|
||||||
|
entity_types: a set of all possible entity types
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def __extract_entities_pred(pred, entity_types):
|
||||||
|
outputs = {}
|
||||||
|
for etype in entity_types:
|
||||||
|
occurances = __find_substring_starts(pred, etype)
|
||||||
|
for start, end in occurances:
|
||||||
|
if end >= (len(pred) - 2):
|
||||||
|
continue
|
||||||
|
if pred[end] == ":" or pred[end] == ":":
|
||||||
|
einstance = re.split("\n| ", pred[end + 1:].strip())[0].strip()
|
||||||
|
if einstance != '无' and einstance != '未提及':
|
||||||
|
outputs[etype] = einstance
|
||||||
|
return outputs
|
49
opencompass/datasets/lawbench/utils/function_utils.py
Normal file
49
opencompass/datasets/lawbench/utils/function_utils.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from rouge_chinese import Rouge
|
||||||
|
import jieba
|
||||||
|
from nltk.translate.gleu_score import corpus_gleu
|
||||||
|
|
||||||
|
def compute_f1_two_sets(pred_set, gt_set):
|
||||||
|
precision = len(pred_set.intersection(gt_set)) / len(pred_set) if len(pred_set) > 0 else 0
|
||||||
|
recall = len(pred_set.intersection(gt_set)) / len(gt_set) if len(gt_set) > 0 else 0
|
||||||
|
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
|
||||||
|
return f1
|
||||||
|
|
||||||
|
def multi_choice_judge(prediction, option_list, answer_token):
|
||||||
|
# a dict, key: letters in the option list, value: count of the letter in the prediction
|
||||||
|
count_dict, abstention, accuracy = {}, 0, 0
|
||||||
|
for option in option_list:
|
||||||
|
option_count = prediction.count(option)
|
||||||
|
count_dict[option] = 1 if option_count > 0 else 0 # multiple occurrence of the same letter is counted as 1
|
||||||
|
|
||||||
|
if sum(count_dict.values()) == 0:
|
||||||
|
abstention = 1
|
||||||
|
# if the answer token is the only predicted token, the prediction is correct
|
||||||
|
elif count_dict[answer_token] == 1 and sum(count_dict.values()) == 1:
|
||||||
|
accuracy = 1
|
||||||
|
return {"score": accuracy, "abstention": abstention}
|
||||||
|
|
||||||
|
"""
|
||||||
|
compute the rouge score.
|
||||||
|
hyps and refs are lists of hyposisis and reference strings
|
||||||
|
empty predictions are replaces with 无内容
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def compute_rouge(hyps, refs):
|
||||||
|
assert(len(hyps) == len(refs))
|
||||||
|
hyps = [' '.join(jieba.cut(h)) for h in hyps]
|
||||||
|
hyps = [h if h.strip() != "" else "无内容" for h in hyps]
|
||||||
|
refs = [' '.join(jieba.cut(r)) for r in refs]
|
||||||
|
return Rouge().get_scores(hyps, refs)
|
||||||
|
|
||||||
|
"""
|
||||||
|
compute the gleu score.
|
||||||
|
hyps and refs are lists of hyposisis and reference strings
|
||||||
|
empty predictions are replaces with 无内容
|
||||||
|
"""
|
||||||
|
def compute_gleu(hyps, refs):
|
||||||
|
assert(len(hyps) == len(refs))
|
||||||
|
hyps = [' '.join(jieba.cut(h)) for h in hyps]
|
||||||
|
hyps = [h if h.strip() != "" else "无内容" for h in hyps]
|
||||||
|
refs = [[' '.join(jieba.cut(r))] for r in refs]
|
||||||
|
return corpus_gleu(refs, hyps)
|
334
opencompass/datasets/lawbench/utils/modules/alignment.py
Normal file
334
opencompass/datasets/lawbench/utils/modules/alignment.py
Normal file
@ -0,0 +1,334 @@
|
|||||||
|
import numpy as np
|
||||||
|
from typing import List, Tuple, Dict
|
||||||
|
from modules.tokenizer import Tokenizer
|
||||||
|
import os
|
||||||
|
from string import punctuation
|
||||||
|
|
||||||
|
REAL_PATH = os.path.split(os.path.realpath(__file__))[0]
|
||||||
|
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏"
|
||||||
|
english_punct = punctuation
|
||||||
|
punct = chinese_punct + english_punct
|
||||||
|
|
||||||
|
def check_all_chinese(word):
|
||||||
|
"""
|
||||||
|
判断一个单词是否全部由中文组成
|
||||||
|
:param word:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return all(['\u4e00' <= ch <= '\u9fff' for ch in word])
|
||||||
|
|
||||||
|
def read_cilin():
|
||||||
|
"""
|
||||||
|
Cilin 詞林 is a thesaurus with semantic information
|
||||||
|
"""
|
||||||
|
# TODO -- fix this path
|
||||||
|
project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path
|
||||||
|
lines = open(os.path.join(project_dir, "data", "cilin.txt"), "r", encoding="gbk").read().strip().split("\n")
|
||||||
|
semantic_dict = {}
|
||||||
|
semantic_classes = {}
|
||||||
|
for line in lines:
|
||||||
|
code, *words = line.split(" ")
|
||||||
|
for word in words:
|
||||||
|
semantic_dict[word] = code
|
||||||
|
# make reverse dict
|
||||||
|
if code in semantic_classes:
|
||||||
|
semantic_classes[code] += words
|
||||||
|
else:
|
||||||
|
semantic_classes[code] = words
|
||||||
|
return semantic_dict, semantic_classes
|
||||||
|
|
||||||
|
|
||||||
|
def read_confusion():
|
||||||
|
confusion_dict = {}
|
||||||
|
project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path
|
||||||
|
with open(os.path.join(project_dir, "data", "confusion_dict.txt"), "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
li = line.rstrip('\n').split(" ")
|
||||||
|
confusion_dict[li[0]] = li[1:]
|
||||||
|
return confusion_dict
|
||||||
|
|
||||||
|
class Alignment:
|
||||||
|
"""
|
||||||
|
对齐错误句子和正确句子,
|
||||||
|
使用编辑距离算法抽取编辑操作
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
semantic_dict: Dict,
|
||||||
|
confusion_dict: Dict,
|
||||||
|
granularity: str = "word",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
构造函数
|
||||||
|
:param semantic_dict: 语义词典(大词林)
|
||||||
|
:param confusion_dict: 字符混淆集
|
||||||
|
"""
|
||||||
|
self.insertion_cost = 1
|
||||||
|
self.deletion_cost = 1
|
||||||
|
self.semantic_dict = semantic_dict
|
||||||
|
self.confusion_dict = confusion_dict
|
||||||
|
# Because we use character level tokenization, this doesn't currently use POS
|
||||||
|
self._open_pos = {} # 如果是词级别,还可以利用词性是否相同来计算cost
|
||||||
|
self.granularity = granularity # word-level or character-level
|
||||||
|
self.align_seqs = []
|
||||||
|
|
||||||
|
def __call__(self,
|
||||||
|
src: List[Tuple],
|
||||||
|
tgt: List[Tuple],
|
||||||
|
verbose: bool = False):
|
||||||
|
cost_matrix, oper_matrix = self.align(src, tgt)
|
||||||
|
align_seq = self.get_cheapest_align_seq(oper_matrix)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print("========== Seg. and POS: ==========")
|
||||||
|
print(src)
|
||||||
|
print(tgt)
|
||||||
|
print("========== Cost Matrix ==========")
|
||||||
|
print(cost_matrix)
|
||||||
|
print("========== Oper Matrix ==========")
|
||||||
|
print(oper_matrix)
|
||||||
|
print("========== Alignment ==========")
|
||||||
|
print(align_seq)
|
||||||
|
print("========== Results ==========")
|
||||||
|
for a in align_seq:
|
||||||
|
print(a[0], src[a[1]: a[2]], tgt[a[3]: a[4]])
|
||||||
|
return align_seq
|
||||||
|
|
||||||
|
def _get_semantic_class(self, word):
|
||||||
|
"""
|
||||||
|
NOTE: Based on the paper:
|
||||||
|
Improved-Edit-Distance Kernel for Chinese Relation Extraction
|
||||||
|
获取每个词语的语义类别(基于大词林,有三个级别)
|
||||||
|
"""
|
||||||
|
if word in self.semantic_dict:
|
||||||
|
code = self.semantic_dict[word]
|
||||||
|
high, mid, low = code[0], code[1], code[2:4]
|
||||||
|
return high, mid, low
|
||||||
|
else: # unknown
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_class_diff(a_class, b_class):
|
||||||
|
"""
|
||||||
|
d == 3 for equivalent semantics
|
||||||
|
d == 0 for completely different semantics
|
||||||
|
根据大词林的信息,计算两个词的语义类别的差距
|
||||||
|
"""
|
||||||
|
d = sum([a == b for a, b in zip(a_class, b_class)])
|
||||||
|
return d
|
||||||
|
|
||||||
|
def _get_semantic_cost(self, a, b):
|
||||||
|
"""
|
||||||
|
计算基于语义信息的替换操作cost
|
||||||
|
:param a: 单词a的语义类别
|
||||||
|
:param b: 单词b的语义类别
|
||||||
|
:return: 替换编辑代价
|
||||||
|
"""
|
||||||
|
a_class = self._get_semantic_class(a)
|
||||||
|
b_class = self._get_semantic_class(b)
|
||||||
|
# unknown class, default to 1
|
||||||
|
if a_class is None or b_class is None:
|
||||||
|
return 4
|
||||||
|
elif a_class == b_class:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return 2 * (3 - self._get_class_diff(a_class, b_class))
|
||||||
|
|
||||||
|
def _get_pos_cost(self, a_pos, b_pos):
|
||||||
|
"""
|
||||||
|
计算基于词性信息的编辑距离cost
|
||||||
|
:param a_pos: 单词a的词性
|
||||||
|
:param b_pos: 单词b的词性
|
||||||
|
:return: 替换编辑代价
|
||||||
|
"""
|
||||||
|
if a_pos == b_pos:
|
||||||
|
return 0
|
||||||
|
elif a_pos in self._open_pos and b_pos in self._open_pos:
|
||||||
|
return 0.25
|
||||||
|
else:
|
||||||
|
return 0.499
|
||||||
|
|
||||||
|
def _get_char_cost(self, a, b, pinyin_a, pinyin_b):
|
||||||
|
"""
|
||||||
|
NOTE: This is a replacement of ERRANTS lemma cost for Chinese
|
||||||
|
计算基于字符相似度的编辑距离cost
|
||||||
|
"""
|
||||||
|
if not (check_all_chinese(a) and check_all_chinese(b)):
|
||||||
|
return 0.5
|
||||||
|
if len(a) > len(b):
|
||||||
|
a, b = b, a
|
||||||
|
pinyin_a, pinyin_b = pinyin_b, pinyin_a
|
||||||
|
if a == b:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return self._get_spell_cost(a, b, pinyin_a, pinyin_b)
|
||||||
|
|
||||||
|
def _get_spell_cost(self, a, b, pinyin_a, pinyin_b):
|
||||||
|
"""
|
||||||
|
计算两个单词拼写相似度,分别由字形相似度和字音相似度组成
|
||||||
|
:param a: 单词a
|
||||||
|
:param b: 单词b,且单词a的长度小于等于b
|
||||||
|
:param pinyin_a: 单词a的拼音
|
||||||
|
:param pinyin_b: 单词b的拼音
|
||||||
|
:return: 替换操作cost
|
||||||
|
"""
|
||||||
|
count = 0
|
||||||
|
for i in range(len(a)):
|
||||||
|
for j in range(len(b)):
|
||||||
|
if a[i] == b[j] or (set(pinyin_a) & set(pinyin_b)) or (b[j] in self.confusion_dict.keys() and a[i] in self.confusion_dict[b[j]]) or (a[i] in self.confusion_dict.keys() and b[j] in self.confusion_dict[a[i]]):
|
||||||
|
count += 1
|
||||||
|
break
|
||||||
|
return (len(a) - count) / (len(a) * 2)
|
||||||
|
|
||||||
|
def get_sub_cost(self, a_seg, b_seg):
|
||||||
|
"""
|
||||||
|
Calculate the substitution cost between words a and b
|
||||||
|
计算两个单词替换操作的编辑cost,最大为2,等于一次删除和一次添加
|
||||||
|
"""
|
||||||
|
if a_seg[0] == b_seg[0]:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if self.granularity == "word": # 词级别可以额外利用词性信息
|
||||||
|
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
|
||||||
|
pos_cost = self._get_pos_cost(a_seg[1], b_seg[1])
|
||||||
|
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
|
||||||
|
return semantic_cost + pos_cost + char_cost
|
||||||
|
else: # 字级别只能利用字义信息(从大词林中获取)和字面相似度信息
|
||||||
|
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
|
||||||
|
if a_seg[0] in punct and b_seg[0] in punct:
|
||||||
|
pos_cost = 0.0
|
||||||
|
elif a_seg[0] not in punct and b_seg[0] not in punct:
|
||||||
|
pos_cost = 0.25
|
||||||
|
else:
|
||||||
|
pos_cost = 0.499
|
||||||
|
# pos_cost = 0.0 if (a_seg[0] in punct and b_seg[0] in punct) or (a_seg[0] not in punct and b_seg[0] not in punct) else 0.5
|
||||||
|
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
|
||||||
|
return semantic_cost + char_cost + pos_cost
|
||||||
|
|
||||||
|
def align(self,
|
||||||
|
src: List[Tuple],
|
||||||
|
tgt: List[Tuple]):
|
||||||
|
"""
|
||||||
|
Based on ERRANT's alignment
|
||||||
|
基于改进的动态规划算法,为原句子的每个字打上编辑标签,以便使它能够成功转换为目标句子。
|
||||||
|
编辑操作类别:
|
||||||
|
1) M:Match,即KEEP,即当前字保持不变
|
||||||
|
2) D:Delete,删除,即当前字需要被删除
|
||||||
|
3) I:Insert,插入,即当前字需要被插入
|
||||||
|
4) T:Transposition,移位操作,即涉及到词序问题
|
||||||
|
"""
|
||||||
|
cost_matrix = np.zeros((len(src) + 1, len(tgt) + 1)) # 编辑cost矩阵
|
||||||
|
oper_matrix = np.full(
|
||||||
|
(len(src) + 1, len(tgt) + 1), "O", dtype=object
|
||||||
|
) # 操作矩阵
|
||||||
|
# Fill in the edges
|
||||||
|
for i in range(1, len(src) + 1):
|
||||||
|
cost_matrix[i][0] = cost_matrix[i - 1][0] + 1
|
||||||
|
oper_matrix[i][0] = ["D"]
|
||||||
|
for j in range(1, len(tgt) + 1):
|
||||||
|
cost_matrix[0][j] = cost_matrix[0][j - 1] + 1
|
||||||
|
oper_matrix[0][j] = ["I"]
|
||||||
|
|
||||||
|
# Loop through the cost matrix
|
||||||
|
for i in range(len(src)):
|
||||||
|
for j in range(len(tgt)):
|
||||||
|
# Matches
|
||||||
|
if src[i][0] == tgt[j][0]: # 如果两个字相等,则匹配成功(Match),编辑距离为0
|
||||||
|
cost_matrix[i + 1][j + 1] = cost_matrix[i][j]
|
||||||
|
oper_matrix[i + 1][j + 1] = ["M"]
|
||||||
|
# Non-matches
|
||||||
|
else:
|
||||||
|
del_cost = cost_matrix[i][j + 1] + self.deletion_cost # 由删除动作得到的总cost
|
||||||
|
ins_cost = cost_matrix[i + 1][j] + self.insertion_cost # 由插入动作得到的总cost
|
||||||
|
sub_cost = cost_matrix[i][j] + self.get_sub_cost(
|
||||||
|
src[i], tgt[j]
|
||||||
|
) # 由替换动作得到的总cost
|
||||||
|
# Calculate transposition cost
|
||||||
|
# 计算移位操作的总cost
|
||||||
|
trans_cost = float("inf")
|
||||||
|
k = 1
|
||||||
|
while (
|
||||||
|
i - k >= 0
|
||||||
|
and j - k >= 0
|
||||||
|
and cost_matrix[i - k + 1][j - k + 1]
|
||||||
|
!= cost_matrix[i - k][j - k]
|
||||||
|
):
|
||||||
|
p1 = sorted([a[0] for a in src][i - k: i + 1])
|
||||||
|
p2 = sorted([b[0] for b in tgt][j - k: j + 1])
|
||||||
|
if p1 == p2:
|
||||||
|
trans_cost = cost_matrix[i - k][j - k] + k
|
||||||
|
break
|
||||||
|
k += 1
|
||||||
|
|
||||||
|
costs = [trans_cost, sub_cost, ins_cost, del_cost]
|
||||||
|
ind = costs.index(min(costs))
|
||||||
|
cost_matrix[i + 1][j + 1] = costs[ind]
|
||||||
|
# ind = costs.index(costs[ind], ind+1)
|
||||||
|
for idx, cost in enumerate(costs):
|
||||||
|
if cost == costs[ind]:
|
||||||
|
if idx == 0:
|
||||||
|
if oper_matrix[i + 1][j + 1] == "O":
|
||||||
|
oper_matrix[i + 1][j + 1] = ["T" + str(k + 1)]
|
||||||
|
else:
|
||||||
|
oper_matrix[i + 1][j + 1].append("T" + str(k + 1))
|
||||||
|
elif idx == 1:
|
||||||
|
if oper_matrix[i + 1][j + 1] == "O":
|
||||||
|
oper_matrix[i + 1][j + 1] = ["S"]
|
||||||
|
else:
|
||||||
|
oper_matrix[i + 1][j + 1].append("S")
|
||||||
|
elif idx == 2:
|
||||||
|
if oper_matrix[i + 1][j + 1] == "O":
|
||||||
|
oper_matrix[i + 1][j + 1] = ["I"]
|
||||||
|
else:
|
||||||
|
oper_matrix[i + 1][j + 1].append("I")
|
||||||
|
else:
|
||||||
|
if oper_matrix[i + 1][j + 1] == "O":
|
||||||
|
oper_matrix[i + 1][j + 1] = ["D"]
|
||||||
|
else:
|
||||||
|
oper_matrix[i + 1][j + 1].append("D")
|
||||||
|
return cost_matrix, oper_matrix
|
||||||
|
|
||||||
|
def _dfs(self, i, j, align_seq_now, oper_matrix, strategy="all"):
|
||||||
|
"""
|
||||||
|
深度优先遍历,获取最小编辑距离相同的所有序列
|
||||||
|
"""
|
||||||
|
if i + j == 0:
|
||||||
|
self.align_seqs.append(align_seq_now)
|
||||||
|
else:
|
||||||
|
ops = oper_matrix[i][j] # 可以类比成搜索一棵树从根结点到叶子结点的所有路径
|
||||||
|
if strategy != "all": ops = ops[:1]
|
||||||
|
for op in ops:
|
||||||
|
if op in {"M", "S"}:
|
||||||
|
self._dfs(i - 1, j - 1, align_seq_now + [(op, i - 1, i, j - 1, j)], oper_matrix, strategy)
|
||||||
|
elif op == "D":
|
||||||
|
self._dfs(i - 1, j, align_seq_now + [(op, i - 1, i, j, j)], oper_matrix, strategy)
|
||||||
|
elif op == "I":
|
||||||
|
self._dfs(i, j - 1, align_seq_now + [(op, i, i, j - 1, j)], oper_matrix, strategy)
|
||||||
|
else:
|
||||||
|
k = int(op[1:])
|
||||||
|
self._dfs(i - k, j - k, align_seq_now + [(op, i - k, i, j - k, j)], oper_matrix, strategy)
|
||||||
|
|
||||||
|
def get_cheapest_align_seq(self, oper_matrix):
|
||||||
|
"""
|
||||||
|
回溯获得编辑距离最小的编辑序列
|
||||||
|
"""
|
||||||
|
self.align_seqs = []
|
||||||
|
i = oper_matrix.shape[0] - 1
|
||||||
|
j = oper_matrix.shape[1] - 1
|
||||||
|
if abs(i - j) > 10:
|
||||||
|
self._dfs(i, j , [], oper_matrix, "first")
|
||||||
|
else:
|
||||||
|
self._dfs(i, j , [], oper_matrix, "all")
|
||||||
|
final_align_seqs = [seq[::-1] for seq in self.align_seqs]
|
||||||
|
return final_align_seqs
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tokenizer = Tokenizer("word")
|
||||||
|
semantic_dict, semantic_class = read_cilin()
|
||||||
|
confusion_dict = read_confusion()
|
||||||
|
alignment = Alignment(semantic_dict, confusion_dict)
|
||||||
|
sents = ["首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 搾 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 6 粒 , 纯净 水 4量杯 、 香菜 半量杯 和 草菇 10 个 。".replace(" ", ""), "首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 榨 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 六 粒 , 纯净 水 四 量杯 、 香菜 半量杯 和 草菇 十 个 。".replace(" ", "")]
|
||||||
|
src, tgt = tokenizer(sents)
|
||||||
|
alignment(src, tgt, verbose=True)
|
76
opencompass/datasets/lawbench/utils/modules/annotator.py
Normal file
76
opencompass/datasets/lawbench/utils/modules/annotator.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from typing import List, Tuple
|
||||||
|
from modules.alignment import read_cilin, read_confusion, Alignment
|
||||||
|
from modules.merger import Merger
|
||||||
|
from modules.classifier import Classifier
|
||||||
|
|
||||||
|
class Annotator:
|
||||||
|
def __init__(self,
|
||||||
|
align: Alignment,
|
||||||
|
merger: Merger,
|
||||||
|
classifier: Classifier,
|
||||||
|
granularity: str = "word",
|
||||||
|
strategy: str = "first"):
|
||||||
|
self.align = align
|
||||||
|
self.merger = merger
|
||||||
|
self.classifier = classifier
|
||||||
|
self.granularity = granularity
|
||||||
|
self.strategy = strategy
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_default(cls, granularity: str = "word", strategy: str = "first"):
|
||||||
|
"""
|
||||||
|
Default parameters used in the paper
|
||||||
|
"""
|
||||||
|
semantic_dict, semantic_class = read_cilin()
|
||||||
|
confusion_dict = read_confusion()
|
||||||
|
align = Alignment(semantic_dict, confusion_dict, granularity)
|
||||||
|
merger = Merger(granularity)
|
||||||
|
classifier = Classifier(granularity)
|
||||||
|
return cls(align, merger, classifier, granularity, strategy)
|
||||||
|
|
||||||
|
def __call__(self,
|
||||||
|
src: List[Tuple],
|
||||||
|
tgt: List[Tuple],
|
||||||
|
annotator_id: int = 0,
|
||||||
|
verbose: bool = False):
|
||||||
|
"""
|
||||||
|
Align sentences and annotate them with error type information
|
||||||
|
"""
|
||||||
|
src_tokens = [x[0] for x in src]
|
||||||
|
tgt_tokens = [x[0] for x in tgt]
|
||||||
|
src_str = "".join(src_tokens)
|
||||||
|
tgt_str = "".join(tgt_tokens)
|
||||||
|
# convert to text form
|
||||||
|
annotations_out = ["S " + " ".join(src_tokens) + "\n"]
|
||||||
|
if tgt_str == "没有错误" or src_str == tgt_str: # Error Free Case
|
||||||
|
annotations_out.append(f"T{annotator_id} 没有错误\n")
|
||||||
|
cors = [tgt_str]
|
||||||
|
op, toks, inds = "noop", "-NONE-", (-1, -1)
|
||||||
|
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
|
||||||
|
annotations_out.append(a_str)
|
||||||
|
elif tgt_str == "无法标注": # Not Annotatable Case
|
||||||
|
annotations_out.append(f"T{annotator_id} 无法标注\n")
|
||||||
|
cors = [tgt_str]
|
||||||
|
op, toks, inds = "NA", "-NONE-", (-1, -1)
|
||||||
|
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
|
||||||
|
annotations_out.append(a_str)
|
||||||
|
else: # Other
|
||||||
|
align_objs = self.align(src, tgt)
|
||||||
|
edit_objs = []
|
||||||
|
align_idx = 0
|
||||||
|
if self.strategy == "first":
|
||||||
|
align_objs = align_objs[:1]
|
||||||
|
for align_obj in align_objs:
|
||||||
|
edits = self.merger(align_obj, src, tgt, verbose)
|
||||||
|
if edits not in edit_objs:
|
||||||
|
edit_objs.append(edits)
|
||||||
|
annotations_out.append(f"T{annotator_id}-A{align_idx} " + " ".join(tgt_tokens) + "\n")
|
||||||
|
align_idx += 1
|
||||||
|
cors = self.classifier(src, tgt, edits, verbose)
|
||||||
|
# annotations_out = []
|
||||||
|
for cor in cors:
|
||||||
|
op, toks, inds = cor.op, cor.toks, cor.inds
|
||||||
|
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
|
||||||
|
annotations_out.append(a_str)
|
||||||
|
annotations_out.append("\n")
|
||||||
|
return annotations_out, cors
|
151
opencompass/datasets/lawbench/utils/modules/classifier.py
Normal file
151
opencompass/datasets/lawbench/utils/modules/classifier.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
from char_smi import CharFuncs
|
||||||
|
from collections import namedtuple
|
||||||
|
from pypinyin import pinyin, Style
|
||||||
|
import os
|
||||||
|
Correction = namedtuple(
|
||||||
|
"Correction",
|
||||||
|
[
|
||||||
|
"op",
|
||||||
|
"toks",
|
||||||
|
"inds",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
file_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
char_smi = CharFuncs(os.path.join(file_path.replace("modules", ""), 'data/char_meta.txt'))
|
||||||
|
|
||||||
|
def check_spell_error(src_span: str,
|
||||||
|
tgt_span: str,
|
||||||
|
threshold: float = 0.8) -> bool:
|
||||||
|
if len(src_span) != len(tgt_span):
|
||||||
|
return False
|
||||||
|
src_chars = [ch for ch in src_span]
|
||||||
|
tgt_chars = [ch for ch in tgt_span]
|
||||||
|
if sorted(src_chars) == sorted(tgt_chars): # 词内部字符异位
|
||||||
|
return True
|
||||||
|
for src_char, tgt_char in zip(src_chars, tgt_chars):
|
||||||
|
if src_char != tgt_char:
|
||||||
|
if src_char not in char_smi.data or tgt_char not in char_smi.data:
|
||||||
|
return False
|
||||||
|
v_sim = char_smi.shape_similarity(src_char, tgt_char)
|
||||||
|
p_sim = char_smi.pronunciation_similarity(src_char, tgt_char)
|
||||||
|
if v_sim + p_sim < threshold and not (
|
||||||
|
set(pinyin(src_char, style=Style.NORMAL, heteronym=True)[0]) & set(pinyin(tgt_char, style=Style.NORMAL, heteronym=True)[0])):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
class Classifier:
|
||||||
|
"""
|
||||||
|
错误类型分类器
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
granularity: str = "word"):
|
||||||
|
|
||||||
|
self.granularity = granularity
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_pos_type(pos):
|
||||||
|
if pos in {"n", "nd"}:
|
||||||
|
return "NOUN"
|
||||||
|
if pos in {"nh", "ni", "nl", "ns", "nt", "nz"}:
|
||||||
|
return "NOUN-NE"
|
||||||
|
if pos in {"v"}:
|
||||||
|
return "VERB"
|
||||||
|
if pos in {"a", "b"}:
|
||||||
|
return "ADJ"
|
||||||
|
if pos in {"c"}:
|
||||||
|
return "CONJ"
|
||||||
|
if pos in {"r"}:
|
||||||
|
return "PRON"
|
||||||
|
if pos in {"d"}:
|
||||||
|
return "ADV"
|
||||||
|
if pos in {"u"}:
|
||||||
|
return "AUX"
|
||||||
|
# if pos in {"k"}: # TODO 后缀词比例太少,暂且分入其它
|
||||||
|
# return "SUFFIX"
|
||||||
|
if pos in {"m"}:
|
||||||
|
return "NUM"
|
||||||
|
if pos in {"p"}:
|
||||||
|
return "PREP"
|
||||||
|
if pos in {"q"}:
|
||||||
|
return "QUAN"
|
||||||
|
if pos in {"wp"}:
|
||||||
|
return "PUNCT"
|
||||||
|
return "OTHER"
|
||||||
|
|
||||||
|
def __call__(self,
|
||||||
|
src,
|
||||||
|
tgt,
|
||||||
|
edits,
|
||||||
|
verbose: bool = False):
|
||||||
|
"""
|
||||||
|
为编辑操作划分错误类型
|
||||||
|
:param src: 错误句子信息
|
||||||
|
:param tgt: 正确句子信息
|
||||||
|
:param edits: 编辑操作
|
||||||
|
:param verbose: 是否打印信息
|
||||||
|
:return: 划分完错误类型后的编辑操作
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
src_tokens = [x[0] for x in src]
|
||||||
|
tgt_tokens = [x[0] for x in tgt]
|
||||||
|
for edit in edits:
|
||||||
|
error_type = edit[0]
|
||||||
|
src_span = " ".join(src_tokens[edit[1]: edit[2]])
|
||||||
|
tgt_span = " ".join(tgt_tokens[edit[3]: edit[4]])
|
||||||
|
# print(tgt_span)
|
||||||
|
cor = None
|
||||||
|
if error_type[0] == "T":
|
||||||
|
cor = Correction("W", tgt_span, (edit[1], edit[2]))
|
||||||
|
elif error_type[0] == "D":
|
||||||
|
if self.granularity == "word": # 词级别可以细分错误类型
|
||||||
|
if edit[2] - edit[1] > 1: # 词组冗余暂时分为OTHER
|
||||||
|
cor = Correction("R:OTHER", "-NONE-", (edit[1], edit[2]))
|
||||||
|
else:
|
||||||
|
pos = self.get_pos_type(src[edit[1]][1])
|
||||||
|
pos = "NOUN" if pos == "NOUN-NE" else pos
|
||||||
|
pos = "MC" if tgt_span == "[缺失成分]" else pos
|
||||||
|
cor = Correction("R:{:s}".format(pos), "-NONE-", (edit[1], edit[2]))
|
||||||
|
else: # 字级别可以只需要根据操作划分类型即可
|
||||||
|
cor = Correction("R", "-NONE-", (edit[1], edit[2]))
|
||||||
|
elif error_type[0] == "I":
|
||||||
|
if self.granularity == "word": # 词级别可以细分错误类型
|
||||||
|
if edit[4] - edit[3] > 1: # 词组丢失暂时分为OTHER
|
||||||
|
cor = Correction("M:OTHER", tgt_span, (edit[1], edit[2]))
|
||||||
|
else:
|
||||||
|
pos = self.get_pos_type(tgt[edit[3]][1])
|
||||||
|
pos = "NOUN" if pos == "NOUN-NE" else pos
|
||||||
|
pos = "MC" if tgt_span == "[缺失成分]" else pos
|
||||||
|
cor = Correction("M:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
|
||||||
|
else: # 字级别可以只需要根据操作划分类型即可
|
||||||
|
cor = Correction("M", tgt_span, (edit[1], edit[2]))
|
||||||
|
elif error_type[0] == "S":
|
||||||
|
if self.granularity == "word": # 词级别可以细分错误类型
|
||||||
|
if check_spell_error(src_span.replace(" ", ""), tgt_span.replace(" ", "")):
|
||||||
|
cor = Correction("S:SPELL", tgt_span, (edit[1], edit[2]))
|
||||||
|
# Todo 暂且不单独区分命名实体拼写错误
|
||||||
|
# if edit[4] - edit[3] > 1:
|
||||||
|
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
|
||||||
|
# else:
|
||||||
|
# pos = self.get_pos_type(tgt[edit[3]][1])
|
||||||
|
# if pos == "NOUN-NE": # 命名实体拼写有误
|
||||||
|
# cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2]))
|
||||||
|
# else: # 普通词语拼写有误
|
||||||
|
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
|
||||||
|
else:
|
||||||
|
if edit[4] - edit[3] > 1: # 词组被替换暂时分为OTHER
|
||||||
|
cor = Correction("S:OTHER", tgt_span, (edit[1], edit[2]))
|
||||||
|
else:
|
||||||
|
pos = self.get_pos_type(tgt[edit[3]][1])
|
||||||
|
pos = "NOUN" if pos == "NOUN-NE" else pos
|
||||||
|
pos = "MC" if tgt_span == "[缺失成分]" else pos
|
||||||
|
cor = Correction("S:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
|
||||||
|
else: # 字级别可以只需要根据操作划分类型即可
|
||||||
|
cor = Correction("S", tgt_span, (edit[1], edit[2]))
|
||||||
|
results.append(cor)
|
||||||
|
if verbose:
|
||||||
|
print("========== Corrections ==========")
|
||||||
|
for cor in results:
|
||||||
|
print("Type: {:s}, Position: {:d} -> {:d}, Target: {:s}".format(cor.op, cor.inds[0], cor.inds[1], cor.toks))
|
||||||
|
return results
|
||||||
|
|
||||||
|
# print(pinyin("朝", style=Style.NORMAL))
|
273
opencompass/datasets/lawbench/utils/modules/merger.py
Normal file
273
opencompass/datasets/lawbench/utils/modules/merger.py
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
from itertools import groupby
|
||||||
|
from string import punctuation
|
||||||
|
from typing import List
|
||||||
|
from modules.tokenizer import Tokenizer
|
||||||
|
from modules.alignment import Alignment, read_cilin, read_confusion
|
||||||
|
import Levenshtein
|
||||||
|
|
||||||
|
class Merger:
|
||||||
|
"""
|
||||||
|
合并编辑操作,从Token-Level转换为Span-Level
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
granularity: str = "word",
|
||||||
|
merge: bool = False):
|
||||||
|
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧."
|
||||||
|
self.punctuation = punctuation + chinese_punct
|
||||||
|
self.not_merge_token = [punct for punct in self.punctuation]
|
||||||
|
self.granularity = granularity
|
||||||
|
self.merge = merge
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_edits(seq, tag="X"):
|
||||||
|
if seq:
|
||||||
|
return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
|
||||||
|
else:
|
||||||
|
return seq
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_revolve(span_a, span_b):
|
||||||
|
span_a = span_a + span_a
|
||||||
|
return span_b in span_a
|
||||||
|
|
||||||
|
def _process_seq(self, seq, src_tokens, tgt_tokens):
|
||||||
|
if len(seq) <= 1:
|
||||||
|
return seq
|
||||||
|
|
||||||
|
ops = [op[0] for op in seq]
|
||||||
|
if set(ops) == {"D"} or set(ops) == {"I"}:
|
||||||
|
return self._merge_edits(seq, set(ops).pop())
|
||||||
|
|
||||||
|
if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}:
|
||||||
|
# do not merge this pattern_from_qua.txt
|
||||||
|
return seq
|
||||||
|
|
||||||
|
if set(ops) == {"S"}:
|
||||||
|
if self.granularity == "word":
|
||||||
|
return seq
|
||||||
|
else:
|
||||||
|
return self._merge_edits(seq, "S")
|
||||||
|
|
||||||
|
if set(ops) == {"M"}:
|
||||||
|
return self._merge_edits(seq, "M")
|
||||||
|
|
||||||
|
return self._merge_edits(seq, "S")
|
||||||
|
|
||||||
|
def __call__(self,
|
||||||
|
align_obj,
|
||||||
|
src: List,
|
||||||
|
tgt: List,
|
||||||
|
verbose: bool = False):
|
||||||
|
"""
|
||||||
|
Based on ERRANT's merge, adapted for Chinese
|
||||||
|
"""
|
||||||
|
src_tokens = [x[0] for x in src]
|
||||||
|
tgt_tokens = [x[0] for x in tgt]
|
||||||
|
edits = []
|
||||||
|
# Split alignment into groups of M, T and rest. (T has a number after it)
|
||||||
|
# Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并
|
||||||
|
# Todo 缺失成分标签也不与其它编辑合并
|
||||||
|
for op, group in groupby(
|
||||||
|
align_obj,
|
||||||
|
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False,
|
||||||
|
):
|
||||||
|
group = list(group)
|
||||||
|
# T is always split TODO: Evaluate this
|
||||||
|
if op == "T":
|
||||||
|
for seq in group:
|
||||||
|
edits.append(seq)
|
||||||
|
# Process D, I and S subsequence
|
||||||
|
else:
|
||||||
|
# Turn the processed sequence into edits
|
||||||
|
processed = self._process_seq(group, src_tokens, tgt_tokens)
|
||||||
|
for seq in processed:
|
||||||
|
edits.append(seq)
|
||||||
|
|
||||||
|
filtered_edits = []
|
||||||
|
i = 0
|
||||||
|
while i < len(edits):
|
||||||
|
e1 = edits[i][0][0]
|
||||||
|
|
||||||
|
if i < len(edits) - 2:
|
||||||
|
e2 = edits[i + 1][0][0]
|
||||||
|
e3 = edits[i + 2][0][0]
|
||||||
|
|
||||||
|
# Find "S M S" patterns
|
||||||
|
# Ex:
|
||||||
|
# S M S
|
||||||
|
# 冬阴功 对 外国人
|
||||||
|
# 外国人 对 冬阴功
|
||||||
|
if e1 == "S" and e2 == "M" and e3 == "S":
|
||||||
|
w1 = "".join(src_tokens[edits[i][1]: edits[i][2]])
|
||||||
|
w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]])
|
||||||
|
w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]])
|
||||||
|
w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]])
|
||||||
|
if min([len(w1), len(w2), len(w3), len(w4)]) == 1:
|
||||||
|
if w1 == w4 and w2 == w3:
|
||||||
|
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||||||
|
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
|
||||||
|
for seq in processed:
|
||||||
|
filtered_edits.append(seq)
|
||||||
|
i += 3
|
||||||
|
else:
|
||||||
|
filtered_edits.append(edits[i])
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1:
|
||||||
|
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||||||
|
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
|
||||||
|
for seq in processed:
|
||||||
|
filtered_edits.append(seq)
|
||||||
|
i += 3
|
||||||
|
else:
|
||||||
|
filtered_edits.append(edits[i])
|
||||||
|
i += 1
|
||||||
|
# Find "D M I" or "I M D" patterns
|
||||||
|
# Ex:
|
||||||
|
# D M I
|
||||||
|
# 旅游 去 陌生 的 地方
|
||||||
|
# 去 陌生 的 地方 旅游
|
||||||
|
elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"):
|
||||||
|
if e1 == "D":
|
||||||
|
delete_token = src_tokens[edits[i][1]: edits[i][2]]
|
||||||
|
insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]
|
||||||
|
else:
|
||||||
|
delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]]
|
||||||
|
insert_token = tgt_tokens[edits[i][3]: edits[i][4]]
|
||||||
|
a, b = "".join(delete_token), "".join(insert_token)
|
||||||
|
if len(a) < len(b):
|
||||||
|
a, b = b, a
|
||||||
|
if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1:
|
||||||
|
if len(b) == 1:
|
||||||
|
if a == b:
|
||||||
|
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||||||
|
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
|
||||||
|
for seq in processed:
|
||||||
|
filtered_edits.append(seq)
|
||||||
|
i += 3
|
||||||
|
else:
|
||||||
|
filtered_edits.append(edits[i])
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)):
|
||||||
|
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||||||
|
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
|
||||||
|
for seq in processed:
|
||||||
|
filtered_edits.append(seq)
|
||||||
|
i += 3
|
||||||
|
else:
|
||||||
|
filtered_edits.append(edits[i])
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
filtered_edits.append(edits[i])
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
if e1 != "M":
|
||||||
|
filtered_edits.append(edits[i])
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
if e1 != "M":
|
||||||
|
filtered_edits.append(edits[i])
|
||||||
|
i += 1
|
||||||
|
# In rare cases with word-level tokenization, the following error can occur:
|
||||||
|
# M D S M
|
||||||
|
# 有 時 住 上層
|
||||||
|
# 有 時住 上層
|
||||||
|
# Which results in S: 時住 --> 時住
|
||||||
|
# We need to filter this case out
|
||||||
|
second_filter = []
|
||||||
|
for edit in filtered_edits: # 避免因为分词错误导致的mismatch现象
|
||||||
|
span1 = "".join(src_tokens[edit[1] : edit[2]])
|
||||||
|
span2 = "".join(tgt_tokens[edit[3] : edit[4]])
|
||||||
|
|
||||||
|
if span1 != span2:
|
||||||
|
if edit[0] == "S":
|
||||||
|
b = True
|
||||||
|
# In rare cases with word-level tokenization, the following error can occur:
|
||||||
|
# S I I M
|
||||||
|
# 负责任 老师
|
||||||
|
# 负 责任 的 老师
|
||||||
|
# Which results in S: 负责任 --> 负 责任 的
|
||||||
|
# We need to convert this edit to I: --> 的
|
||||||
|
|
||||||
|
# 首部有重叠
|
||||||
|
common_str = ""
|
||||||
|
tmp_new_start_1 = edit[1]
|
||||||
|
for i in range(edit[1], edit[2]):
|
||||||
|
if not span2.startswith(common_str + src_tokens[i]):
|
||||||
|
break
|
||||||
|
common_str += src_tokens[i]
|
||||||
|
tmp_new_start_1 = i + 1
|
||||||
|
new_start_1, new_start_2 = edit[1], edit[3]
|
||||||
|
if common_str:
|
||||||
|
tmp_str = ""
|
||||||
|
for i in range(edit[3], edit[4]):
|
||||||
|
tmp_str += tgt_tokens[i]
|
||||||
|
if tmp_str == common_str:
|
||||||
|
new_start_1, new_start_2 = tmp_new_start_1, i + 1
|
||||||
|
# second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4]))
|
||||||
|
b = False
|
||||||
|
break
|
||||||
|
elif len(tmp_str) > len(common_str):
|
||||||
|
break
|
||||||
|
# 尾部有重叠
|
||||||
|
common_str = ""
|
||||||
|
new_end_1, new_end_2 = edit[2], edit[4]
|
||||||
|
tmp_new_end_1 = edit[2]
|
||||||
|
for i in reversed(range(new_start_1, edit[2])):
|
||||||
|
if not span2.endswith(src_tokens[i] + common_str):
|
||||||
|
break
|
||||||
|
common_str = src_tokens[i] + common_str
|
||||||
|
tmp_new_end_1 = i
|
||||||
|
if common_str:
|
||||||
|
tmp_str = ""
|
||||||
|
for i in reversed(range(new_start_2, edit[4])):
|
||||||
|
tmp_str = tgt_tokens[i] + tmp_str
|
||||||
|
if tmp_str == common_str:
|
||||||
|
new_end_1, new_end_2 = tmp_new_end_1, i
|
||||||
|
b = False
|
||||||
|
break
|
||||||
|
elif len(tmp_str) > len(common_str):
|
||||||
|
break
|
||||||
|
if b:
|
||||||
|
second_filter.append(edit)
|
||||||
|
else:
|
||||||
|
if new_start_1 == new_end_1:
|
||||||
|
new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2)
|
||||||
|
elif new_start_2 == new_end_2:
|
||||||
|
new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2)
|
||||||
|
else:
|
||||||
|
new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2)
|
||||||
|
second_filter.append(new_edit)
|
||||||
|
else:
|
||||||
|
second_filter.append(edit)
|
||||||
|
if verbose:
|
||||||
|
print("========== Parallels ==========")
|
||||||
|
print("".join(src_tokens))
|
||||||
|
print("".join(tgt_tokens))
|
||||||
|
print("========== Results ==========")
|
||||||
|
for edit in second_filter:
|
||||||
|
op = edit[0]
|
||||||
|
s = " ".join(src_tokens[edit[1]: edit[2]])
|
||||||
|
t = " ".join(tgt_tokens[edit[3]: edit[4]])
|
||||||
|
print(f"{op}:\t{s}\t-->\t{t}")
|
||||||
|
print("========== Infos ==========")
|
||||||
|
print(str(src))
|
||||||
|
print(str(tgt))
|
||||||
|
return second_filter
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tokenizer = Tokenizer("char")
|
||||||
|
semantic_dict, semantic_class = read_cilin()
|
||||||
|
confusion_dict = read_confusion()
|
||||||
|
alignment = Alignment(semantic_dict, confusion_dict)
|
||||||
|
sents = [
|
||||||
|
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace(
|
||||||
|
" ", ""),
|
||||||
|
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace(
|
||||||
|
" ", "")]
|
||||||
|
src, tgt = tokenizer(sents)
|
||||||
|
align_obj = alignment(src, tgt)
|
||||||
|
m = Merger()
|
||||||
|
m(align_obj, src, tgt, verbose=True)
|
346
opencompass/datasets/lawbench/utils/modules/tokenization.py
Normal file
346
opencompass/datasets/lawbench/utils/modules/tokenization.py
Normal file
@ -0,0 +1,346 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tokenization classes."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import unicodedata
|
||||||
|
import six
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_unicode(text):
|
||||||
|
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||||
|
if six.PY3:
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text
|
||||||
|
elif isinstance(text, bytes):
|
||||||
|
return text.decode("utf-8", "ignore")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||||
|
elif six.PY2:
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text.decode("utf-8", "ignore")
|
||||||
|
elif isinstance(text, unicode):
|
||||||
|
return text
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||||
|
else:
|
||||||
|
raise ValueError("Not running on Python2 or Python 3?")
|
||||||
|
|
||||||
|
|
||||||
|
def printable_text(text):
|
||||||
|
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
||||||
|
|
||||||
|
# These functions want `str` for both Python2 and Python3, but in one case
|
||||||
|
# it's a Unicode string and in the other it's a byte string.
|
||||||
|
if six.PY3:
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text
|
||||||
|
elif isinstance(text, bytes):
|
||||||
|
return text.decode("utf-8", "ignore")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||||
|
elif six.PY2:
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text
|
||||||
|
elif isinstance(text, unicode):
|
||||||
|
return text.encode("utf-8")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||||
|
else:
|
||||||
|
raise ValueError("Not running on Python2 or Python 3?")
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocab(vocab_file):
|
||||||
|
"""Loads a vocabulary file into a dictionary."""
|
||||||
|
vocab = collections.OrderedDict()
|
||||||
|
index = 0
|
||||||
|
with open(vocab_file, "r") as reader:
|
||||||
|
while True:
|
||||||
|
token = convert_to_unicode(reader.readline())
|
||||||
|
if not token:
|
||||||
|
break
|
||||||
|
token = token.strip()
|
||||||
|
vocab[token] = index
|
||||||
|
index += 1
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
def convert_by_vocab(vocab, items):
|
||||||
|
"""Converts a sequence of [tokens|ids] using the vocab."""
|
||||||
|
output = []
|
||||||
|
for item in items:
|
||||||
|
if item not in vocab:
|
||||||
|
print("warning: %s not in vocab" % item)
|
||||||
|
item = "[UNK]"
|
||||||
|
output.append(vocab[item])
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tokens_to_ids(vocab, tokens):
|
||||||
|
return convert_by_vocab(vocab, tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ids_to_tokens(inv_vocab, ids):
|
||||||
|
return convert_by_vocab(inv_vocab, ids)
|
||||||
|
|
||||||
|
|
||||||
|
def whitespace_tokenize(text):
|
||||||
|
"""Runs basic whitespace cleaning and splitting on a peice of text."""
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
tokens = text.split()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class FullTokenizer(object):
|
||||||
|
"""Runs end-to-end tokenziation."""
|
||||||
|
|
||||||
|
def __init__(self, vocab_file, do_lower_case=True):
|
||||||
|
self.vocab = load_vocab(vocab_file)
|
||||||
|
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||||
|
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||||
|
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
split_tokens = []
|
||||||
|
for token in self.basic_tokenizer.tokenize(text):
|
||||||
|
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||||
|
split_tokens.append(sub_token)
|
||||||
|
|
||||||
|
return split_tokens
|
||||||
|
|
||||||
|
def convert_tokens_to_ids(self, tokens):
|
||||||
|
return convert_by_vocab(self.vocab, tokens)
|
||||||
|
|
||||||
|
def convert_ids_to_tokens(self, ids):
|
||||||
|
return convert_by_vocab(self.inv_vocab, ids)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTokenizer(object):
|
||||||
|
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||||
|
|
||||||
|
def __init__(self, do_lower_case=True):
|
||||||
|
"""Constructs a BasicTokenizer.
|
||||||
|
Args:
|
||||||
|
do_lower_case: Whether to lower case the input.
|
||||||
|
"""
|
||||||
|
self.do_lower_case = do_lower_case
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
"""Tokenizes a piece of text."""
|
||||||
|
text = convert_to_unicode(text)
|
||||||
|
text = self._clean_text(text)
|
||||||
|
|
||||||
|
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||||
|
# models. This is also applied to the English models now, but it doesn't
|
||||||
|
# matter since the English models were not trained on any Chinese data
|
||||||
|
# and generally don't have any Chinese data in them (there are Chinese
|
||||||
|
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||||
|
# words in the English Wikipedia.).
|
||||||
|
text = self._tokenize_chinese_chars(text)
|
||||||
|
|
||||||
|
orig_tokens = whitespace_tokenize(text)
|
||||||
|
split_tokens = []
|
||||||
|
for token in orig_tokens:
|
||||||
|
if self.do_lower_case:
|
||||||
|
token = token.lower()
|
||||||
|
token = self._run_strip_accents(token)
|
||||||
|
split_tokens.extend(self._run_split_on_punc(token))
|
||||||
|
|
||||||
|
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||||
|
return output_tokens
|
||||||
|
|
||||||
|
def _run_strip_accents(self, text):
|
||||||
|
"""Strips accents from a piece of text."""
|
||||||
|
text = unicodedata.normalize("NFD", text)
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat == "Mn":
|
||||||
|
continue
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
def _run_split_on_punc(self, text):
|
||||||
|
"""Splits punctuation on a piece of text."""
|
||||||
|
chars = list(text)
|
||||||
|
i = 0
|
||||||
|
start_new_word = True
|
||||||
|
output = []
|
||||||
|
while i < len(chars):
|
||||||
|
char = chars[i]
|
||||||
|
if _is_punctuation(char):
|
||||||
|
output.append([char])
|
||||||
|
start_new_word = True
|
||||||
|
else:
|
||||||
|
if start_new_word:
|
||||||
|
output.append([])
|
||||||
|
start_new_word = False
|
||||||
|
output[-1].append(char)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return ["".join(x) for x in output]
|
||||||
|
|
||||||
|
def _tokenize_chinese_chars(self, text):
|
||||||
|
"""Adds whitespace around any CJK character."""
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cp = ord(char)
|
||||||
|
if self._is_chinese_char(cp):
|
||||||
|
output.append(" ")
|
||||||
|
output.append(char)
|
||||||
|
output.append(" ")
|
||||||
|
else:
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
def _is_chinese_char(self, cp):
|
||||||
|
"""Checks whether CP is the codepoint of a CJK character."""
|
||||||
|
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||||
|
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||||
|
#
|
||||||
|
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||||
|
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||||
|
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||||
|
# space-separated words, so they are not treated specially and handled
|
||||||
|
# like the all of the other languages.
|
||||||
|
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||||
|
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||||
|
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||||
|
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||||
|
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||||
|
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||||
|
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||||
|
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _clean_text(self, text):
|
||||||
|
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cp = ord(char)
|
||||||
|
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||||
|
continue
|
||||||
|
if _is_whitespace(char):
|
||||||
|
output.append(" ")
|
||||||
|
else:
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
|
||||||
|
class WordpieceTokenizer(object):
|
||||||
|
"""Runs WordPiece tokenziation."""
|
||||||
|
|
||||||
|
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.unk_token = unk_token
|
||||||
|
self.max_input_chars_per_word = max_input_chars_per_word
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
"""Tokenizes a piece of text into its word pieces.
|
||||||
|
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||||
|
using the given vocabulary.
|
||||||
|
For example:
|
||||||
|
input = "unaffable"
|
||||||
|
output = ["un", "##aff", "##able"]
|
||||||
|
Args:
|
||||||
|
text: A single token or whitespace separated tokens. This should have
|
||||||
|
already been passed through `BasicTokenizer.
|
||||||
|
Returns:
|
||||||
|
A list of wordpiece tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
text = convert_to_unicode(text)
|
||||||
|
|
||||||
|
output_tokens = []
|
||||||
|
for token in whitespace_tokenize(text):
|
||||||
|
chars = list(token)
|
||||||
|
if len(chars) > self.max_input_chars_per_word:
|
||||||
|
output_tokens.append(self.unk_token)
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_bad = False
|
||||||
|
start = 0
|
||||||
|
sub_tokens = []
|
||||||
|
while start < len(chars):
|
||||||
|
end = len(chars)
|
||||||
|
cur_substr = None
|
||||||
|
while start < end:
|
||||||
|
substr = "".join(chars[start:end])
|
||||||
|
if start > 0:
|
||||||
|
substr = "##" + substr
|
||||||
|
if substr in self.vocab:
|
||||||
|
cur_substr = substr
|
||||||
|
break
|
||||||
|
end -= 1
|
||||||
|
if cur_substr is None:
|
||||||
|
is_bad = True
|
||||||
|
break
|
||||||
|
sub_tokens.append(cur_substr)
|
||||||
|
start = end
|
||||||
|
|
||||||
|
if is_bad:
|
||||||
|
# output_tokens.append(self.unk_token)
|
||||||
|
output_tokens.append(token) # keep the UNK token
|
||||||
|
else:
|
||||||
|
output_tokens.extend(sub_tokens)
|
||||||
|
return output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _is_whitespace(char):
|
||||||
|
"""Checks whether `chars` is a whitespace character."""
|
||||||
|
# \t, \n, and \r are technically contorl characters but we treat them
|
||||||
|
# as whitespace since they are generally considered as such.
|
||||||
|
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||||
|
return True
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat == "Zs":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_control(char):
|
||||||
|
"""Checks whether `chars` is a control character."""
|
||||||
|
# These are technically control characters but we count them as whitespace
|
||||||
|
# characters.
|
||||||
|
if char == "\t" or char == "\n" or char == "\r":
|
||||||
|
return False
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat.startswith("C"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_punctuation(char):
|
||||||
|
"""Checks whether `chars` is a punctuation character."""
|
||||||
|
cp = ord(char)
|
||||||
|
# We treat all non-letter/number ASCII as punctuation.
|
||||||
|
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||||
|
# Punctuation class but we treat them as punctuation anyways, for
|
||||||
|
# consistency.
|
||||||
|
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||||
|
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||||
|
return True
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat.startswith("P"):
|
||||||
|
return True
|
||||||
|
return False
|
92
opencompass/datasets/lawbench/utils/modules/tokenizer.py
Normal file
92
opencompass/datasets/lawbench/utils/modules/tokenizer.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
from ltp import LTP
|
||||||
|
from typing import List
|
||||||
|
from pypinyin import pinyin, Style, lazy_pinyin
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import functools
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
"""
|
||||||
|
分词器
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
granularity: str = "word",
|
||||||
|
device: str = "cpu",
|
||||||
|
segmented: bool = False,
|
||||||
|
bpe: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
构造函数
|
||||||
|
:param mode: 分词模式,可选级别:字级别(char)、词级别(word)
|
||||||
|
"""
|
||||||
|
self.ltp = None
|
||||||
|
if granularity == "word":
|
||||||
|
self.ltp = LTP(device=torch.device(device) if torch.cuda.is_available() else torch.device("cpu"))
|
||||||
|
self.ltp.add_words(words=["[缺失成分]"], max_window=6)
|
||||||
|
self.segmented = segmented
|
||||||
|
self.granularity = granularity
|
||||||
|
if self.granularity == "word":
|
||||||
|
self.tokenizer = self.split_word
|
||||||
|
elif self.granularity == "char":
|
||||||
|
self.tokenizer = functools.partial(self.split_char, bpe=bpe)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "{:s}\nMode:{:s}\n}".format(str(self.__class__.__name__), self.mode)
|
||||||
|
|
||||||
|
def __call__(self,
|
||||||
|
input_strings: List[str]
|
||||||
|
) -> List:
|
||||||
|
"""
|
||||||
|
分词函数
|
||||||
|
:param input_strings: 需要分词的字符串列表
|
||||||
|
:return: 分词后的结果列表,由元组组成,元组为(token,pos_tag,pinyin)的形式
|
||||||
|
"""
|
||||||
|
if not self.segmented:
|
||||||
|
input_strings = ["".join(s.split(" ")) for s in input_strings]
|
||||||
|
results = self.tokenizer(input_strings)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def split_char(self, input_strings: List[str], bpe=False) -> List:
|
||||||
|
"""
|
||||||
|
分字函数
|
||||||
|
:param input_strings: 需要分字的字符串
|
||||||
|
:return: 分字结果
|
||||||
|
"""
|
||||||
|
if bpe:
|
||||||
|
from . import tokenization
|
||||||
|
project_dir = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(project_dir,"data","chinese_vocab.txt"), do_lower_case=False)
|
||||||
|
results = []
|
||||||
|
for input_string in input_strings:
|
||||||
|
if not self.segmented: # 如果没有被分字,就按照每个字符隔开(不考虑英文标点的特殊处理,也不考虑BPE),否则遵循原分字结果
|
||||||
|
segment_string = " ".join([char for char in input_string] if not bpe else tokenizer.tokenize(input_string))
|
||||||
|
else:
|
||||||
|
segment_string = input_string
|
||||||
|
# print(segment_string)
|
||||||
|
segment_string = segment_string.replace("[ 缺 失 成 分 ]", "[缺失成分]").split(" ") # 缺失成分当成一个单独的token
|
||||||
|
results.append([(char, "unk", pinyin(char, style=Style.NORMAL, heteronym=True)[0]) for char in segment_string])
|
||||||
|
return results
|
||||||
|
|
||||||
|
def split_word(self, input_strings: List[str]) -> List:
|
||||||
|
"""
|
||||||
|
分词函数
|
||||||
|
:param input_strings: 需要分词的字符串
|
||||||
|
:return: 分词结果
|
||||||
|
"""
|
||||||
|
if self.segmented:
|
||||||
|
seg, hidden = self.ltp.seg([input_string.split(" ") for input_string in input_strings], is_preseged=True)
|
||||||
|
else:
|
||||||
|
seg, hidden = self.ltp.seg(input_strings)
|
||||||
|
pos = self.ltp.pos(hidden)
|
||||||
|
result = []
|
||||||
|
for s, p in zip(seg, pos):
|
||||||
|
pinyin = [lazy_pinyin(word) for word in s]
|
||||||
|
result.append(list(zip(s, p, pinyin)))
|
||||||
|
return result
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tokenizer = Tokenizer("word")
|
||||||
|
print(tokenizer(["LAC是个优秀的分词工具", "百度是一家高科技公司"]))
|
221
opencompass/datasets/lawbench/utils/parallel_to_m2.py
Normal file
221
opencompass/datasets/lawbench/utils/parallel_to_m2.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
import os
|
||||||
|
from modules.annotator import Annotator
|
||||||
|
from modules.tokenizer import Tokenizer
|
||||||
|
import argparse
|
||||||
|
from collections import Counter
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
from collections import defaultdict
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from opencc import OpenCC
|
||||||
|
import timeout_decorator
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
annotator, sentence_to_tokenized = None, None
|
||||||
|
cc = OpenCC("t2s")
|
||||||
|
|
||||||
|
@timeout_decorator.timeout(10)
|
||||||
|
def annotate_with_time_out(line):
|
||||||
|
"""
|
||||||
|
:param line:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
sent_list = line.split("\t")[1:]
|
||||||
|
source = sent_list[0]
|
||||||
|
if args.segmented:
|
||||||
|
source = source.strip()
|
||||||
|
else:
|
||||||
|
source = "".join(source.strip().split())
|
||||||
|
output_str = ""
|
||||||
|
for idx, target in enumerate(sent_list[1:]):
|
||||||
|
try:
|
||||||
|
if args.segmented:
|
||||||
|
target = target.strip()
|
||||||
|
else:
|
||||||
|
target = "".join(target.strip().split())
|
||||||
|
if not args.no_simplified:
|
||||||
|
target = cc.convert(target)
|
||||||
|
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
|
||||||
|
out, cors = annotator(source_tokenized, target_tokenized, idx)
|
||||||
|
if idx == 0:
|
||||||
|
output_str += "".join(out[:-1])
|
||||||
|
else:
|
||||||
|
output_str += "".join(out[1:-1])
|
||||||
|
except Exception:
|
||||||
|
raise Exception
|
||||||
|
return output_str
|
||||||
|
|
||||||
|
|
||||||
|
def annotate(line):
|
||||||
|
"""
|
||||||
|
:param line:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
sent_list = line.split("\t")[1:]
|
||||||
|
source = sent_list[0]
|
||||||
|
if args.segmented:
|
||||||
|
source = source.strip()
|
||||||
|
else:
|
||||||
|
source = "".join(source.strip().split())
|
||||||
|
output_str = ""
|
||||||
|
for idx, target in enumerate(sent_list[1:]):
|
||||||
|
try:
|
||||||
|
if args.segmented:
|
||||||
|
target = target.strip()
|
||||||
|
else:
|
||||||
|
target = "".join(target.strip().split())
|
||||||
|
if not args.no_simplified:
|
||||||
|
target = cc.convert(target)
|
||||||
|
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
|
||||||
|
out, cors = annotator(source_tokenized, target_tokenized, idx)
|
||||||
|
if idx == 0:
|
||||||
|
output_str += "".join(out[:-1])
|
||||||
|
else:
|
||||||
|
output_str += "".join(out[1:-1])
|
||||||
|
except Exception:
|
||||||
|
raise Exception
|
||||||
|
return output_str
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def firsttime_process(args):
|
||||||
|
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
|
||||||
|
global annotator, sentence_to_tokenized
|
||||||
|
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
|
||||||
|
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n") # format: id src tgt1 tgt2...
|
||||||
|
# error_types = []
|
||||||
|
|
||||||
|
with open(args.output, "w", encoding="utf-8") as f:
|
||||||
|
count = 0
|
||||||
|
sentence_set = set()
|
||||||
|
sentence_to_tokenized = {}
|
||||||
|
for line in lines:
|
||||||
|
sent_list = line.split("\t")[1:]
|
||||||
|
for idx, sent in enumerate(sent_list):
|
||||||
|
if args.segmented:
|
||||||
|
# print(sent)
|
||||||
|
sent = sent.strip()
|
||||||
|
else:
|
||||||
|
sent = "".join(sent.split()).strip()
|
||||||
|
if idx >= 1:
|
||||||
|
if not args.no_simplified:
|
||||||
|
sentence_set.add(cc.convert(sent))
|
||||||
|
else:
|
||||||
|
sentence_set.add(sent)
|
||||||
|
else:
|
||||||
|
sentence_set.add(sent)
|
||||||
|
batch = []
|
||||||
|
for sent in tqdm(sentence_set):
|
||||||
|
count += 1
|
||||||
|
if sent:
|
||||||
|
batch.append(sent)
|
||||||
|
if count % args.batch_size == 0:
|
||||||
|
results = tokenizer(batch)
|
||||||
|
for s, r in zip(batch, results):
|
||||||
|
sentence_to_tokenized[s] = r # Get tokenization map.
|
||||||
|
batch = []
|
||||||
|
if batch:
|
||||||
|
results = tokenizer(batch)
|
||||||
|
for s, r in zip(batch, results):
|
||||||
|
sentence_to_tokenized[s] = r # Get tokenization map.
|
||||||
|
|
||||||
|
timeout_indices = []
|
||||||
|
|
||||||
|
# 单进程模式
|
||||||
|
for idx, line in enumerate(tqdm(lines)):
|
||||||
|
try:
|
||||||
|
ret = annotate_with_time_out(line)
|
||||||
|
except Exception:
|
||||||
|
timeout_indices.append(idx)
|
||||||
|
return timeout_indices
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
timeout_indices = firsttime_process(args)
|
||||||
|
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
|
||||||
|
global annotator, sentence_to_tokenized
|
||||||
|
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
|
||||||
|
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n")
|
||||||
|
new_lines = []# format: id src tgt1 tgt2...
|
||||||
|
|
||||||
|
with open(args.output, "w", encoding="utf-8") as f:
|
||||||
|
count = 0
|
||||||
|
sentence_set = set()
|
||||||
|
sentence_to_tokenized = {}
|
||||||
|
for line_idx, line in enumerate(lines):
|
||||||
|
|
||||||
|
if line_idx in timeout_indices:
|
||||||
|
# print(f"line before split: {line}")
|
||||||
|
line_split = line.split("\t")
|
||||||
|
line_number, sent_list = line_split[0], line_split[1:]
|
||||||
|
assert len(sent_list) == 2
|
||||||
|
sent_list[-1] = " 无"
|
||||||
|
line = line_number + "\t" + "\t".join(sent_list)
|
||||||
|
# print(f"line time out: {line}")
|
||||||
|
new_lines.append(line)
|
||||||
|
else:
|
||||||
|
new_lines.append(line)
|
||||||
|
|
||||||
|
sent_list = line.split("\t")[1:]
|
||||||
|
for idx, sent in enumerate(sent_list):
|
||||||
|
if args.segmented:
|
||||||
|
# print(sent)
|
||||||
|
sent = sent.strip()
|
||||||
|
else:
|
||||||
|
sent = "".join(sent.split()).strip()
|
||||||
|
if idx >= 1:
|
||||||
|
if not args.no_simplified:
|
||||||
|
sentence_set.add(cc.convert(sent))
|
||||||
|
else:
|
||||||
|
sentence_set.add(sent)
|
||||||
|
else:
|
||||||
|
sentence_set.add(sent)
|
||||||
|
batch = []
|
||||||
|
for sent in tqdm(sentence_set):
|
||||||
|
count += 1
|
||||||
|
if sent:
|
||||||
|
batch.append(sent)
|
||||||
|
if count % args.batch_size == 0:
|
||||||
|
results = tokenizer(batch)
|
||||||
|
for s, r in zip(batch, results):
|
||||||
|
sentence_to_tokenized[s] = r # Get tokenization map.
|
||||||
|
batch = []
|
||||||
|
if batch:
|
||||||
|
results = tokenizer(batch)
|
||||||
|
for s, r in zip(batch, results):
|
||||||
|
sentence_to_tokenized[s] = r # Get tokenization map.
|
||||||
|
|
||||||
|
# 单进程模式
|
||||||
|
lines = new_lines
|
||||||
|
for idx, line in enumerate(tqdm(lines)):
|
||||||
|
ret = annotate(line)
|
||||||
|
f.write(ret)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
# 多进程模式:仅在Linux环境下测试,建议在linux服务器上使用
|
||||||
|
# with Pool(args.worker_num) as pool:
|
||||||
|
# for ret in pool.imap(annotate, tqdm(lines), chunksize=8):
|
||||||
|
# if ret:
|
||||||
|
# f.write(ret)
|
||||||
|
# f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Choose input file to annotate")
|
||||||
|
parser.add_argument("-f", "--file", type=str, required=True, help="Input parallel file")
|
||||||
|
parser.add_argument("-o", "--output", type=str, help="Output file", required=True)
|
||||||
|
parser.add_argument("-b", "--batch_size", type=int, help="The size of batch", default=128)
|
||||||
|
parser.add_argument("-d", "--device", type=int, help="The ID of GPU", default=0)
|
||||||
|
parser.add_argument("-w", "--worker_num", type=int, help="The number of workers", default=16)
|
||||||
|
parser.add_argument("-g", "--granularity", type=str, help="Choose char-level or word-level evaluation", default="char")
|
||||||
|
parser.add_argument("-m", "--merge", help="Whether merge continuous replacement/deletion/insertion", action="store_true")
|
||||||
|
parser.add_argument("-s", "--multi_cheapest_strategy", type=str, choices=["first", "all"], default="all")
|
||||||
|
parser.add_argument("--segmented", help="Whether tokens have been segmented", action="store_true") # 支持提前token化,用空格隔开
|
||||||
|
parser.add_argument("--no_simplified", help="Whether simplifying chinese", action="store_true") # 将所有corrections转换为简体中文
|
||||||
|
parser.add_argument("--bpe", help="Whether to use bpe", action="store_true") # 支持 bpe 切分英文单词
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
158
opencompass/datasets/lawbench/utils/rc_f1.py
Normal file
158
opencompass/datasets/lawbench/utils/rc_f1.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
"""Official evaluation script for CAIL-2021.
|
||||||
|
|
||||||
|
The code is based partially on CoQA evaluation script.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
|
||||||
|
class CJRCEvaluator:
|
||||||
|
def __init__(self, gold_file):
|
||||||
|
self.gold_data = CJRCEvaluator.gold_answers_to_dict(gold_file)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gold_answers_to_dict(gold_file):
|
||||||
|
dataset = json.load(open(gold_file, mode="r", encoding="utf-8"))
|
||||||
|
gold_dict = {}
|
||||||
|
# id_to_domain = {}
|
||||||
|
for story in dataset['data']:
|
||||||
|
qas = story["paragraphs"][0]["qas"]
|
||||||
|
for qa in qas:
|
||||||
|
qid = qa['id']
|
||||||
|
gold_answers = []
|
||||||
|
answers = qa["answers"]
|
||||||
|
if len(answers) == 0:
|
||||||
|
gold_answers = ['']
|
||||||
|
else:
|
||||||
|
for answer in qa["answers"]:
|
||||||
|
if type(answer) == dict:
|
||||||
|
gold_answers.append(answer["text"])
|
||||||
|
elif type(answer) == list:
|
||||||
|
gold_answers.append("".join([a["text"] for a in answer]))
|
||||||
|
if qid in gold_dict:
|
||||||
|
sys.stderr.write("Gold file has duplicate stories: {}".format(qid))
|
||||||
|
gold_dict[qid] = gold_answers
|
||||||
|
return gold_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preds_to_dict(pred_file):
|
||||||
|
preds = json.load(open(pred_file, mode="r", encoding="utf-8"))
|
||||||
|
pred_dict = {}
|
||||||
|
for pred in preds:
|
||||||
|
pred_dict[pred['id']] = "".join(pred['answer'])
|
||||||
|
return pred_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_answer(s):
|
||||||
|
"""Lower text and remove punctuation, storys and extra whitespace."""
|
||||||
|
|
||||||
|
def remove_punc(text):
|
||||||
|
return "".join(ch for ch in text if ch.isdigit() or ch.isalpha())
|
||||||
|
|
||||||
|
def lower(text):
|
||||||
|
return text.lower()
|
||||||
|
|
||||||
|
return remove_punc(lower(s))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tokens(s):
|
||||||
|
if not s: return []
|
||||||
|
return list(CJRCEvaluator.normalize_answer(s))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_exact(a_gold, a_pred):
|
||||||
|
return int(CJRCEvaluator.normalize_answer(a_gold) == CJRCEvaluator.normalize_answer(a_pred))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_f1(a_gold, a_pred):
|
||||||
|
gold_toks = CJRCEvaluator.get_tokens(a_gold)
|
||||||
|
pred_toks = CJRCEvaluator.get_tokens(a_pred)
|
||||||
|
common = Counter(gold_toks) & Counter(pred_toks)
|
||||||
|
num_same = sum(common.values())
|
||||||
|
if len(gold_toks) == 0 or len(pred_toks) == 0:
|
||||||
|
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
|
||||||
|
return int(gold_toks == pred_toks)
|
||||||
|
if num_same == 0:
|
||||||
|
return 0
|
||||||
|
precision = 1.0 * num_same / len(pred_toks)
|
||||||
|
recall = 1.0 * num_same / len(gold_toks)
|
||||||
|
f1 = (2 * precision * recall) / (precision + recall)
|
||||||
|
return f1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_turn_score(a_gold_list, a_pred):
|
||||||
|
f1_sum = 0.0
|
||||||
|
em_sum = 0.0
|
||||||
|
if len(a_gold_list) > 1:
|
||||||
|
for i in range(len(a_gold_list)):
|
||||||
|
# exclude the current answer
|
||||||
|
gold_answers = a_gold_list[0:i] + a_gold_list[i + 1:]
|
||||||
|
em_sum += max(CJRCEvaluator.compute_exact(a, a_pred) for a in gold_answers)
|
||||||
|
f1_sum += max(CJRCEvaluator.compute_f1(a, a_pred) for a in gold_answers)
|
||||||
|
else:
|
||||||
|
em_sum += max(CJRCEvaluator.compute_exact(a, a_pred) for a in a_gold_list)
|
||||||
|
f1_sum += max(CJRCEvaluator.compute_f1(a, a_pred) for a in a_gold_list)
|
||||||
|
if f1_sum != 1:
|
||||||
|
a = 1 + 1
|
||||||
|
return {'em': em_sum / max(1, len(a_gold_list)), 'f1': f1_sum / max(1, len(a_gold_list))}
|
||||||
|
|
||||||
|
def compute_turn_score(self, qid, a_pred):
|
||||||
|
''' This is the function what you are probably looking for. a_pred is the answer string your model predicted. '''
|
||||||
|
a_gold_list = self.gold_data[qid]
|
||||||
|
return CJRCEvaluator._compute_turn_score(a_gold_list, a_pred)
|
||||||
|
|
||||||
|
def get_raw_scores(self, pred_data):
|
||||||
|
''''Returns a dict with score'''
|
||||||
|
exact_scores = {}
|
||||||
|
f1_scores = {}
|
||||||
|
for qid in self.gold_data:
|
||||||
|
if qid not in pred_data:
|
||||||
|
sys.stderr.write('Missing prediction for {}\n'.format(qid))
|
||||||
|
continue
|
||||||
|
a_pred = pred_data[qid]
|
||||||
|
scores = self.compute_turn_score(qid, a_pred)
|
||||||
|
# Take max over all gold answers
|
||||||
|
exact_scores[qid] = scores['em']
|
||||||
|
f1_scores[qid] = scores['f1']
|
||||||
|
return exact_scores, f1_scores
|
||||||
|
|
||||||
|
def get_raw_scores_human(self):
|
||||||
|
'''
|
||||||
|
Returns a dict with score
|
||||||
|
'''
|
||||||
|
exact_scores = {}
|
||||||
|
f1_scores = {}
|
||||||
|
for qid in self.gold_data:
|
||||||
|
f1_sum = 0.0
|
||||||
|
em_sum = 0.0
|
||||||
|
if len(self.gold_data[qid]) > 1:
|
||||||
|
for i in range(len(self.gold_data[qid])):
|
||||||
|
# exclude the current answer
|
||||||
|
gold_answers = self.gold_data[qid][0:i] + self.gold_data[qid][i + 1:]
|
||||||
|
em_sum += max(CJRCEvaluator.compute_exact(a, self.gold_data[qid][i]) for a in gold_answers)
|
||||||
|
f1_sum += max(CJRCEvaluator.compute_f1(a, self.gold_data[qid][i]) for a in gold_answers)
|
||||||
|
else:
|
||||||
|
exit("Gold answers should be multiple: {}={}".format(qid, self.gold_data[qid]))
|
||||||
|
exact_scores[qid] = em_sum / len(self.gold_data[qid])
|
||||||
|
f1_scores[qid] = f1_sum / len(self.gold_data[qid])
|
||||||
|
return exact_scores, f1_scores
|
||||||
|
|
||||||
|
def human_performance(self):
|
||||||
|
exact_scores, f1_scores = self.get_raw_scores_human()
|
||||||
|
return self.get_total_scores(exact_scores, f1_scores)
|
||||||
|
|
||||||
|
def model_performance(self, pred_data):
|
||||||
|
exact_scores, f1_scores = self.get_raw_scores(pred_data)
|
||||||
|
return self.get_total_scores(exact_scores, f1_scores)
|
||||||
|
|
||||||
|
def get_total_scores(self, exact_scores, f1_scores):
|
||||||
|
em_total, f1_total, turn_count = 0, 0, 0
|
||||||
|
scores = {}
|
||||||
|
for qid in self.gold_data:
|
||||||
|
em_total += exact_scores.get(qid, 0)
|
||||||
|
f1_total += f1_scores.get(qid, 0)
|
||||||
|
turn_count += 1
|
||||||
|
scores["F1"] = round(f1_total / max(1, turn_count) * 100, 1)
|
||||||
|
return scores
|
@ -1,6 +1,7 @@
|
|||||||
absl-py
|
absl-py
|
||||||
accelerate>=0.19.0
|
accelerate>=0.19.0
|
||||||
boto3
|
boto3
|
||||||
|
cn2an
|
||||||
colossalai
|
colossalai
|
||||||
cpm_kernels
|
cpm_kernels
|
||||||
datasets>=2.12.0
|
datasets>=2.12.0
|
||||||
@ -9,11 +10,15 @@ fairscale
|
|||||||
faiss_gpu==1.7.2
|
faiss_gpu==1.7.2
|
||||||
fuzzywuzzy
|
fuzzywuzzy
|
||||||
jieba
|
jieba
|
||||||
|
ltp
|
||||||
mmengine>=0.8.2
|
mmengine>=0.8.2
|
||||||
nltk==3.8
|
nltk==3.8
|
||||||
numpy==1.23.4
|
numpy==1.23.4
|
||||||
openai
|
openai
|
||||||
|
OpenCC
|
||||||
pandas<2.0.0
|
pandas<2.0.0
|
||||||
|
pypinyin
|
||||||
|
python-Levenshtein
|
||||||
rank_bm25==0.2.2
|
rank_bm25==0.2.2
|
||||||
rapidfuzz
|
rapidfuzz
|
||||||
requests==2.31.0
|
requests==2.31.0
|
||||||
@ -25,6 +30,7 @@ seaborn
|
|||||||
sentence_transformers==2.2.2
|
sentence_transformers==2.2.2
|
||||||
tabulate
|
tabulate
|
||||||
tiktoken
|
tiktoken
|
||||||
|
timeout_decorator
|
||||||
tokenizers>=0.13.3
|
tokenizers>=0.13.3
|
||||||
torch>=1.13.1
|
torch>=1.13.1
|
||||||
tqdm==4.64.1
|
tqdm==4.64.1
|
||||||
|
Loading…
Reference in New Issue
Block a user