mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
273 lines
12 KiB
Python
273 lines
12 KiB
Python
from itertools import groupby
|
||
from string import punctuation
|
||
from typing import List
|
||
from modules.tokenizer import Tokenizer
|
||
from modules.alignment import Alignment, read_cilin, read_confusion
|
||
import Levenshtein
|
||
|
||
class Merger:
|
||
"""
|
||
合并编辑操作,从Token-Level转换为Span-Level
|
||
"""
|
||
|
||
def __init__(self,
|
||
granularity: str = "word",
|
||
merge: bool = False):
|
||
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧."
|
||
self.punctuation = punctuation + chinese_punct
|
||
self.not_merge_token = [punct for punct in self.punctuation]
|
||
self.granularity = granularity
|
||
self.merge = merge
|
||
|
||
@staticmethod
|
||
def _merge_edits(seq, tag="X"):
|
||
if seq:
|
||
return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
|
||
else:
|
||
return seq
|
||
|
||
@staticmethod
|
||
def _check_revolve(span_a, span_b):
|
||
span_a = span_a + span_a
|
||
return span_b in span_a
|
||
|
||
def _process_seq(self, seq, src_tokens, tgt_tokens):
|
||
if len(seq) <= 1:
|
||
return seq
|
||
|
||
ops = [op[0] for op in seq]
|
||
if set(ops) == {"D"} or set(ops) == {"I"}:
|
||
return self._merge_edits(seq, set(ops).pop())
|
||
|
||
if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}:
|
||
# do not merge this pattern_from_qua.txt
|
||
return seq
|
||
|
||
if set(ops) == {"S"}:
|
||
if self.granularity == "word":
|
||
return seq
|
||
else:
|
||
return self._merge_edits(seq, "S")
|
||
|
||
if set(ops) == {"M"}:
|
||
return self._merge_edits(seq, "M")
|
||
|
||
return self._merge_edits(seq, "S")
|
||
|
||
def __call__(self,
|
||
align_obj,
|
||
src: List,
|
||
tgt: List,
|
||
verbose: bool = False):
|
||
"""
|
||
Based on ERRANT's merge, adapted for Chinese
|
||
"""
|
||
src_tokens = [x[0] for x in src]
|
||
tgt_tokens = [x[0] for x in tgt]
|
||
edits = []
|
||
# Split alignment into groups of M, T and rest. (T has a number after it)
|
||
# Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并
|
||
# Todo 缺失成分标签也不与其它编辑合并
|
||
for op, group in groupby(
|
||
align_obj,
|
||
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False,
|
||
):
|
||
group = list(group)
|
||
# T is always split TODO: Evaluate this
|
||
if op == "T":
|
||
for seq in group:
|
||
edits.append(seq)
|
||
# Process D, I and S subsequence
|
||
else:
|
||
# Turn the processed sequence into edits
|
||
processed = self._process_seq(group, src_tokens, tgt_tokens)
|
||
for seq in processed:
|
||
edits.append(seq)
|
||
|
||
filtered_edits = []
|
||
i = 0
|
||
while i < len(edits):
|
||
e1 = edits[i][0][0]
|
||
|
||
if i < len(edits) - 2:
|
||
e2 = edits[i + 1][0][0]
|
||
e3 = edits[i + 2][0][0]
|
||
|
||
# Find "S M S" patterns
|
||
# Ex:
|
||
# S M S
|
||
# 冬阴功 对 外国人
|
||
# 外国人 对 冬阴功
|
||
if e1 == "S" and e2 == "M" and e3 == "S":
|
||
w1 = "".join(src_tokens[edits[i][1]: edits[i][2]])
|
||
w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]])
|
||
w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]])
|
||
w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]])
|
||
if min([len(w1), len(w2), len(w3), len(w4)]) == 1:
|
||
if w1 == w4 and w2 == w3:
|
||
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
|
||
for seq in processed:
|
||
filtered_edits.append(seq)
|
||
i += 3
|
||
else:
|
||
filtered_edits.append(edits[i])
|
||
i += 1
|
||
else:
|
||
if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1:
|
||
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
|
||
for seq in processed:
|
||
filtered_edits.append(seq)
|
||
i += 3
|
||
else:
|
||
filtered_edits.append(edits[i])
|
||
i += 1
|
||
# Find "D M I" or "I M D" patterns
|
||
# Ex:
|
||
# D M I
|
||
# 旅游 去 陌生 的 地方
|
||
# 去 陌生 的 地方 旅游
|
||
elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"):
|
||
if e1 == "D":
|
||
delete_token = src_tokens[edits[i][1]: edits[i][2]]
|
||
insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]
|
||
else:
|
||
delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]]
|
||
insert_token = tgt_tokens[edits[i][3]: edits[i][4]]
|
||
a, b = "".join(delete_token), "".join(insert_token)
|
||
if len(a) < len(b):
|
||
a, b = b, a
|
||
if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1:
|
||
if len(b) == 1:
|
||
if a == b:
|
||
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
|
||
for seq in processed:
|
||
filtered_edits.append(seq)
|
||
i += 3
|
||
else:
|
||
filtered_edits.append(edits[i])
|
||
i += 1
|
||
else:
|
||
if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)):
|
||
group = [edits[i], edits[i + 1], edits[i + 2]]
|
||
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
|
||
for seq in processed:
|
||
filtered_edits.append(seq)
|
||
i += 3
|
||
else:
|
||
filtered_edits.append(edits[i])
|
||
i += 1
|
||
else:
|
||
filtered_edits.append(edits[i])
|
||
i += 1
|
||
else:
|
||
if e1 != "M":
|
||
filtered_edits.append(edits[i])
|
||
i += 1
|
||
else:
|
||
if e1 != "M":
|
||
filtered_edits.append(edits[i])
|
||
i += 1
|
||
# In rare cases with word-level tokenization, the following error can occur:
|
||
# M D S M
|
||
# 有 時 住 上層
|
||
# 有 時住 上層
|
||
# Which results in S: 時住 --> 時住
|
||
# We need to filter this case out
|
||
second_filter = []
|
||
for edit in filtered_edits: # 避免因为分词错误导致的mismatch现象
|
||
span1 = "".join(src_tokens[edit[1] : edit[2]])
|
||
span2 = "".join(tgt_tokens[edit[3] : edit[4]])
|
||
|
||
if span1 != span2:
|
||
if edit[0] == "S":
|
||
b = True
|
||
# In rare cases with word-level tokenization, the following error can occur:
|
||
# S I I M
|
||
# 负责任 老师
|
||
# 负 责任 的 老师
|
||
# Which results in S: 负责任 --> 负 责任 的
|
||
# We need to convert this edit to I: --> 的
|
||
|
||
# 首部有重叠
|
||
common_str = ""
|
||
tmp_new_start_1 = edit[1]
|
||
for i in range(edit[1], edit[2]):
|
||
if not span2.startswith(common_str + src_tokens[i]):
|
||
break
|
||
common_str += src_tokens[i]
|
||
tmp_new_start_1 = i + 1
|
||
new_start_1, new_start_2 = edit[1], edit[3]
|
||
if common_str:
|
||
tmp_str = ""
|
||
for i in range(edit[3], edit[4]):
|
||
tmp_str += tgt_tokens[i]
|
||
if tmp_str == common_str:
|
||
new_start_1, new_start_2 = tmp_new_start_1, i + 1
|
||
# second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4]))
|
||
b = False
|
||
break
|
||
elif len(tmp_str) > len(common_str):
|
||
break
|
||
# 尾部有重叠
|
||
common_str = ""
|
||
new_end_1, new_end_2 = edit[2], edit[4]
|
||
tmp_new_end_1 = edit[2]
|
||
for i in reversed(range(new_start_1, edit[2])):
|
||
if not span2.endswith(src_tokens[i] + common_str):
|
||
break
|
||
common_str = src_tokens[i] + common_str
|
||
tmp_new_end_1 = i
|
||
if common_str:
|
||
tmp_str = ""
|
||
for i in reversed(range(new_start_2, edit[4])):
|
||
tmp_str = tgt_tokens[i] + tmp_str
|
||
if tmp_str == common_str:
|
||
new_end_1, new_end_2 = tmp_new_end_1, i
|
||
b = False
|
||
break
|
||
elif len(tmp_str) > len(common_str):
|
||
break
|
||
if b:
|
||
second_filter.append(edit)
|
||
else:
|
||
if new_start_1 == new_end_1:
|
||
new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2)
|
||
elif new_start_2 == new_end_2:
|
||
new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2)
|
||
else:
|
||
new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2)
|
||
second_filter.append(new_edit)
|
||
else:
|
||
second_filter.append(edit)
|
||
if verbose:
|
||
print("========== Parallels ==========")
|
||
print("".join(src_tokens))
|
||
print("".join(tgt_tokens))
|
||
print("========== Results ==========")
|
||
for edit in second_filter:
|
||
op = edit[0]
|
||
s = " ".join(src_tokens[edit[1]: edit[2]])
|
||
t = " ".join(tgt_tokens[edit[3]: edit[4]])
|
||
print(f"{op}:\t{s}\t-->\t{t}")
|
||
print("========== Infos ==========")
|
||
print(str(src))
|
||
print(str(tgt))
|
||
return second_filter
|
||
|
||
if __name__ == "__main__":
|
||
tokenizer = Tokenizer("char")
|
||
semantic_dict, semantic_class = read_cilin()
|
||
confusion_dict = read_confusion()
|
||
alignment = Alignment(semantic_dict, confusion_dict)
|
||
sents = [
|
||
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace(
|
||
" ", ""),
|
||
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace(
|
||
" ", "")]
|
||
src, tgt = tokenizer(sents)
|
||
align_obj = alignment(src, tgt)
|
||
m = Merger()
|
||
m(align_obj, src, tgt, verbose=True) |