mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
55 lines
2.4 KiB
Python
55 lines
2.4 KiB
Python
import math
|
|
|
|
import re
|
|
|
|
#法律判决预测-刑期预测
|
|
def compute_ljp_imprison(data_dict):
|
|
import cn2an
|
|
|
|
score_list, abstentions = [], 0
|
|
|
|
for example in data_dict:
|
|
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
|
|
# get answer digit, which is the number between "刑期:" and "个月"
|
|
if "死刑" in answer or "无期" in answer:
|
|
# TODO: data imperfection
|
|
continue
|
|
|
|
assert answer.startswith("刑期:") and answer.endswith("个月"), f"answer: {answer}, question: {question}"
|
|
answer = answer.replace("刑期:", "")
|
|
answer = answer.replace("个月", "")
|
|
answer_digit = int(answer)
|
|
prediction = cn2an.transform(prediction, "cn2an")
|
|
|
|
# use regular expression to extract the digits from prediction, only consider digits before "个月" or "月"
|
|
prediction_digit_month_list = re.findall(r"\d+个月", prediction)
|
|
prediction_digit_month_list = [int(digit.replace("个月", "")) for digit in prediction_digit_month_list]
|
|
prediction_digit_month_list2 = re.findall(r"\d+月", prediction)
|
|
prediction_digit_month_list2 = [int(digit.replace("月", "")) for digit in prediction_digit_month_list2]
|
|
prediction_digit_month_list.extend(prediction_digit_month_list2)
|
|
# catches the digits before "年"
|
|
prediction_digit_year_list = re.findall(r"\d+年", prediction)
|
|
prediction_digit_year_list = [int(digit.replace("年", "")) for digit in prediction_digit_year_list]
|
|
|
|
if len(prediction_digit_month_list) > 0:
|
|
prediction_digit_month = int(prediction_digit_month_list[0])
|
|
elif len(prediction_digit_year_list) > 0:
|
|
prediction_digit_month = int(prediction_digit_year_list[0]) * 12
|
|
else:
|
|
abstentions += 1
|
|
prediction_digit_month = -1
|
|
|
|
if prediction_digit_month != -1:
|
|
score_list.append(abs(math.log(answer_digit + 1) - math.log(prediction_digit_month + 1)))
|
|
else:
|
|
score_list.append(math.log(216))
|
|
|
|
if abstentions == len(score_list):
|
|
log_distance = 0
|
|
else:
|
|
# compute the average of score_list (log distance)
|
|
log_distance = sum(score_list) / len(score_list)
|
|
# normalize the score to between 0 and 1
|
|
log_distance = (math.log(216) - log_distance)/math.log(216)
|
|
return {"score": log_distance, "abstention_rate": abstentions/len(data_dict)}
|