mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
update
This commit is contained in:
parent
23fb3c7fa9
commit
1acb3c30c0
@ -1,7 +1,8 @@
|
|||||||
from opencompass.datasets import NejmaibenchDataset, nejmaibench_llmjudge_postprocess
|
from opencompass.datasets import NejmaibenchDataset
|
||||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.datasets import generic_llmjudge_postprocess
|
||||||
from opencompass.evaluator import GenericLLMEvaluator
|
from opencompass.evaluator import GenericLLMEvaluator
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -88,7 +89,7 @@ eval_cfg = dict(
|
|||||||
reader_cfg=reader_cfg,
|
reader_cfg=reader_cfg,
|
||||||
),
|
),
|
||||||
judge_cfg=dict(),
|
judge_cfg=dict(),
|
||||||
dict_postprocessor=dict(type=nejmaibench_llmjudge_postprocess),
|
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from datasets import Dataset
|
|||||||
|
|
||||||
from opencompass.openicl import BaseEvaluator
|
from opencompass.openicl import BaseEvaluator
|
||||||
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
||||||
from opencompass.utils import get_data_path, get_logger
|
from opencompass.utils import get_data_path
|
||||||
|
|
||||||
from .base import BaseDataset
|
from .base import BaseDataset
|
||||||
|
|
||||||
@ -137,116 +137,3 @@ def answer_cleansing(
|
|||||||
prediction[0] = prediction[0][:-1]
|
prediction[0] = prediction[0][:-1]
|
||||||
|
|
||||||
return prediction[0]
|
return prediction[0]
|
||||||
|
|
||||||
|
|
||||||
def _generic_llmjudge_postprocess(judgement: str):
|
|
||||||
match = re.search(r'(A|B)', judgement)
|
|
||||||
grade_letter = (match.group(0) if match else 'B'
|
|
||||||
) # Default to "INCORRECT" if no match
|
|
||||||
return grade_letter
|
|
||||||
|
|
||||||
|
|
||||||
def nejmaibench_llmjudge_postprocess(
|
|
||||||
output: dict,
|
|
||||||
output_path: str,
|
|
||||||
dataset: Dataset,
|
|
||||||
) -> dict:
|
|
||||||
# Get the original dataset
|
|
||||||
original_dataset = dataset.reader.dataset['test']
|
|
||||||
|
|
||||||
judged_answers = []
|
|
||||||
original_responses = []
|
|
||||||
references = []
|
|
||||||
details = []
|
|
||||||
|
|
||||||
# Initialize statistics dictionaries
|
|
||||||
stats = {'Subject': {}}
|
|
||||||
|
|
||||||
total_correct = 0
|
|
||||||
total_count = 0
|
|
||||||
|
|
||||||
# Process each sample
|
|
||||||
for k, v in output.items():
|
|
||||||
idx = int(k) # Convert key to integer for indexing
|
|
||||||
original_responses.append(v['prediction'])
|
|
||||||
|
|
||||||
processed_judge = _generic_llmjudge_postprocess(v['prediction'])
|
|
||||||
|
|
||||||
# Get category information from the dataset
|
|
||||||
sample = original_dataset[idx]
|
|
||||||
subject = sample.get('Subject', 'unknown')
|
|
||||||
|
|
||||||
# Initialize category stats if not exists
|
|
||||||
for level, key in [
|
|
||||||
('Subject', subject),
|
|
||||||
]:
|
|
||||||
if key not in stats[level]:
|
|
||||||
stats[level][key] = {'correct': 0, 'total': 0}
|
|
||||||
|
|
||||||
# Record the judgment
|
|
||||||
if processed_judge is not None:
|
|
||||||
judged_answers.append(processed_judge)
|
|
||||||
try:
|
|
||||||
gold = v['gold']
|
|
||||||
references.append(gold)
|
|
||||||
except KeyError:
|
|
||||||
get_logger().warning(
|
|
||||||
f'No gold answer for {k}, use empty string as reference!')
|
|
||||||
gold = ''
|
|
||||||
references.append('')
|
|
||||||
|
|
||||||
# Check if the answer is correct (A means correct)
|
|
||||||
is_correct = processed_judge == 'A'
|
|
||||||
total_count += 1
|
|
||||||
|
|
||||||
if is_correct:
|
|
||||||
total_correct += 1
|
|
||||||
# Update category stats
|
|
||||||
for level, key in [
|
|
||||||
('Subject', subject),
|
|
||||||
]:
|
|
||||||
stats[level][key]['correct'] += 1
|
|
||||||
|
|
||||||
# Update category totals
|
|
||||||
for level, key in [
|
|
||||||
('Subject', subject),
|
|
||||||
]:
|
|
||||||
stats[level][key]['total'] += 1
|
|
||||||
# Add to details
|
|
||||||
details.append({
|
|
||||||
'id': k,
|
|
||||||
'question': sample['question'],
|
|
||||||
'options': sample['options'],
|
|
||||||
'origin_prompt': v['origin_prompt'],
|
|
||||||
'llm_judge': processed_judge,
|
|
||||||
'gold': gold,
|
|
||||||
'is_correct': is_correct,
|
|
||||||
'Subject': subject,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Calculate overall accuracy with two decimal places
|
|
||||||
overall_accuracy = (round(
|
|
||||||
(total_correct / total_count * 100), 2) if total_count > 0 else 0.00)
|
|
||||||
|
|
||||||
# Initialize results dictionary
|
|
||||||
results = {
|
|
||||||
'accuracy': overall_accuracy,
|
|
||||||
'total_correct': total_correct,
|
|
||||||
'total_count': total_count,
|
|
||||||
'details': details,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Calculate accuracy for each category and flatten into results
|
|
||||||
for level in stats:
|
|
||||||
for key, value in stats[level].items():
|
|
||||||
if value['total'] > 0:
|
|
||||||
# Calculate accuracy with two decimal places
|
|
||||||
accuracy = round((value['correct'] / value['total'] * 100), 2)
|
|
||||||
|
|
||||||
# Create a flattened key for the category
|
|
||||||
flat_key = f'nejmaibench-{key}'
|
|
||||||
|
|
||||||
# Add to results
|
|
||||||
results[flat_key] = accuracy
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
@ -448,8 +448,8 @@ DATASETS_MAPPING = {
|
|||||||
},
|
},
|
||||||
"opencompass/nejmaibench": {
|
"opencompass/nejmaibench": {
|
||||||
"ms_id": "",
|
"ms_id": "",
|
||||||
"hf_id": "SeanWu25/NEJM-AI_Benchmarking_Medical_Language_Models",
|
"hf_id": "",
|
||||||
"local": "./opencompass/configs/datasets/nejm_ai_benchmark/data/NEJM_All_Questions_And_Answers.csv",
|
"local": "./data/nejmaibench/NEJM_All_Questions_And_Answers.csv",
|
||||||
},
|
},
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -803,6 +803,11 @@ DATASETS_URL = {
|
|||||||
"url":
|
"url":
|
||||||
"http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/ChemBench4K.zip",
|
"http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/ChemBench4K.zip",
|
||||||
"md5": "fc23fd21b2566a5dbbebfa4601d7779c"
|
"md5": "fc23fd21b2566a5dbbebfa4601d7779c"
|
||||||
|
},
|
||||||
|
"nejmaibench": {
|
||||||
|
"url":
|
||||||
|
"http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/nejmaibench.zip",
|
||||||
|
"md5": "e6082cae3596b3ebea73e23ba445b99e"
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user