mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
244 lines
7.6 KiB
Python
244 lines
7.6 KiB
Python
![]() |
import re
|
|||
|
|
|||
|
import pandas as pd
|
|||
|
from datasets import Dataset
|
|||
|
|
|||
|
from opencompass.openicl import BaseEvaluator
|
|||
|
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
|||
|
from opencompass.utils import get_data_path, get_logger
|
|||
|
|
|||
|
from .base import BaseDataset
|
|||
|
|
|||
|
|
|||
|
def _parse(item: dict, prompt_mode: str) -> dict:
|
|||
|
# 构建选项列表,忽略空字符串的 ope
|
|||
|
options_keys = ['opa', 'opb', 'opc', 'opd']
|
|||
|
if item.get('ope', '') != '':
|
|||
|
options_keys.append('ope')
|
|||
|
options_list = [item.get(k, '') for k in options_keys]
|
|||
|
item['options'] = options_list
|
|||
|
|
|||
|
# 构建带标号的选项字符串
|
|||
|
options_str = '\n'.join(
|
|||
|
[f'{chr(65 + i)}. {opt}' for i, opt in enumerate(options_list)])
|
|||
|
|
|||
|
# 将选项附加到问题末尾
|
|||
|
item['question'] = f"{item.get('question', '')}\n{options_str}"
|
|||
|
|
|||
|
# 标签及其他字段
|
|||
|
item['label'] = item.get('answer_idx')
|
|||
|
item['prompt_mode'] = prompt_mode
|
|||
|
item['start'] = chr(65)
|
|||
|
item['end'] = chr(65 + len(options_list) - 1)
|
|||
|
return item
|
|||
|
|
|||
|
|
|||
|
@LOAD_DATASET.register_module()
|
|||
|
class MedbulletsDataset(BaseDataset):
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def load(path: str, prompt_mode: str = 'zero-shot', **kwargs):
|
|||
|
# 读取 CSV 文件为 DataFrame,并将 NaN 转为空字符串
|
|||
|
path = get_data_path(path)
|
|||
|
df = pd.read_csv(path, encoding='utf-8')
|
|||
|
df = df.fillna('')
|
|||
|
|
|||
|
# 转换为字典列表
|
|||
|
data_list = df.to_dict(orient='records')
|
|||
|
|
|||
|
# 将数据列表包装为 Dataset
|
|||
|
dataset = Dataset.from_list(data_list)
|
|||
|
|
|||
|
# 根据提示模式进行解析
|
|||
|
if prompt_mode == 'zero-shot':
|
|||
|
dataset = dataset.map(lambda item: _parse(item, prompt_mode))
|
|||
|
elif prompt_mode == 'few-shot':
|
|||
|
pass # TODO: Implement few-shot prompt handling
|
|||
|
return dataset
|
|||
|
|
|||
|
|
|||
|
class MedbulletsEvaluator(BaseEvaluator):
|
|||
|
|
|||
|
def score(self, predictions, references, test_set):
|
|||
|
method = test_set['prompt_mode'][0]
|
|||
|
|
|||
|
if len(predictions) != len(references):
|
|||
|
return {'error': 'preds and refrs have different length'}
|
|||
|
correct = 0
|
|||
|
count = 0
|
|||
|
details = []
|
|||
|
for idx, (i, j) in enumerate(zip(predictions, references)):
|
|||
|
i = answer_cleansing(method, i, test_set['options'][idx],
|
|||
|
test_set['label'][idx])
|
|||
|
detail = {
|
|||
|
'pred': i,
|
|||
|
'answer': j,
|
|||
|
'correct': False,
|
|||
|
'question_type': test_set['question_type'][idx]
|
|||
|
}
|
|||
|
count += 1
|
|||
|
if i == j:
|
|||
|
correct += 1
|
|||
|
detail['correct'] = True
|
|||
|
details.append(detail)
|
|||
|
result = {'accuracy': 100 * correct / count, 'details': details}
|
|||
|
return result
|
|||
|
|
|||
|
|
|||
|
@TEXT_POSTPROCESSORS.register_module()
|
|||
|
def answer_cleansing(
|
|||
|
method: str,
|
|||
|
prediction: str,
|
|||
|
options: list,
|
|||
|
label: str,
|
|||
|
) -> str:
|
|||
|
|
|||
|
# Clean up unwanted phrases in the prediction
|
|||
|
for unwanted_phrase in [
|
|||
|
'I understand',
|
|||
|
'A through J',
|
|||
|
'A through E',
|
|||
|
'A through D',
|
|||
|
]:
|
|||
|
prediction = prediction.replace(unwanted_phrase, '')
|
|||
|
|
|||
|
options_num = len(options)
|
|||
|
options = [chr(65 + i) for i in range(options_num)]
|
|||
|
options_str = r'\b(' + '|'.join(options) + r')\b'
|
|||
|
prediction = re.findall(options_str, prediction)
|
|||
|
|
|||
|
if len(prediction) == 0:
|
|||
|
prediction = []
|
|||
|
return prediction
|
|||
|
else:
|
|||
|
# If there is a "label" and its length is 1,
|
|||
|
# process prediction accordingly
|
|||
|
if len(label) == 1:
|
|||
|
if method == 'few-shot':
|
|||
|
answer_flag = True if len(prediction) > 1 else False
|
|||
|
# choose the first or last element based on the answer_flag
|
|||
|
if answer_flag:
|
|||
|
prediction = [prediction[0]]
|
|||
|
else:
|
|||
|
prediction = [prediction[-1]]
|
|||
|
elif method == 'zero-shot':
|
|||
|
# choose the first element in list
|
|||
|
prediction = [prediction[0]]
|
|||
|
else:
|
|||
|
raise ValueError('Method is not properly defined ...')
|
|||
|
|
|||
|
# Remove trailing period if it exists
|
|||
|
if prediction[0] and prediction[0].endswith('.'):
|
|||
|
prediction[0] = prediction[0][:-1]
|
|||
|
|
|||
|
return prediction[0]
|
|||
|
|
|||
|
|
|||
|
def _generic_llmjudge_postprocess(judgement: str):
|
|||
|
match = re.search(r'(A|B)', judgement)
|
|||
|
grade_letter = (match.group(0) if match else 'B'
|
|||
|
) # Default to "INCORRECT" if no match
|
|||
|
return grade_letter
|
|||
|
|
|||
|
|
|||
|
def medbullets_llmjudge_postprocess(
|
|||
|
output: dict,
|
|||
|
output_path: str,
|
|||
|
dataset: Dataset,
|
|||
|
) -> dict:
|
|||
|
original_dataset = dataset.reader.dataset['test']
|
|||
|
|
|||
|
judged_answers = []
|
|||
|
original_responses = []
|
|||
|
references = []
|
|||
|
details = []
|
|||
|
|
|||
|
# Initialize statistics dictionaries
|
|||
|
stats = {'question_type': {}}
|
|||
|
|
|||
|
total_correct = 0
|
|||
|
total_count = 0
|
|||
|
|
|||
|
# Process each sample
|
|||
|
for k, v in output.items():
|
|||
|
idx = int(k) # Convert key to integer for indexing
|
|||
|
original_responses.append(v['prediction'])
|
|||
|
processed_judge = _generic_llmjudge_postprocess(v['prediction'])
|
|||
|
|
|||
|
# Get category information from the dataset
|
|||
|
sample = original_dataset[idx]
|
|||
|
question_type = sample.get('question_type', 'unknown')
|
|||
|
|
|||
|
# Initialize category stats if not exists
|
|||
|
for level, key in [
|
|||
|
('question_type', question_type),
|
|||
|
]:
|
|||
|
if key not in stats[level]:
|
|||
|
stats[level][key] = {'correct': 0, 'total': 0}
|
|||
|
|
|||
|
# Record the judgment
|
|||
|
if processed_judge is not None:
|
|||
|
judged_answers.append(processed_judge)
|
|||
|
try:
|
|||
|
gold = v['gold']
|
|||
|
references.append(gold)
|
|||
|
except KeyError:
|
|||
|
get_logger().warning(
|
|||
|
f'No gold answer for {k}, use empty string as reference!')
|
|||
|
gold = ''
|
|||
|
references.append('')
|
|||
|
|
|||
|
# Check if the answer is correct (A means correct)
|
|||
|
is_correct = processed_judge == 'A'
|
|||
|
total_count += 1
|
|||
|
|
|||
|
if is_correct:
|
|||
|
total_correct += 1
|
|||
|
# Update category stats
|
|||
|
for level, key in [
|
|||
|
('question_type', question_type),
|
|||
|
]:
|
|||
|
stats[level][key]['correct'] += 1
|
|||
|
|
|||
|
# Update category totals
|
|||
|
for level, key in [
|
|||
|
('question_type', question_type),
|
|||
|
]:
|
|||
|
stats[level][key]['total'] += 1
|
|||
|
# Add to details
|
|||
|
details.append({
|
|||
|
'id': k,
|
|||
|
'origin_prompt': v['origin_prompt'],
|
|||
|
'llm_judge': processed_judge,
|
|||
|
'gold': gold,
|
|||
|
'is_correct': is_correct,
|
|||
|
'question_type': question_type,
|
|||
|
})
|
|||
|
|
|||
|
# Calculate overall accuracy with two decimal places
|
|||
|
overall_accuracy = (round(
|
|||
|
(total_correct / total_count * 100), 2) if total_count > 0 else 0.00)
|
|||
|
|
|||
|
# Initialize results dictionary
|
|||
|
results = {
|
|||
|
'accuracy': overall_accuracy,
|
|||
|
'total_correct': total_correct,
|
|||
|
'total_count': total_count,
|
|||
|
'details': details,
|
|||
|
}
|
|||
|
|
|||
|
# Calculate accuracy for each category and flatten into results
|
|||
|
for level in stats:
|
|||
|
for key, value in stats[level].items():
|
|||
|
if value['total'] > 0:
|
|||
|
# Calculate accuracy with two decimal places
|
|||
|
accuracy = round((value['correct'] / value['total'] * 100), 2)
|
|||
|
|
|||
|
# Create a flattened key for the category
|
|||
|
flat_key = f'Medbullets-{key}'
|
|||
|
|
|||
|
# Add to results
|
|||
|
results[flat_key] = accuracy
|
|||
|
|
|||
|
return results
|