mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
434 lines
18 KiB
Python
434 lines
18 KiB
Python
import argparse
|
|
from collections import Counter
|
|
|
|
def main():
|
|
# Parse command line args
|
|
args = parse_args()
|
|
# Open hypothesis and reference m2 files and split into chunks
|
|
hyp_m2 = open(args.hyp).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.hyp).read().strip().split("\n\n")
|
|
ref_m2 = open(args.ref).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.ref).read().strip().split("\n\n")
|
|
# Make sure they have the same number of sentences
|
|
assert len(hyp_m2) == len(ref_m2), print(len(hyp_m2), len(ref_m2))
|
|
|
|
# Store global corpus level best counts here
|
|
best_dict = Counter({"tp":0, "fp":0, "fn":0})
|
|
best_cats = {}
|
|
# Process each sentence
|
|
sents = zip(hyp_m2, ref_m2)
|
|
for sent_id, sent in enumerate(sents):
|
|
# Simplify the edits into lists of lists
|
|
# if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons:
|
|
# sent_id_cons.append(sent_id)
|
|
src = sent[0].split("\n")[0]
|
|
hyp_edits = simplify_edits(sent[0], args.max_answer_num)
|
|
ref_edits = simplify_edits(sent[1], args.max_answer_num)
|
|
# Process the edits for detection/correction based on args
|
|
hyp_dict = process_edits(hyp_edits, args)
|
|
ref_dict = process_edits(ref_edits, args)
|
|
if args.reference_num is None or len(ref_dict.keys()) == args.reference_num:
|
|
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
|
|
count_dict, cat_dict = evaluate_edits(src,
|
|
hyp_dict, ref_dict, best_dict, sent_id, args)
|
|
# Merge these dicts with best_dict and best_cats
|
|
best_dict += Counter(count_dict)
|
|
best_cats = merge_dict(best_cats, cat_dict)
|
|
# Print results
|
|
print_results(best_dict, best_cats, args)
|
|
|
|
# Parse command line args
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="Calculate F-scores for error detection and/or correction.\n"
|
|
"Flags let you evaluate at different levels of granularity.",
|
|
formatter_class=argparse.RawTextHelpFormatter,
|
|
usage="%(prog)s [options] -hyp HYP -ref REF")
|
|
parser.add_argument(
|
|
"-hyp",
|
|
help="A hypothesis M2 file.",
|
|
required=True)
|
|
parser.add_argument(
|
|
"-ref",
|
|
help="A reference M2 file.",
|
|
required=True)
|
|
parser.add_argument(
|
|
"--start",
|
|
type=int,
|
|
default=None
|
|
)
|
|
parser.add_argument(
|
|
"--end",
|
|
type=int,
|
|
default=None
|
|
)
|
|
parser.add_argument(
|
|
"--max_answer_num",
|
|
type=int,
|
|
default=None
|
|
)
|
|
parser.add_argument(
|
|
"--reference_num",
|
|
type=int,
|
|
default=None
|
|
)
|
|
parser.add_argument(
|
|
"-b",
|
|
"--beta",
|
|
help="Value of beta in F-score. (default: 0.5)",
|
|
default=0.5,
|
|
type=float)
|
|
parser.add_argument(
|
|
"-v",
|
|
"--verbose",
|
|
help="Print verbose output.",
|
|
action="store_true")
|
|
eval_type = parser.add_mutually_exclusive_group()
|
|
eval_type.add_argument(
|
|
"-dt",
|
|
help="Evaluate Detection in terms of Tokens.",
|
|
action="store_true")
|
|
eval_type.add_argument(
|
|
"-ds",
|
|
help="Evaluate Detection in terms of Spans.",
|
|
action="store_true")
|
|
eval_type.add_argument(
|
|
"-cs",
|
|
help="Evaluate Correction in terms of Spans. (default)",
|
|
action="store_true")
|
|
eval_type.add_argument(
|
|
"-cse",
|
|
help="Evaluate Correction in terms of Spans and Error types.",
|
|
action="store_true")
|
|
parser.add_argument(
|
|
"-single",
|
|
help="Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1",
|
|
action="store_true")
|
|
parser.add_argument(
|
|
"-multi",
|
|
help="Only evaluate multi token edits; i.e. 2+:n or n:2+",
|
|
action="store_true")
|
|
parser.add_argument(
|
|
"-multi_hyp_avg",
|
|
help="When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence.",
|
|
action="store_true") # For IAA calculation
|
|
parser.add_argument(
|
|
"-multi_hyp_max",
|
|
help="When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence.",
|
|
action="store_true") # For multiple hypotheses system evaluation
|
|
parser.add_argument(
|
|
"-filt",
|
|
help="Do not evaluate the specified error types.",
|
|
nargs="+",
|
|
default=[])
|
|
parser.add_argument(
|
|
"-cat",
|
|
help="Show error category scores.\n"
|
|
"1: Only show operation tier scores; e.g. R.\n"
|
|
"2: Only show main tier scores; e.g. NOUN.\n"
|
|
"3: Show all category scores; e.g. R:NOUN.",
|
|
choices=[1, 2, 3],
|
|
type=int)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
# Input: An m2 format sentence with edits.
|
|
# Output: A list of lists. Each edit: [start, end, cat, cor, coder]
|
|
def simplify_edits(sent, max_answer_num):
|
|
out_edits = []
|
|
# Get the edit lines from an m2 block.
|
|
edits = sent.split("\n")
|
|
# Loop through the edits
|
|
for edit in edits:
|
|
# Preprocessing
|
|
if edit.startswith("A "):
|
|
edit = edit[2:].split("|||") # Ignore "A " then split.
|
|
span = edit[0].split()
|
|
start = int(span[0])
|
|
end = int(span[1])
|
|
cat = edit[1]
|
|
cor = edit[2].replace(" ", "")
|
|
coder = int(edit[-1])
|
|
out_edit = [start, end, cat, cor, coder]
|
|
out_edits.append(out_edit)
|
|
# return [edit for edit in out_edits if edit[-1] in [0,1]]
|
|
if max_answer_num is None:
|
|
return out_edits
|
|
elif max_answer_num == 1:
|
|
return [edit for edit in out_edits if edit[-1] == 0]
|
|
elif max_answer_num == 2:
|
|
return [edit for edit in out_edits if edit[-1] in [0, 1]]
|
|
elif max_answer_num == 3:
|
|
return [edit for edit in out_edits if edit[-1] in [0, 1, 2]]
|
|
|
|
# Input 1: A list of edits. Each edit: [start, end, cat, cor, coder]
|
|
# Input 2: Command line args
|
|
# Output: A dict; key is coder, value is edit dict.
|
|
def process_edits(edits, args):
|
|
coder_dict = {}
|
|
# Add an explicit noop edit if there are no edits.
|
|
if not edits: edits = [[-1, -1, "noop", "-NONE-", 0]]
|
|
# Loop through the edits
|
|
for edit in edits:
|
|
# Name the edit elements for clarity
|
|
start = edit[0]
|
|
end = edit[1]
|
|
cat = edit[2]
|
|
cor = edit[3]
|
|
coder = edit[4]
|
|
# Add the coder to the coder_dict if necessary
|
|
if coder not in coder_dict: coder_dict[coder] = {}
|
|
|
|
# Optionally apply filters based on args
|
|
# 1. UNK type edits are only useful for detection, not correction.
|
|
if not args.dt and not args.ds and cat == "UNK": continue
|
|
# 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1
|
|
if args.single and (end-start >= 2 or len(cor.split()) >= 2): continue
|
|
# 3. Only evaluate multi token edits; i.e. 2+:n or n:2+
|
|
if args.multi and end-start < 2 and len(cor.split()) < 2: continue
|
|
# 4. If there is a filter, ignore the specified error types
|
|
if args.filt and cat in args.filt: continue
|
|
|
|
# Token Based Detection
|
|
if args.dt:
|
|
# Preserve noop edits.
|
|
if start == -1:
|
|
if (start, start) in coder_dict[coder].keys():
|
|
coder_dict[coder][(start, start)].append(cat)
|
|
else:
|
|
coder_dict[coder][(start, start)] = [cat]
|
|
# Insertions defined as affecting the token on the right
|
|
elif start == end and start >= 0:
|
|
if (start, start+1) in coder_dict[coder].keys():
|
|
coder_dict[coder][(start, start+1)].append(cat)
|
|
else:
|
|
coder_dict[coder][(start, start+1)] = [cat]
|
|
# Edit spans are split for each token in the range.
|
|
else:
|
|
for tok_id in range(start, end):
|
|
if (tok_id, tok_id+1) in coder_dict[coder].keys():
|
|
coder_dict[coder][(tok_id, tok_id+1)].append(cat)
|
|
else:
|
|
coder_dict[coder][(tok_id, tok_id+1)] = [cat]
|
|
|
|
# Span Based Detection
|
|
elif args.ds:
|
|
if (start, end) in coder_dict[coder].keys():
|
|
coder_dict[coder][(start, end)].append(cat)
|
|
else:
|
|
coder_dict[coder][(start, end)] = [cat]
|
|
|
|
# Span Based Correction
|
|
else:
|
|
# With error type classification
|
|
if args.cse:
|
|
if (start, end, cat, cor) in coder_dict[coder].keys():
|
|
coder_dict[coder][(start, end, cat, cor)].append(cat)
|
|
else:
|
|
coder_dict[coder][(start, end, cat, cor)] = [cat]
|
|
# Without error type classification
|
|
else:
|
|
if (start, end, cor) in coder_dict[coder].keys():
|
|
coder_dict[coder][(start, end, cor)].append(cat)
|
|
else:
|
|
coder_dict[coder][(start, end, cor)] = [cat]
|
|
return coder_dict
|
|
|
|
# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits.
|
|
# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits.
|
|
# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far.
|
|
# Input 4: Sentence ID (for verbose output only)
|
|
# Input 5: Command line args
|
|
# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
|
|
# Output 2: The corresponding error type dict for the above dict.
|
|
def evaluate_edits(src, hyp_dict, ref_dict, best, sent_id, args):
|
|
# Store the best sentence level scores and hyp+ref combination IDs
|
|
# best_f is initialised as -1 cause 0 is a valid result.
|
|
best_tp, best_fp, best_fn, best_f, best_hyp, best_ref = 0, 0, 0, -1, 0, 0
|
|
best_cat = {}
|
|
# skip not annotatable sentence
|
|
if len(ref_dict.keys()) == 1:
|
|
ref_id = list(ref_dict.keys())[0]
|
|
if len(ref_dict[ref_id].keys()) == 1:
|
|
cat = list(ref_dict[ref_id].values())[0][0]
|
|
if cat == "NA":
|
|
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
|
|
return best_dict, best_cat
|
|
|
|
# Compare each hyp and ref combination
|
|
for hyp_id in hyp_dict.keys():
|
|
for ref_id in ref_dict.keys():
|
|
# Get the local counts for the current combination.
|
|
tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id])
|
|
# Compute the local sentence scores (for verbose output only)
|
|
loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta)
|
|
# Compute the global sentence scores
|
|
p, r, f = computeFScore(
|
|
tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta)
|
|
# Save the scores if they are better in terms of:
|
|
# 1. Higher F-score
|
|
# 2. Same F-score, higher TP
|
|
# 3. Same F-score and TP, lower FP
|
|
# 4. Same F-score, TP and FP, lower FN
|
|
if (f > best_f) or \
|
|
(f == best_f and tp > best_tp) or \
|
|
(f == best_f and tp == best_tp and fp < best_fp) or \
|
|
(f == best_f and tp == best_tp and fp == best_fp and fn < best_fn):
|
|
best_tp, best_fp, best_fn = tp, fp, fn
|
|
best_f, best_hyp, best_ref = f, hyp_id, ref_id
|
|
best_cat = cat_dict
|
|
# Verbose output
|
|
if args.verbose:
|
|
# Prepare verbose output edits.
|
|
hyp_verb = list(sorted(hyp_dict[hyp_id].keys()))
|
|
ref_verb = list(sorted(ref_dict[ref_id].keys()))
|
|
# Ignore noop edits
|
|
if not hyp_verb or hyp_verb[0][0] == -1: hyp_verb = []
|
|
if not ref_verb or ref_verb[0][0] == -1: ref_verb = []
|
|
# Print verbose info
|
|
print('{:-^40}'.format(""))
|
|
print("SENTENCE "+str(sent_id)+src[1:])
|
|
print('{:-^40}'.format(""))
|
|
print("SENTENCE "+str(sent_id)+" - HYP "+str(hyp_id)+" - REF "+str(ref_id))
|
|
print("HYPOTHESIS EDITS :", hyp_verb)
|
|
print("REFERENCE EDITS :", ref_verb)
|
|
print("Local TP/FP/FN :", str(tp), str(fp), str(fn))
|
|
print("Local P/R/F"+str(args.beta)+" :", str(loc_p), str(loc_r), str(loc_f))
|
|
print("Global TP/FP/FN :", str(tp+best["tp"]), str(fp+best["fp"]), str(fn+best["fn"]))
|
|
print("Global P/R/F"+str(args.beta)+" :", str(p), str(r), str(f))
|
|
# Verbose output: display the best hyp+ref combination
|
|
if args.verbose:
|
|
print('{:-^40}'.format(""))
|
|
print("^^ HYP "+str(best_hyp)+", REF "+str(best_ref)+" chosen for sentence "+str(sent_id))
|
|
# Save the best TP, FP and FNs as a dict, and return this and the best_cat dict
|
|
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
|
|
return best_dict, best_cat
|
|
|
|
# Input 1: A dictionary of hypothesis edits for a single system.
|
|
# Input 2: A dictionary of reference edits for a single annotator.
|
|
# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator.
|
|
# Output 4: A dictionary of the error type counts.
|
|
def compareEdits(hyp_edits, ref_edits):
|
|
tp = 0 # True Positives
|
|
fp = 0 # False Positives
|
|
fn = 0 # False Negatives
|
|
cat_dict = {} # {cat: [tp, fp, fn], ...}
|
|
|
|
for h_edit, h_cats in hyp_edits.items():
|
|
# noop hyp edits cannot be TP or FP
|
|
if h_cats[0] == "noop": continue
|
|
# TRUE POSITIVES
|
|
if h_edit in ref_edits.keys():
|
|
# On occasion, multiple tokens at same span.
|
|
for h_cat in ref_edits[h_edit]: # Use ref dict for TP
|
|
tp += 1
|
|
# Each dict value [TP, FP, FN]
|
|
if h_cat in cat_dict.keys():
|
|
cat_dict[h_cat][0] += 1
|
|
else:
|
|
cat_dict[h_cat] = [1, 0, 0]
|
|
# FALSE POSITIVES
|
|
else:
|
|
# On occasion, multiple tokens at same span.
|
|
for h_cat in h_cats:
|
|
fp += 1
|
|
# Each dict value [TP, FP, FN]
|
|
if h_cat in cat_dict.keys():
|
|
cat_dict[h_cat][1] += 1
|
|
else:
|
|
cat_dict[h_cat] = [0, 1, 0]
|
|
for r_edit, r_cats in ref_edits.items():
|
|
# noop ref edits cannot be FN
|
|
if r_cats[0] == "noop": continue
|
|
# FALSE NEGATIVES
|
|
if r_edit not in hyp_edits.keys():
|
|
# On occasion, multiple tokens at same span.
|
|
for r_cat in r_cats:
|
|
fn += 1
|
|
# Each dict value [TP, FP, FN]
|
|
if r_cat in cat_dict.keys():
|
|
cat_dict[r_cat][2] += 1
|
|
else:
|
|
cat_dict[r_cat] = [0, 0, 1]
|
|
return tp, fp, fn, cat_dict
|
|
|
|
# Input 1-3: True positives, false positives, false negatives
|
|
# Input 4: Value of beta in F-score.
|
|
# Output 1-3: Precision, Recall and F-score rounded to 4dp.
|
|
def computeFScore(tp, fp, fn, beta):
|
|
p = float(tp)/(tp+fp) if fp else 1.0
|
|
r = float(tp)/(tp+fn) if fn else 1.0
|
|
f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0
|
|
return round(p, 4), round(r, 4), round(f, 4)
|
|
|
|
# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN.
|
|
# Output: The dictionaries combined with cumulative TP, FP, FN.
|
|
def merge_dict(dict1, dict2):
|
|
for cat, stats in dict2.items():
|
|
if cat in dict1.keys():
|
|
dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)]
|
|
else:
|
|
dict1[cat] = stats
|
|
return dict1
|
|
|
|
# Input 1: A dict; key is error cat, value is counts for [tp, fp, fn]
|
|
# Input 2: Integer value denoting level of error category granularity.
|
|
# 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything.
|
|
# Output: A dictionary of category TP, FP and FN based on Input 2.
|
|
def processCategories(cat_dict, setting):
|
|
# Otherwise, do some processing.
|
|
proc_cat_dict = {}
|
|
for cat, cnt in cat_dict.items():
|
|
if cat == "UNK":
|
|
proc_cat_dict[cat] = cnt
|
|
continue
|
|
# M, U, R or UNK combined only.
|
|
if setting == 1:
|
|
if cat[0] in proc_cat_dict.keys():
|
|
proc_cat_dict[cat[0]] = [x+y for x, y in zip(proc_cat_dict[cat[0]], cnt)]
|
|
else:
|
|
proc_cat_dict[cat[0]] = cnt
|
|
# Everything without M, U or R.
|
|
elif setting == 2:
|
|
if cat[2:] in proc_cat_dict.keys():
|
|
proc_cat_dict[cat[2:]] = [x+y for x, y in zip(proc_cat_dict[cat[2:]], cnt)]
|
|
else:
|
|
proc_cat_dict[cat[2:]] = cnt
|
|
# All error category combinations
|
|
else:
|
|
return cat_dict
|
|
return proc_cat_dict
|
|
|
|
# Input 1: A dict of global best TP, FP and FNs
|
|
# Input 2: A dict of error types and counts for those TP, FP and FNs
|
|
# Input 3: Command line args
|
|
def print_results(best, best_cats, args):
|
|
# Prepare output title.
|
|
if args.dt: title = " Token-Based Detection "
|
|
elif args.ds: title = " Span-Based Detection "
|
|
elif args.cse: title = " Span-Based Correction + Classification "
|
|
else: title = " Span-Based Correction "
|
|
|
|
# Category Scores
|
|
if args.cat:
|
|
best_cats = processCategories(best_cats, args.cat)
|
|
print("")
|
|
print('{:=^66}'.format(title))
|
|
print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8),
|
|
"P".ljust(8), "R".ljust(8), "F"+str(args.beta))
|
|
for cat, cnts in sorted(best_cats.items()):
|
|
cat_p, cat_r, cat_f = computeFScore(cnts[0], cnts[1], cnts[2], args.beta)
|
|
print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8),
|
|
str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f)
|
|
|
|
# Print the overall results.
|
|
print("")
|
|
print('{:=^46}'.format(title))
|
|
print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)]))
|
|
print("\t".join(map(str, [best["tp"], best["fp"],
|
|
best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta)))))
|
|
print('{:=^46}'.format(""))
|
|
print("")
|
|
|
|
if __name__ == "__main__":
|
|
# Run the program
|
|
main()
|