OpenCompass/opencompass/datasets/compassbench_obj.py
liushz e49fcfd3a3
[Update] Update MATH dataset with model judge (#1711)
* Update math with llm judge

* Update math with llm judge

* Update math with llm judge

* Update math with llm judge

* Update math with llm judge
2024-11-25 15:14:55 +08:00

114 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
import json
import re
from datasets import Dataset
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from .base import BaseDataset
def get_number(options):
result_string = ''
for i, option in enumerate(options, start=65):
result_string += f'{chr(i)}. {option}\n'
return result_string
@LOAD_DATASET.register_module()
class CompassBenchObjectiveV1_3(BaseDataset):
@staticmethod
def load(path: str, name: str):
circular_patterns = ['ABCD', 'BCDA', 'CDAB', 'DABC']
data = []
with open(path, 'r', encoding='utf-8', errors='ignore') as infile:
for id, line in enumerate(infile):
entry = json.loads(line)
if 'cloze' in name:
data.append({
'question': entry['question'].strip(),
'answer': entry['answer'].strip(),
})
elif 'circular' in name:
for c in circular_patterns:
line = copy.deepcopy(entry)
options = []
for i in range(4):
options.append(line['options'][ord(c[i]) -
ord('A')])
line['options'] = options
line['answer'] = {
c[0]: 'A',
c[1]: 'B',
c[2]: 'C',
c[3]: 'D'
}[line['answer']]
line['answer'] = str(
id) + '--' + line['answer'] + '--' + c
line['question'] = (line['question'].strip() + '\n' +
get_number(line['options']))
data.append(line)
else:
# treat as normal single choice question
entry['question'] = (entry['question'].strip() + '\n' +
get_number(entry['options']))
data.append(entry)
dataset = Dataset.from_list(data)
return dataset
@LOAD_DATASET.register_module()
class CompassBenchObjectiveMath(BaseDataset):
@staticmethod
def load(path: str):
with open(path, 'r') as infile:
data = [json.loads(line) for line in infile]
for idx in range(len(data)):
item = data[idx]
prefix = ''
if item.get('question_type',
None) and item['question_type'] in [
'multiple-answer', '多选题'
]:
if '_en_' in path:
prefix = 'This question may has multiple answers, \
please select all correct answers. like this: A, B, C as your final answer\n'
else:
prefix = '这道题可能有多个正确答案,请选择所有正确的答案,\
例如A, B, C 作为你的最终答案\n'
if item.get('options', None) and len(item['options']) != 0:
item['question'] = prefix + item[
'question'] + '\n' + get_number(item['options'])
dataset = Dataset.from_list(data)
return dataset
@TEXT_POSTPROCESSORS.register_module()
def compassbench_objective_v1_3_postprocess(text: str, name) -> str:
split = False
ans = text
if '_cn' in name:
ans_line = ans.split('答案是')
else:
ans_line = ans.split('The answer is')
if len(ans_line) != 1:
ans = ans_line[1].strip()
split = True
output = re.sub(r'(\d),(\d)', r'\1\2', ans)
numbers = re.findall(r'-?\d*\.?/?\d+|\d+', output)
if numbers:
return numbers[0] if split else numbers[-1]
return ans