OpenCompass/opencompass/datasets/lawbench/utils/modules/merger.py

273 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from itertools import groupby
from string import punctuation
from typing import List
from modules.tokenizer import Tokenizer
from modules.alignment import Alignment, read_cilin, read_confusion
import Levenshtein
class Merger:
"""
合并编辑操作从Token-Level转换为Span-Level
"""
def __init__(self,
granularity: str = "word",
merge: bool = False):
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧."
self.punctuation = punctuation + chinese_punct
self.not_merge_token = [punct for punct in self.punctuation]
self.granularity = granularity
self.merge = merge
@staticmethod
def _merge_edits(seq, tag="X"):
if seq:
return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
else:
return seq
@staticmethod
def _check_revolve(span_a, span_b):
span_a = span_a + span_a
return span_b in span_a
def _process_seq(self, seq, src_tokens, tgt_tokens):
if len(seq) <= 1:
return seq
ops = [op[0] for op in seq]
if set(ops) == {"D"} or set(ops) == {"I"}:
return self._merge_edits(seq, set(ops).pop())
if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}:
# do not merge this pattern_from_qua.txt
return seq
if set(ops) == {"S"}:
if self.granularity == "word":
return seq
else:
return self._merge_edits(seq, "S")
if set(ops) == {"M"}:
return self._merge_edits(seq, "M")
return self._merge_edits(seq, "S")
def __call__(self,
align_obj,
src: List,
tgt: List,
verbose: bool = False):
"""
Based on ERRANT's merge, adapted for Chinese
"""
src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt]
edits = []
# Split alignment into groups of M, T and rest. (T has a number after it)
# Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并
# Todo 缺失成分标签也不与其它编辑合并
for op, group in groupby(
align_obj,
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False,
):
group = list(group)
# T is always split TODO: Evaluate this
if op == "T":
for seq in group:
edits.append(seq)
# Process D, I and S subsequence
else:
# Turn the processed sequence into edits
processed = self._process_seq(group, src_tokens, tgt_tokens)
for seq in processed:
edits.append(seq)
filtered_edits = []
i = 0
while i < len(edits):
e1 = edits[i][0][0]
if i < len(edits) - 2:
e2 = edits[i + 1][0][0]
e3 = edits[i + 2][0][0]
# Find "S M S" patterns
# Ex:
# S M S
# 冬阴功 对 外国人
# 外国人 对 冬阴功
if e1 == "S" and e2 == "M" and e3 == "S":
w1 = "".join(src_tokens[edits[i][1]: edits[i][2]])
w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]])
w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]])
w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]])
if min([len(w1), len(w2), len(w3), len(w4)]) == 1:
if w1 == w4 and w2 == w3:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
# Find "D M I" or "I M D" patterns
# Ex:
# D M I
# 旅游 去 陌生 的 地方
# 去 陌生 的 地方 旅游
elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"):
if e1 == "D":
delete_token = src_tokens[edits[i][1]: edits[i][2]]
insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]
else:
delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]]
insert_token = tgt_tokens[edits[i][3]: edits[i][4]]
a, b = "".join(delete_token), "".join(insert_token)
if len(a) < len(b):
a, b = b, a
if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1:
if len(b) == 1:
if a == b:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)):
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
filtered_edits.append(edits[i])
i += 1
else:
if e1 != "M":
filtered_edits.append(edits[i])
i += 1
else:
if e1 != "M":
filtered_edits.append(edits[i])
i += 1
# In rare cases with word-level tokenization, the following error can occur:
# M D S M
# 有 時 住 上層
# 有 時住 上層
# Which results in S: 時住 --> 時住
# We need to filter this case out
second_filter = []
for edit in filtered_edits: # 避免因为分词错误导致的mismatch现象
span1 = "".join(src_tokens[edit[1] : edit[2]])
span2 = "".join(tgt_tokens[edit[3] : edit[4]])
if span1 != span2:
if edit[0] == "S":
b = True
# In rare cases with word-level tokenization, the following error can occur:
# S I I M
# 负责任 老师
# 负 责任 的 老师
# Which results in S: 负责任 --> 负 责任 的
# We need to convert this edit to I: --> 的
# 首部有重叠
common_str = ""
tmp_new_start_1 = edit[1]
for i in range(edit[1], edit[2]):
if not span2.startswith(common_str + src_tokens[i]):
break
common_str += src_tokens[i]
tmp_new_start_1 = i + 1
new_start_1, new_start_2 = edit[1], edit[3]
if common_str:
tmp_str = ""
for i in range(edit[3], edit[4]):
tmp_str += tgt_tokens[i]
if tmp_str == common_str:
new_start_1, new_start_2 = tmp_new_start_1, i + 1
# second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4]))
b = False
break
elif len(tmp_str) > len(common_str):
break
# 尾部有重叠
common_str = ""
new_end_1, new_end_2 = edit[2], edit[4]
tmp_new_end_1 = edit[2]
for i in reversed(range(new_start_1, edit[2])):
if not span2.endswith(src_tokens[i] + common_str):
break
common_str = src_tokens[i] + common_str
tmp_new_end_1 = i
if common_str:
tmp_str = ""
for i in reversed(range(new_start_2, edit[4])):
tmp_str = tgt_tokens[i] + tmp_str
if tmp_str == common_str:
new_end_1, new_end_2 = tmp_new_end_1, i
b = False
break
elif len(tmp_str) > len(common_str):
break
if b:
second_filter.append(edit)
else:
if new_start_1 == new_end_1:
new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2)
elif new_start_2 == new_end_2:
new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2)
else:
new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2)
second_filter.append(new_edit)
else:
second_filter.append(edit)
if verbose:
print("========== Parallels ==========")
print("".join(src_tokens))
print("".join(tgt_tokens))
print("========== Results ==========")
for edit in second_filter:
op = edit[0]
s = " ".join(src_tokens[edit[1]: edit[2]])
t = " ".join(tgt_tokens[edit[3]: edit[4]])
print(f"{op}:\t{s}\t-->\t{t}")
print("========== Infos ==========")
print(str(src))
print(str(tgt))
return second_filter
if __name__ == "__main__":
tokenizer = Tokenizer("char")
semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion()
alignment = Alignment(semantic_dict, confusion_dict)
sents = [
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace(
" ", ""),
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace(
" ", "")]
src, tgt = tokenizer(sents)
align_obj = alignment(src, tgt)
m = Merger()
m(align_obj, src, tgt, verbose=True)