diff --git a/opencompass/utils/text_postprocessors.py b/opencompass/utils/text_postprocessors.py index 899610c3..b8aa8924 100644 --- a/opencompass/utils/text_postprocessors.py +++ b/opencompass/utils/text_postprocessors.py @@ -307,3 +307,41 @@ def extract_answer_before_evaluation(text: str): answer = '\n\n'.join(text_split[max(ans_start_idx - 1, 0):]) return answer + + +@TEXT_POSTPROCESSORS.register_module( + 'extract_qwq_answer_before_eval_for_humatchingfib') +def extract_answer_before_evaluation(text: str): + """The format of the answer from QwQ when inferring HuSimpleQA is \ + different with others models due to the special prompt.""" + max_sentence_len = 30 + if len(re.findall(r'\n\n', text)) > 2: + split_mark = '\n\n' + else: + split_mark = '\n' + text_split = text.split(split_mark) + last_try_idx = max(len(text_split) - max_sentence_len, 0) + ans_start_idx = last_try_idx + has_answer = False + answer_flags = [ + 'answer', 'Answer', 'Válasz', 'válasz', '答案', '回答', 'Summar', 'summar', + 'Summar', 'summar' + ] + + for idx, s in enumerate(reversed(text_split)): + sen_idx = len(text_split) - 1 - idx + if sen_idx < last_try_idx: + break + + for af in answer_flags: + if af in s: + has_answer = True + break + + if has_answer: + ans_start_idx = sen_idx + break + + answer = '\n\n'.join(text_split[max(ans_start_idx - 1, 0):]) + + return answer, has_answer