mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
add some features (#32)
* [Feature] Support answer extraction of QwQ when evaluating HuSimpleQA * [Feature] Support mulit-language summarization in HuSimpleQASummarizer * [Feature] Support DeepSeep-R1-Distill-Qwen_32B_turbomind
This commit is contained in:
parent
b6c8165ca3
commit
879b181c1b
@ -3,7 +3,7 @@ from mmengine.config import read_base
|
|||||||
from opencompass.summarizers.subjective.husimpleqa import HuSimpleQASummarizer
|
from opencompass.summarizers.subjective.husimpleqa import HuSimpleQASummarizer
|
||||||
|
|
||||||
with read_base():
|
with read_base():
|
||||||
from opencompass.configs.datasets.OpenHuEval.HuSimpleQA.HuSimpleQA import HuSimpleQA_datasets
|
from opencompass.configs.datasets.OpenHuEval.HuSimpleQA.HuSimpleQA import HuSimpleQA_datasets, PROMPT_LANGUAGES
|
||||||
|
|
||||||
from opencompass.configs.models.openai.gpt_4o_mini_20240718 import models as gpt_4o_mini_20240718_model
|
from opencompass.configs.models.openai.gpt_4o_mini_20240718 import models as gpt_4o_mini_20240718_model
|
||||||
from opencompass.configs.models.openai.gpt_4o_2024_11_20 import models as gpt_4o_2024_11_20_model
|
from opencompass.configs.models.openai.gpt_4o_2024_11_20 import models as gpt_4o_2024_11_20_model
|
||||||
@ -14,7 +14,7 @@ with read_base():
|
|||||||
from opencompass.configs.models.hf_llama.lmdeploy_llama3_1_8b_instruct import models as lmdeploy_llama3_1_8b_instruct_model
|
from opencompass.configs.models.hf_llama.lmdeploy_llama3_1_8b_instruct import models as lmdeploy_llama3_1_8b_instruct_model
|
||||||
from opencompass.configs.models.hf_llama.lmdeploy_llama3_1_70b_instruct import models as lmdeploy_llama3_1_70b_instruct_model
|
from opencompass.configs.models.hf_llama.lmdeploy_llama3_1_70b_instruct import models as lmdeploy_llama3_1_70b_instruct_model
|
||||||
|
|
||||||
from opencompass.configs.models.hf_internlm.lmdeploy_internlm3_8b_instruct import models as lmdeploy_internlm3_8b_instruct_model
|
# from opencompass.configs.models.hf_internlm.lmdeploy_internlm3_8b_instruct import models as lmdeploy_internlm3_8b_instruct_model
|
||||||
|
|
||||||
from opencompass.configs.models.qwq.lmdeploy_qwq_32b_preview import models as lmdeploy_qwq_32b_preview_model
|
from opencompass.configs.models.qwq.lmdeploy_qwq_32b_preview import models as lmdeploy_qwq_32b_preview_model
|
||||||
from opencompass.configs.models.deepseek.deepseek_r1_api_aliyun import models as deepseek_r1_api_aliyun_model
|
from opencompass.configs.models.deepseek.deepseek_r1_api_aliyun import models as deepseek_r1_api_aliyun_model
|
||||||
@ -45,6 +45,12 @@ for model in models:
|
|||||||
'type': 'rm_<think>_before_eval'
|
'type': 'rm_<think>_before_eval'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if model['abbr'].startswith('QwQ'):
|
||||||
|
model['pred_postprocessor'] = {
|
||||||
|
'OpenHuEval_*': {
|
||||||
|
'type': 'extract_qwq_answer_before_eval_for_husimpleqa'
|
||||||
|
}
|
||||||
|
}
|
||||||
del model
|
del model
|
||||||
|
|
||||||
|
|
||||||
@ -92,7 +98,8 @@ eval = dict(
|
|||||||
task=dict(type=SubjectiveEvalTask)),
|
task=dict(type=SubjectiveEvalTask)),
|
||||||
)
|
)
|
||||||
|
|
||||||
summarizer = dict(type=HuSimpleQASummarizer)
|
summarizer = dict(type=HuSimpleQASummarizer,
|
||||||
|
prompt_languages=PROMPT_LANGUAGES)
|
||||||
|
|
||||||
work_dir = (
|
work_dir = (
|
||||||
'./outputs/' + __file__.split('/')[-1].split('.')[0] + '/'
|
'./outputs/' + __file__.split('/')[-1].split('.')[0] + '/'
|
||||||
|
@ -0,0 +1,15 @@
|
|||||||
|
from opencompass.models import TurboMindModelwithChatTemplate
|
||||||
|
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
type=TurboMindModelwithChatTemplate,
|
||||||
|
abbr='deepseek_r1_distill_qwen_32b_turbomind',
|
||||||
|
path='deepseek-ai/DeepSeek-R1-Distill-Qwen-32B',
|
||||||
|
engine_config=dict(session_len=16384, max_batch_size=16, tp=2),
|
||||||
|
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=4096),
|
||||||
|
max_seq_len=16384,
|
||||||
|
max_out_len=4096,
|
||||||
|
batch_size=16,
|
||||||
|
run_cfg=dict(num_gpus=2),
|
||||||
|
)
|
||||||
|
]
|
@ -60,6 +60,7 @@ def get_capability_results(
|
|||||||
writer.writerow(col_name)
|
writer.writerow(col_name)
|
||||||
writer.writerow(column)
|
writer.writerow(column)
|
||||||
|
|
||||||
|
|
||||||
class HuSimpleQASummarizer:
|
class HuSimpleQASummarizer:
|
||||||
"""Do the subjectivity analyze based on evaluation results.
|
"""Do the subjectivity analyze based on evaluation results.
|
||||||
|
|
||||||
@ -67,10 +68,11 @@ class HuSimpleQASummarizer:
|
|||||||
config (ConfigDict): The configuration object of the evaluation task.
|
config (ConfigDict): The configuration object of the evaluation task.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: ConfigDict) -> None:
|
def __init__(self, config: ConfigDict, prompt_languages) -> None:
|
||||||
self.judge_type = 'single'
|
self.judge_type = 'single'
|
||||||
self.tasks = []
|
self.tasks = []
|
||||||
self.cfg = config
|
self.cfg = config
|
||||||
|
self.prompt_languages = prompt_languages
|
||||||
|
|
||||||
self.eval_model_cfgs = self.cfg['eval']['partitioner']['models']
|
self.eval_model_cfgs = self.cfg['eval']['partitioner']['models']
|
||||||
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_models'][0])
|
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_models'][0])
|
||||||
@ -85,30 +87,32 @@ class HuSimpleQASummarizer:
|
|||||||
Returns:
|
Returns:
|
||||||
pd.DataFrame: The summary results.
|
pd.DataFrame: The summary results.
|
||||||
"""
|
"""
|
||||||
|
for language in self.prompt_languages:
|
||||||
|
dataset_cfgs = self.cfg['datasets']
|
||||||
|
output_dir, results_folder = get_outdir(self.cfg, time_str)
|
||||||
|
fout_flag = 0
|
||||||
|
for eval_model_cfg in self.eval_model_cfgs:
|
||||||
|
eval_model_abbr = model_abbr_from_cfg(eval_model_cfg)
|
||||||
|
show_model_abbr = model_abbr_from_cfg_used_in_summarizer(eval_model_cfg)
|
||||||
|
subdir_path = os.path.join(results_folder, eval_model_abbr + '_judged-by--' + self.judge_abbr)
|
||||||
|
if os.path.isdir(subdir_path):
|
||||||
|
fout = osp.join(output_dir, 'judged-by--' + self.judge_abbr + '-capability' + '_' + language + '.csv')
|
||||||
|
overall_judged_answers, overall_references = [], []
|
||||||
|
for dataset in dataset_cfgs:
|
||||||
|
if not dataset['abbr'].endswith('_' + language):
|
||||||
|
continue
|
||||||
|
judged_answers, references = get_judgeanswer_and_reference(dataset, subdir_path, self.judge_function)
|
||||||
|
judged_answers = [item['judge'] for item in judged_answers]
|
||||||
|
overall_judged_answers += judged_answers
|
||||||
|
overall_references += references
|
||||||
|
|
||||||
dataset_cfgs = self.cfg['datasets']
|
get_capability_results(
|
||||||
output_dir, results_folder = get_outdir(self.cfg, time_str)
|
overall_judged_answers,
|
||||||
fout_flag = 0
|
overall_references,
|
||||||
for eval_model_cfg in self.eval_model_cfgs:
|
fout,
|
||||||
eval_model_abbr = model_abbr_from_cfg(eval_model_cfg)
|
fout_flag,
|
||||||
show_model_abbr = model_abbr_from_cfg_used_in_summarizer(eval_model_cfg)
|
show_model_abbr,
|
||||||
subdir_path = os.path.join(results_folder, eval_model_abbr + '_judged-by--' + self.judge_abbr)
|
)
|
||||||
if os.path.isdir(subdir_path):
|
fout_flag += 1
|
||||||
fout = osp.join(output_dir, 'judged-by--' + self.judge_abbr + '-capability.csv')
|
else:
|
||||||
overall_judged_answers, overall_references = [], []
|
print(subdir_path + ' is not exist! please check!')
|
||||||
for dataset in dataset_cfgs:
|
|
||||||
judged_answers, references = get_judgeanswer_and_reference(dataset, subdir_path, self.judge_function)
|
|
||||||
judged_answers = [item['judge'] for item in judged_answers]
|
|
||||||
overall_judged_answers += judged_answers
|
|
||||||
overall_references += references
|
|
||||||
|
|
||||||
get_capability_results(
|
|
||||||
overall_judged_answers,
|
|
||||||
overall_references,
|
|
||||||
fout,
|
|
||||||
fout_flag,
|
|
||||||
show_model_abbr,
|
|
||||||
)
|
|
||||||
fout_flag += 1
|
|
||||||
else:
|
|
||||||
print(subdir_path + ' is not exist! please check!')
|
|
||||||
|
@ -248,10 +248,11 @@ def extract_answer_before_evaluation(text: str):
|
|||||||
"""Overall, there are three situations in responses of QWQ:
|
"""Overall, there are three situations in responses of QWQ:
|
||||||
|
|
||||||
1. There is a **Final Answer** title in the whole context.
|
1. There is a **Final Answer** title in the whole context.
|
||||||
2. There is only one sentence in the context.
|
2. There is only one sentence in the context.
|
||||||
3. There are more than one sentences in the context, \
|
3. There are more than one sentences in the context, \
|
||||||
and the last one is the answer.
|
and the last one is the answer.
|
||||||
"""
|
"""
|
||||||
|
text = text.strip('\n')
|
||||||
if '**Final Answer**' in text:
|
if '**Final Answer**' in text:
|
||||||
answer = text.split('\n\n**Final Answer**\n\n')[-1]
|
answer = text.split('\n\n**Final Answer**\n\n')[-1]
|
||||||
else:
|
else:
|
||||||
@ -264,3 +265,45 @@ def extract_answer_before_evaluation(text: str):
|
|||||||
else:
|
else:
|
||||||
answer = text_split[-1] + '.'
|
answer = text_split[-1] + '.'
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
|
||||||
|
@TEXT_POSTPROCESSORS.register_module(
|
||||||
|
'extract_qwq_answer_before_eval_for_husimpleqa')
|
||||||
|
def extract_answer_before_evaluation(text: str):
|
||||||
|
"""The format of the answer from QwQ when inferring HuSimpleQA is \
|
||||||
|
different with others models due to the special prompt."""
|
||||||
|
max_sentence_len = 6
|
||||||
|
text_split = text.split('\n\n')
|
||||||
|
last_try_idx = max(len(text_split) - max_sentence_len, 0)
|
||||||
|
ans_start_idx = last_try_idx
|
||||||
|
has_answer = False
|
||||||
|
has_score = False
|
||||||
|
score_flags = [
|
||||||
|
'score', 'Score', 'confidence', 'Confidence', 'szcore', 'Szcore',
|
||||||
|
'pontszám', 'Pontszám', 'Biztonság', 'biztonság', 'Biztoságskor',
|
||||||
|
'biztoságskor', 'Biztoság', '信心', '分数'
|
||||||
|
]
|
||||||
|
answer_flags = ['answer', 'Answer', 'Válasz', 'válasz', '答案', '回答']
|
||||||
|
|
||||||
|
for idx, s in enumerate(reversed(text_split)):
|
||||||
|
sen_idx = len(text_split) - 1 - idx
|
||||||
|
if sen_idx < last_try_idx:
|
||||||
|
break
|
||||||
|
|
||||||
|
for sf in score_flags:
|
||||||
|
if sf in s:
|
||||||
|
has_score = True
|
||||||
|
break
|
||||||
|
|
||||||
|
for af in answer_flags:
|
||||||
|
if af in s:
|
||||||
|
has_answer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if has_answer and has_score:
|
||||||
|
ans_start_idx = sen_idx
|
||||||
|
break
|
||||||
|
|
||||||
|
answer = '\n\n'.join(text_split[max(ans_start_idx - 1, 0):])
|
||||||
|
|
||||||
|
return answer
|
||||||
|
Loading…
Reference in New Issue
Block a user