This commit is contained in:
huihui 2025-04-30 05:29:04 +00:00
parent dfa26b24bd
commit 44aadf627b
2 changed files with 4 additions and 125 deletions

View File

@ -47,7 +47,7 @@ eval_cfg = dict(
sciknoweval_dataset_biology = dict(
type=SciKnowEvalDataset,
abbr='sciknoweval_biology',
path='/root/workspace/opencompass/opencompass/configs/datasets/SciKnowEval/SciKnowEval/data/Biology',
path='hicai-zju/SciKnowEval',
prompt_mode='zero-shot',
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
@ -57,7 +57,7 @@ sciknoweval_dataset_biology = dict(
sciknoweval_dataset_chemistry = dict(
type=SciKnowEvalDataset,
abbr='sciknoweval_chemistry',
path='/root/workspace/opencompass/opencompass/configs/datasets/SciKnowEval/SciKnowEval/data/Chemistry',
path='hicai-zju/SciKnowEval',
prompt_mode='zero-shot',
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
@ -66,7 +66,7 @@ sciknoweval_dataset_chemistry = dict(
sciknoweval_dataset_material = dict(
type=SciKnowEvalDataset,
abbr='sciknoweval_material',
path='/root/workspace/opencompass/opencompass/configs/datasets/SciKnowEval/SciKnowEval/data/Material',
path='hicai-zju/SciKnowEval',
prompt_mode='zero-shot',
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
@ -76,7 +76,7 @@ sciknoweval_dataset_material = dict(
sciknoweval_dataset_physics = dict(
type=SciKnowEvalDataset,
abbr='sciknoweval_physics',
path='/root/workspace/opencompass/opencompass/configs/datasets/SciKnowEval/SciKnowEval/data/Physics',
path='hicai-zju/SciKnowEval',
prompt_mode='zero-shot',
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,

View File

@ -108,124 +108,3 @@ def answer_cleansing(
return prediction[0]
def _generic_llmjudge_postprocess(judgement: str):
match = re.search(r'(A|B)', judgement)
grade_letter = (match.group(0) if match else 'B'
) # Default to "INCORRECT" if no match
return grade_letter
def SciKnowEval_llmjudge_postprocess(
output: dict,
output_path: str,
dataset: Dataset,
) -> dict:
# Get the original dataset
original_dataset = dataset.reader.dataset['test']
judged_answers = []
original_responses = []
references = []
details = []
# Initialize statistics dictionaries
stats = {'medical_task': {}, 'body_system': {}, 'question_type': {}}
total_correct = 0
total_count = 0
# Process each sample
for k, v in output.items():
idx = int(k) # Convert key to integer for indexing
original_responses.append(v['prediction'])
processed_judge = _generic_llmjudge_postprocess(v['prediction'])
# Get category information from the dataset
sample = original_dataset[idx]
medical_task = sample.get('medical_task', 'unknown')
body_system = sample.get('body_system', 'unknown')
question_type = sample.get('question_type', 'unknown')
# Initialize category stats if not exists
for level, key in [
('medical_task', medical_task),
('body_system', body_system),
('question_type', question_type),
]:
if key not in stats[level]:
stats[level][key] = {'correct': 0, 'total': 0}
# Record the judgment
if processed_judge is not None:
judged_answers.append(processed_judge)
try:
gold = v['gold']
references.append(gold)
except KeyError:
get_logger().warning(
f'No gold answer for {k}, use empty string as reference!')
gold = ''
references.append('')
# Check if the answer is correct (A means correct)
is_correct = processed_judge == 'A'
total_count += 1
if is_correct:
total_correct += 1
# Update category stats
for level, key in [
('medical_task', medical_task),
('body_system', body_system),
('question_type', question_type),
]:
stats[level][key]['correct'] += 1
# Update category totals
for level, key in [
('medical_task', medical_task),
('body_system', body_system),
('question_type', question_type),
]:
stats[level][key]['total'] += 1
# Add to details
details.append({
'id': k,
'question': sample['question'],
'options': sample['options'],
'origin_prompt': v['origin_prompt'],
'llm_judge': processed_judge,
'gold': gold,
'is_correct': is_correct,
'medical_task': medical_task,
'body_system': body_system,
'question_type': question_type,
})
# Calculate overall accuracy with two decimal places
overall_accuracy = (round(
(total_correct / total_count * 100), 2) if total_count > 0 else 0.00)
# Initialize results dictionary
results = {
'accuracy': overall_accuracy,
'total_correct': total_correct,
'total_count': total_count,
'details': details,
}
# Calculate accuracy for each category and flatten into results
for level in stats:
for key, value in stats[level].items():
if value['total'] > 0:
# Calculate accuracy with two decimal places
accuracy = round((value['correct'] / value['total'] * 100), 2)
# Create a flattened key for the category
flat_key = f'MedXpertQA-{key}'
# Add to results
results[flat_key] = accuracy
return results