mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
273 lines
12 KiB
Python
273 lines
12 KiB
Python
![]() |
from itertools import groupby
|
|||
|
from string import punctuation
|
|||
|
from typing import List
|
|||
|
from modules.tokenizer import Tokenizer
|
|||
|
from modules.alignment import Alignment, read_cilin, read_confusion
|
|||
|
import Levenshtein
|
|||
|
|
|||
|
class Merger:
|
|||
|
"""
|
|||
|
合并编辑操作,从Token-Level转换为Span-Level
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
granularity: str = "word",
|
|||
|
merge: bool = False):
|
|||
|
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧."
|
|||
|
self.punctuation = punctuation + chinese_punct
|
|||
|
self.not_merge_token = [punct for punct in self.punctuation]
|
|||
|
self.granularity = granularity
|
|||
|
self.merge = merge
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def _merge_edits(seq, tag="X"):
|
|||
|
if seq:
|
|||
|
return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
|
|||
|
else:
|
|||
|
return seq
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def _check_revolve(span_a, span_b):
|
|||
|
span_a = span_a + span_a
|
|||
|
return span_b in span_a
|
|||
|
|
|||
|
def _process_seq(self, seq, src_tokens, tgt_tokens):
|
|||
|
if len(seq) <= 1:
|
|||
|
return seq
|
|||
|
|
|||
|
ops = [op[0] for op in seq]
|
|||
|
if set(ops) == {"D"} or set(ops) == {"I"}:
|
|||
|
return self._merge_edits(seq, set(ops).pop())
|
|||
|
|
|||
|
if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}:
|
|||
|
# do not merge this pattern_from_qua.txt
|
|||
|
return seq
|
|||
|
|
|||
|
if set(ops) == {"S"}:
|
|||
|
if self.granularity == "word":
|
|||
|
return seq
|
|||
|
else:
|
|||
|
return self._merge_edits(seq, "S")
|
|||
|
|
|||
|
if set(ops) == {"M"}:
|
|||
|
return self._merge_edits(seq, "M")
|
|||
|
|
|||
|
return self._merge_edits(seq, "S")
|
|||
|
|
|||
|
def __call__(self,
|
|||
|
align_obj,
|
|||
|
src: List,
|
|||
|
tgt: List,
|
|||
|
verbose: bool = False):
|
|||
|
"""
|
|||
|
Based on ERRANT's merge, adapted for Chinese
|
|||
|
"""
|
|||
|
src_tokens = [x[0] for x in src]
|
|||
|
tgt_tokens = [x[0] for x in tgt]
|
|||
|
edits = []
|
|||
|
# Split alignment into groups of M, T and rest. (T has a number after it)
|
|||
|
# Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并
|
|||
|
# Todo 缺失成分标签也不与其它编辑合并
|
|||
|
for op, group in groupby(
|
|||
|
align_obj,
|
|||
|
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False,
|
|||
|
):
|
|||
|
group = list(group)
|
|||
|
# T is always split TODO: Evaluate this
|
|||
|
if op == "T":
|
|||
|
for seq in group:
|
|||
|
edits.append(seq)
|
|||
|
# Process D, I and S subsequence
|
|||
|
else:
|
|||
|
# Turn the processed sequence into edits
|
|||
|
processed = self._process_seq(group, src_tokens, tgt_tokens)
|
|||
|
for seq in processed:
|
|||
|
edits.append(seq)
|
|||
|
|
|||
|
filtered_edits = []
|
|||
|
i = 0
|
|||
|
while i < len(edits):
|
|||
|
e1 = edits[i][0][0]
|
|||
|
|
|||
|
if i < len(edits) - 2:
|
|||
|
e2 = edits[i + 1][0][0]
|
|||
|
e3 = edits[i + 2][0][0]
|
|||
|
|
|||
|
# Find "S M S" patterns
|
|||
|
# Ex:
|
|||
|
# S M S
|
|||
|
# 冬阴功 对 外国人
|
|||
|
# 外国人 对 冬阴功
|
|||
|
if e1 == "S" and e2 == "M" and e3 == "S":
|
|||
|
w1 = "".join(src_tokens[edits[i][1]: edits[i][2]])
|
|||
|
w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]])
|
|||
|
w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]])
|
|||
|
w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]])
|
|||
|
if min([len(w1), len(w2), len(w3), len(w4)]) == 1:
|
|||
|
if w1 == w4 and w2 == w3:
|
|||
|
group = [edits[i], edits[i + 1], edits[i + 2]]
|
|||
|
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
|
|||
|
for seq in processed:
|
|||
|
filtered_edits.append(seq)
|
|||
|
i += 3
|
|||
|
else:
|
|||
|
filtered_edits.append(edits[i])
|
|||
|
i += 1
|
|||
|
else:
|
|||
|
if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1:
|
|||
|
group = [edits[i], edits[i + 1], edits[i + 2]]
|
|||
|
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
|
|||
|
for seq in processed:
|
|||
|
filtered_edits.append(seq)
|
|||
|
i += 3
|
|||
|
else:
|
|||
|
filtered_edits.append(edits[i])
|
|||
|
i += 1
|
|||
|
# Find "D M I" or "I M D" patterns
|
|||
|
# Ex:
|
|||
|
# D M I
|
|||
|
# 旅游 去 陌生 的 地方
|
|||
|
# 去 陌生 的 地方 旅游
|
|||
|
elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"):
|
|||
|
if e1 == "D":
|
|||
|
delete_token = src_tokens[edits[i][1]: edits[i][2]]
|
|||
|
insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]
|
|||
|
else:
|
|||
|
delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]]
|
|||
|
insert_token = tgt_tokens[edits[i][3]: edits[i][4]]
|
|||
|
a, b = "".join(delete_token), "".join(insert_token)
|
|||
|
if len(a) < len(b):
|
|||
|
a, b = b, a
|
|||
|
if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1:
|
|||
|
if len(b) == 1:
|
|||
|
if a == b:
|
|||
|
group = [edits[i], edits[i + 1], edits[i + 2]]
|
|||
|
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
|
|||
|
for seq in processed:
|
|||
|
filtered_edits.append(seq)
|
|||
|
i += 3
|
|||
|
else:
|
|||
|
filtered_edits.append(edits[i])
|
|||
|
i += 1
|
|||
|
else:
|
|||
|
if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)):
|
|||
|
group = [edits[i], edits[i + 1], edits[i + 2]]
|
|||
|
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
|
|||
|
for seq in processed:
|
|||
|
filtered_edits.append(seq)
|
|||
|
i += 3
|
|||
|
else:
|
|||
|
filtered_edits.append(edits[i])
|
|||
|
i += 1
|
|||
|
else:
|
|||
|
filtered_edits.append(edits[i])
|
|||
|
i += 1
|
|||
|
else:
|
|||
|
if e1 != "M":
|
|||
|
filtered_edits.append(edits[i])
|
|||
|
i += 1
|
|||
|
else:
|
|||
|
if e1 != "M":
|
|||
|
filtered_edits.append(edits[i])
|
|||
|
i += 1
|
|||
|
# In rare cases with word-level tokenization, the following error can occur:
|
|||
|
# M D S M
|
|||
|
# 有 時 住 上層
|
|||
|
# 有 時住 上層
|
|||
|
# Which results in S: 時住 --> 時住
|
|||
|
# We need to filter this case out
|
|||
|
second_filter = []
|
|||
|
for edit in filtered_edits: # 避免因为分词错误导致的mismatch现象
|
|||
|
span1 = "".join(src_tokens[edit[1] : edit[2]])
|
|||
|
span2 = "".join(tgt_tokens[edit[3] : edit[4]])
|
|||
|
|
|||
|
if span1 != span2:
|
|||
|
if edit[0] == "S":
|
|||
|
b = True
|
|||
|
# In rare cases with word-level tokenization, the following error can occur:
|
|||
|
# S I I M
|
|||
|
# 负责任 老师
|
|||
|
# 负 责任 的 老师
|
|||
|
# Which results in S: 负责任 --> 负 责任 的
|
|||
|
# We need to convert this edit to I: --> 的
|
|||
|
|
|||
|
# 首部有重叠
|
|||
|
common_str = ""
|
|||
|
tmp_new_start_1 = edit[1]
|
|||
|
for i in range(edit[1], edit[2]):
|
|||
|
if not span2.startswith(common_str + src_tokens[i]):
|
|||
|
break
|
|||
|
common_str += src_tokens[i]
|
|||
|
tmp_new_start_1 = i + 1
|
|||
|
new_start_1, new_start_2 = edit[1], edit[3]
|
|||
|
if common_str:
|
|||
|
tmp_str = ""
|
|||
|
for i in range(edit[3], edit[4]):
|
|||
|
tmp_str += tgt_tokens[i]
|
|||
|
if tmp_str == common_str:
|
|||
|
new_start_1, new_start_2 = tmp_new_start_1, i + 1
|
|||
|
# second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4]))
|
|||
|
b = False
|
|||
|
break
|
|||
|
elif len(tmp_str) > len(common_str):
|
|||
|
break
|
|||
|
# 尾部有重叠
|
|||
|
common_str = ""
|
|||
|
new_end_1, new_end_2 = edit[2], edit[4]
|
|||
|
tmp_new_end_1 = edit[2]
|
|||
|
for i in reversed(range(new_start_1, edit[2])):
|
|||
|
if not span2.endswith(src_tokens[i] + common_str):
|
|||
|
break
|
|||
|
common_str = src_tokens[i] + common_str
|
|||
|
tmp_new_end_1 = i
|
|||
|
if common_str:
|
|||
|
tmp_str = ""
|
|||
|
for i in reversed(range(new_start_2, edit[4])):
|
|||
|
tmp_str = tgt_tokens[i] + tmp_str
|
|||
|
if tmp_str == common_str:
|
|||
|
new_end_1, new_end_2 = tmp_new_end_1, i
|
|||
|
b = False
|
|||
|
break
|
|||
|
elif len(tmp_str) > len(common_str):
|
|||
|
break
|
|||
|
if b:
|
|||
|
second_filter.append(edit)
|
|||
|
else:
|
|||
|
if new_start_1 == new_end_1:
|
|||
|
new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2)
|
|||
|
elif new_start_2 == new_end_2:
|
|||
|
new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2)
|
|||
|
else:
|
|||
|
new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2)
|
|||
|
second_filter.append(new_edit)
|
|||
|
else:
|
|||
|
second_filter.append(edit)
|
|||
|
if verbose:
|
|||
|
print("========== Parallels ==========")
|
|||
|
print("".join(src_tokens))
|
|||
|
print("".join(tgt_tokens))
|
|||
|
print("========== Results ==========")
|
|||
|
for edit in second_filter:
|
|||
|
op = edit[0]
|
|||
|
s = " ".join(src_tokens[edit[1]: edit[2]])
|
|||
|
t = " ".join(tgt_tokens[edit[3]: edit[4]])
|
|||
|
print(f"{op}:\t{s}\t-->\t{t}")
|
|||
|
print("========== Infos ==========")
|
|||
|
print(str(src))
|
|||
|
print(str(tgt))
|
|||
|
return second_filter
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
tokenizer = Tokenizer("char")
|
|||
|
semantic_dict, semantic_class = read_cilin()
|
|||
|
confusion_dict = read_confusion()
|
|||
|
alignment = Alignment(semantic_dict, confusion_dict)
|
|||
|
sents = [
|
|||
|
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace(
|
|||
|
" ", ""),
|
|||
|
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace(
|
|||
|
" ", "")]
|
|||
|
src, tgt = tokenizer(sents)
|
|||
|
align_obj = alignment(src, tgt)
|
|||
|
m = Merger()
|
|||
|
m(align_obj, src, tgt, verbose=True)
|