mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
222 lines
8.2 KiB
Python
222 lines
8.2 KiB
Python
import os
|
||
from modules.annotator import Annotator
|
||
from modules.tokenizer import Tokenizer
|
||
import argparse
|
||
from collections import Counter
|
||
from tqdm import tqdm
|
||
import torch
|
||
from collections import defaultdict
|
||
from multiprocessing import Pool
|
||
from opencc import OpenCC
|
||
import timeout_decorator
|
||
|
||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||
|
||
annotator, sentence_to_tokenized = None, None
|
||
cc = OpenCC("t2s")
|
||
|
||
@timeout_decorator.timeout(10)
|
||
def annotate_with_time_out(line):
|
||
"""
|
||
:param line:
|
||
:return:
|
||
"""
|
||
sent_list = line.split("\t")[1:]
|
||
source = sent_list[0]
|
||
if args.segmented:
|
||
source = source.strip()
|
||
else:
|
||
source = "".join(source.strip().split())
|
||
output_str = ""
|
||
for idx, target in enumerate(sent_list[1:]):
|
||
try:
|
||
if args.segmented:
|
||
target = target.strip()
|
||
else:
|
||
target = "".join(target.strip().split())
|
||
if not args.no_simplified:
|
||
target = cc.convert(target)
|
||
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
|
||
out, cors = annotator(source_tokenized, target_tokenized, idx)
|
||
if idx == 0:
|
||
output_str += "".join(out[:-1])
|
||
else:
|
||
output_str += "".join(out[1:-1])
|
||
except Exception:
|
||
raise Exception
|
||
return output_str
|
||
|
||
|
||
def annotate(line):
|
||
"""
|
||
:param line:
|
||
:return:
|
||
"""
|
||
sent_list = line.split("\t")[1:]
|
||
source = sent_list[0]
|
||
if args.segmented:
|
||
source = source.strip()
|
||
else:
|
||
source = "".join(source.strip().split())
|
||
output_str = ""
|
||
for idx, target in enumerate(sent_list[1:]):
|
||
try:
|
||
if args.segmented:
|
||
target = target.strip()
|
||
else:
|
||
target = "".join(target.strip().split())
|
||
if not args.no_simplified:
|
||
target = cc.convert(target)
|
||
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
|
||
out, cors = annotator(source_tokenized, target_tokenized, idx)
|
||
if idx == 0:
|
||
output_str += "".join(out[:-1])
|
||
else:
|
||
output_str += "".join(out[1:-1])
|
||
except Exception:
|
||
raise Exception
|
||
return output_str
|
||
|
||
|
||
|
||
|
||
|
||
def firsttime_process(args):
|
||
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
|
||
global annotator, sentence_to_tokenized
|
||
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
|
||
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n") # format: id src tgt1 tgt2...
|
||
# error_types = []
|
||
|
||
with open(args.output, "w", encoding="utf-8") as f:
|
||
count = 0
|
||
sentence_set = set()
|
||
sentence_to_tokenized = {}
|
||
for line in lines:
|
||
sent_list = line.split("\t")[1:]
|
||
for idx, sent in enumerate(sent_list):
|
||
if args.segmented:
|
||
# print(sent)
|
||
sent = sent.strip()
|
||
else:
|
||
sent = "".join(sent.split()).strip()
|
||
if idx >= 1:
|
||
if not args.no_simplified:
|
||
sentence_set.add(cc.convert(sent))
|
||
else:
|
||
sentence_set.add(sent)
|
||
else:
|
||
sentence_set.add(sent)
|
||
batch = []
|
||
for sent in tqdm(sentence_set):
|
||
count += 1
|
||
if sent:
|
||
batch.append(sent)
|
||
if count % args.batch_size == 0:
|
||
results = tokenizer(batch)
|
||
for s, r in zip(batch, results):
|
||
sentence_to_tokenized[s] = r # Get tokenization map.
|
||
batch = []
|
||
if batch:
|
||
results = tokenizer(batch)
|
||
for s, r in zip(batch, results):
|
||
sentence_to_tokenized[s] = r # Get tokenization map.
|
||
|
||
timeout_indices = []
|
||
|
||
# 单进程模式
|
||
for idx, line in enumerate(tqdm(lines)):
|
||
try:
|
||
ret = annotate_with_time_out(line)
|
||
except Exception:
|
||
timeout_indices.append(idx)
|
||
return timeout_indices
|
||
|
||
|
||
|
||
def main(args):
|
||
timeout_indices = firsttime_process(args)
|
||
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
|
||
global annotator, sentence_to_tokenized
|
||
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
|
||
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n")
|
||
new_lines = []# format: id src tgt1 tgt2...
|
||
|
||
with open(args.output, "w", encoding="utf-8") as f:
|
||
count = 0
|
||
sentence_set = set()
|
||
sentence_to_tokenized = {}
|
||
for line_idx, line in enumerate(lines):
|
||
|
||
if line_idx in timeout_indices:
|
||
# print(f"line before split: {line}")
|
||
line_split = line.split("\t")
|
||
line_number, sent_list = line_split[0], line_split[1:]
|
||
assert len(sent_list) == 2
|
||
sent_list[-1] = " 无"
|
||
line = line_number + "\t" + "\t".join(sent_list)
|
||
# print(f"line time out: {line}")
|
||
new_lines.append(line)
|
||
else:
|
||
new_lines.append(line)
|
||
|
||
sent_list = line.split("\t")[1:]
|
||
for idx, sent in enumerate(sent_list):
|
||
if args.segmented:
|
||
# print(sent)
|
||
sent = sent.strip()
|
||
else:
|
||
sent = "".join(sent.split()).strip()
|
||
if idx >= 1:
|
||
if not args.no_simplified:
|
||
sentence_set.add(cc.convert(sent))
|
||
else:
|
||
sentence_set.add(sent)
|
||
else:
|
||
sentence_set.add(sent)
|
||
batch = []
|
||
for sent in tqdm(sentence_set):
|
||
count += 1
|
||
if sent:
|
||
batch.append(sent)
|
||
if count % args.batch_size == 0:
|
||
results = tokenizer(batch)
|
||
for s, r in zip(batch, results):
|
||
sentence_to_tokenized[s] = r # Get tokenization map.
|
||
batch = []
|
||
if batch:
|
||
results = tokenizer(batch)
|
||
for s, r in zip(batch, results):
|
||
sentence_to_tokenized[s] = r # Get tokenization map.
|
||
|
||
# 单进程模式
|
||
lines = new_lines
|
||
for idx, line in enumerate(tqdm(lines)):
|
||
ret = annotate(line)
|
||
f.write(ret)
|
||
f.write("\n")
|
||
|
||
# 多进程模式:仅在Linux环境下测试,建议在linux服务器上使用
|
||
# with Pool(args.worker_num) as pool:
|
||
# for ret in pool.imap(annotate, tqdm(lines), chunksize=8):
|
||
# if ret:
|
||
# f.write(ret)
|
||
# f.write("\n")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="Choose input file to annotate")
|
||
parser.add_argument("-f", "--file", type=str, required=True, help="Input parallel file")
|
||
parser.add_argument("-o", "--output", type=str, help="Output file", required=True)
|
||
parser.add_argument("-b", "--batch_size", type=int, help="The size of batch", default=128)
|
||
parser.add_argument("-d", "--device", type=int, help="The ID of GPU", default=0)
|
||
parser.add_argument("-w", "--worker_num", type=int, help="The number of workers", default=16)
|
||
parser.add_argument("-g", "--granularity", type=str, help="Choose char-level or word-level evaluation", default="char")
|
||
parser.add_argument("-m", "--merge", help="Whether merge continuous replacement/deletion/insertion", action="store_true")
|
||
parser.add_argument("-s", "--multi_cheapest_strategy", type=str, choices=["first", "all"], default="all")
|
||
parser.add_argument("--segmented", help="Whether tokens have been segmented", action="store_true") # 支持提前token化,用空格隔开
|
||
parser.add_argument("--no_simplified", help="Whether simplifying chinese", action="store_true") # 将所有corrections转换为简体中文
|
||
parser.add_argument("--bpe", help="Whether to use bpe", action="store_true") # 支持 bpe 切分英文单词
|
||
args = parser.parse_args()
|
||
main(args)
|