OpenCompass/opencompass/datasets/mbpp.py
Hubert ddb8197212
[Feat] support wizardcoder series (#344)
* [Feat] support wizardcoder series

* minor fix
2023-09-06 17:52:35 +08:00

162 lines
5.1 KiB
Python

import contextlib
import io
import re
import signal
from datasets import DatasetDict, load_dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class MBPPDataset(BaseDataset):
@staticmethod
def load(path: str):
def processing_test(example):
example['test_case'] = example['test_list']
example['test_list'] = '\n'.join(example['test_list'])
example['test_list_2'] = example['test_list']
return example
train = load_dataset('json', data_files=path,
split='train[:10]').map(processing_test)
test = load_dataset('json', data_files=path,
split='train[10:510]').map(processing_test)
return DatasetDict({'train': train, 'test': test})
class TimeOutException(Exception):
pass
@ICL_EVALUATORS.register_module()
class MBPPEvaluator(BaseEvaluator):
def score(self, predictions, references):
assert len(predictions) == len(references)
predictions = [self._process_answer(pred) for pred in predictions]
result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0}
for test_case, pred in zip(references, predictions):
programs = self._process_test(test_case, pred)
try:
# Add exec globals to prevent the exec to raise
# unnecessary NameError for correct answer
exec_globals = {}
with self.swallow_io():
with self.time_limit(2):
exec(programs, exec_globals)
result['pass'] += 1
except TimeOutException:
result['timeout'] += 1
except AssertionError:
result['wrong_answer'] += 1
except BaseException:
result['failed'] += 1
result['score'] = result['pass'] / len(predictions) * 100
return result
def _process_answer(self, text):
text = text.strip()
match = re.search(r"('\s*|)(\[DONE\]|DONE)", text)
if match:
text = text[:match.start()]
match = re.search(r"(\[BEGIN\]|BEGIN)('\s*|)", text)
if match:
text = text[match.end():]
text = text.strip()
if text.startswith("'"):
text = text[1:]
if text.endswith("'"):
text = text[:-1]
return text
def _process_test(self, test_case, pred):
formatted = pred + '\n'
formatted += test_case
return formatted
@contextlib.contextmanager
def swallow_io(self):
stream = self.WriteOnlyStringIO()
with contextlib.redirect_stdout(stream):
with contextlib.redirect_stderr(stream):
with self.redirect_stdin(stream):
yield
@contextlib.contextmanager
def time_limit(self, seconds: float):
def signal_handler(signum, frame):
raise TimeOutException('Time out!')
signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
class WriteOnlyStringIO(io.StringIO):
"""StringIO that throws an exception when it's read from."""
def read(self, *args, **kwargs):
raise IOError
def readline(self, *args, **kwargs):
raise IOError
def readlines(self, *args, **kwargs):
raise IOError
def readable(self, *args, **kwargs):
"""Returns True if the IO object can be read."""
return False
class redirect_stdin(contextlib._RedirectStream): # type: ignore
_stream = 'stdin'
@ICL_EVALUATORS.register_module()
class MBPPEvaluator2(MBPPEvaluator):
"""Better use for WizardCoder evaluation."""
def _process_answer(self, text):
if '```' in text:
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
if len(blocks) == 0:
text = text.split('```')[1] # fall back to default strategy
else:
text = blocks[0] # fetch the first code block
if not text.startswith(
'\n'): # in case starting with ```python
text = text[max(text.find('\n') + 1, 0):]
else:
match = re.search(r'Here(.*?)\n', text)
if match:
text = re.sub('Here(.*?)\n', '', text, count=1)
# remove test in generation
test_list = ['# Test', '#Test', '#test', '# test']
for s in test_list:
if s in text:
text = text[:text.find(s)]
text = text.strip()
match = re.search(r"('\s*|)(\[DONE\]|DONE)", text)
if match:
text = text[:match.start()]
match = re.search(r"(\[BEGIN\]|BEGIN)('\s*|)", text)
if match:
text = text[match.end():]
text = text.strip()
if text.startswith("'"):
text = text[1:]
return text