Support Mbpp_plus dataset (#770)

* support mbpp+

* support mbpp+

* minor fix

* [Feat] minor fix

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>
This commit is contained in:
Connor-Shen 2024-01-05 22:01:57 +08:00 committed by GitHub
parent 3c606cb712
commit 30a90d8dd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 164 additions and 25 deletions

View File

@ -0,0 +1,4 @@
from mmengine.config import read_base
with read_base():
from.mbpp_plus_gen_94815c import mbpp_plus_datasets # noqa: F401, F403

View File

@ -0,0 +1,64 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import MBPPEvaluator, MBPPPlusDataset
mbpp_plus_reader_cfg = dict(
input_columns=['text', 'test_list'], output_column='task_id')
mbpp_plus_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role="HUMAN",
prompt=
"You are an expert Python programmer, and here is your task: Write a function to find the shared elements from the given two lists. Your code should pass these tests:\n\n assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)\n assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4) \n assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14) \n"
),
dict(
role="BOT",
prompt=
"[BEGIN]\n 'def similar_elements(test_tup1, test_tup2):\n return tuple(set(test_tup1) & set(test_tup2))' \n[DONE] \n\n "
),
dict(
role="HUMAN",
prompt=
"You are an expert Python programmer, and here is your task: Write a python function to identify non-prime numbers. Your code should pass these tests:\n\n assert is_not_prime(2) == False \n assert is_not_prime(10) == True \n assert is_not_prime(35) == True \n"
),
dict(
role="BOT",
prompt=
"[BEGIN]\n 'import math\ndef is_not_prime(n):\n if n == 1:\n return True\n for i in range(2, int(math.sqrt(n))+1):\n if n % i == 0:\n return True\n return False' \n[DONE] \n\n "
),
dict(
role="HUMAN",
prompt=
"You are an expert Python programmer, and here is your task: Write a function to find the n largest integers from a given list of numbers, returned in descending order. Your code should pass these tests:\n\n assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] \n assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] \n assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35] \n"
),
dict(
role="BOT",
prompt=
"[BEGIN]\n 'import heapq as hq\ndef heap_queue_largest(nums: list,n: int) -> list:\n largest_nums = hq.nlargest(n, nums)\n return largest_nums' \n[DONE] \n\n "
),
dict(
role="HUMAN",
prompt=
"You are an expert Python programmer, and here is your task: {text} Your code should pass these tests:\n\n {test_list} \n"
),
dict(role="BOT", prompt="[BEGIN]\n"),
], )),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=512))
mbpp_plus_eval_cfg = dict(evaluator=dict(type=MBPPEvaluator, metric='MBPPPlus'), pred_role="BOT")
mbpp_plus_datasets = [
dict(
type=MBPPPlusDataset,
abbr='mbpp_plus',
path='./data/mbpp_plus/mbpp_plus.jsonl',
reader_cfg=mbpp_plus_reader_cfg,
infer_cfg=mbpp_plus_infer_cfg,
eval_cfg=mbpp_plus_eval_cfg)
]

View File

@ -1,15 +1,18 @@
import contextlib import contextlib
import io import io
import itertools import itertools
import json
import multiprocessing import multiprocessing
import os.path as osp
import re import re
import signal import signal
import tempfile
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Sequence, Union from typing import List, Sequence, Union
import numpy as np import numpy as np
from datasets import DatasetDict, concatenate_datasets, load_dataset from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
@ -110,6 +113,35 @@ class SanitizedMBPPDataset(BaseDataset):
return DatasetDict({'train': train, 'test': test}) return DatasetDict({'train': train, 'test': test})
class MBPPPlusDataset(BaseDataset):
@staticmethod
def load(path: str, num_repeats: int = 1):
"""Load mbpp dataset for pass k mode. Note that you can use
num_repeats.
> 1 when your model does not support `num_return_sequence` in
generation, otherwise use the raw mbpp dataset and set
`num_return_sequence` in model config to generate multiple responses
for testing pass@k>1.
It better to change your dataset abbr correspondingly if you want to
change num_repeats>1, otherwise the number in
`.cache/dataset_size.json` might be inconsistent.
Args:
num_repeats(int): Number of repetition for this dataset to get
multiple responses in special cases.
"""
dataset = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
dataset.extend(
[json.loads(line.strip()) for _ in range(num_repeats)])
return Dataset.from_list(dataset)
class TimeOutException(Exception): class TimeOutException(Exception):
pass pass
@ -160,36 +192,75 @@ class redirect_stdin(contextlib._RedirectStream): # type: ignore
@ICL_EVALUATORS.register_module() @ICL_EVALUATORS.register_module()
class MBPPEvaluator(BaseEvaluator): class MBPPEvaluator(BaseEvaluator):
"""Evaluator for MBPP or MBPPPlus."""
def __init__(self, metric: str = 'MBPP') -> None:
self.metric = metric
assert self.metric in ['MBPP', 'MBPPPlus']
def score(self, predictions, references): def score(self, predictions, references):
assert len(predictions) == len(references) assert len(predictions) == len(references)
predictions = [self._process_answer(pred) for pred in predictions] predictions = [self._process_answer(pred) for pred in predictions]
result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0} if self.metric == 'MBPP':
details = {} result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0}
for index, (test_case, pred) in enumerate(zip(references, details = {}
predictions)): for index, (test_case,
programs = self._process_test(test_case, pred) pred) in enumerate(zip(references, predictions)):
try: programs = self._process_test(test_case, pred)
# Add exec globals to prevent the exec to raise try:
# unnecessary NameError for correct answer # Add exec globals to prevent the exec to raise
exec_globals = {} # unnecessary NameError for correct answer
with swallow_io(): exec_globals = {}
with time_limit(2): with swallow_io():
exec(programs, exec_globals) with time_limit(2):
r = 'pass' exec(programs, exec_globals)
except TimeOutException: r = 'pass'
r = 'timeout' except TimeOutException:
except AssertionError: r = 'timeout'
r = 'wrong_answer' except AssertionError:
except BaseException: r = 'wrong_answer'
r = 'failed' except BaseException:
result[r] += 1 r = 'failed'
details[str(index)] = {'programs': programs, 'result': r} result[r] += 1
details[str(index)] = {'programs': programs, 'result': r}
result['score'] = result['pass'] / len(predictions) * 100 result['score'] = result['pass'] / len(predictions) * 100
result['details'] = details result['details'] = details
return result return result
else:
try:
from evalplus.data import write_jsonl
from evalplus.evaluate import evaluate
self.write_jsonl = write_jsonl
self.eval = evaluate
except ImportError:
raise ImportError(
'Please install evalplus use following steps:\n'
'git clone --recurse-submodules git@github.com:open-compass/human-eval.git\n' # noqa
'cd human-eval\n'
'pip install -e .\n'
'pip install -e evalplus\n')
mbpp_preds = []
for preds, refer in zip(predictions, references):
if not isinstance(preds, list):
preds = [preds]
for pred in preds:
mbpp_preds.append({'task_id': refer, 'solution': pred})
with tempfile.TemporaryDirectory() as tmp_dir:
out_dir = osp.join(tmp_dir, 'mbpp_eval.jsonl')
self.write_jsonl(out_dir, mbpp_preds)
flags = dict(dataset='mbpp',
samples=out_dir,
base_only=None,
parallel=None,
i_just_wanna_run=None,
test_details=0.2,
min_time_limit=0.2,
gt_time_limit_factor=4.0,
mini=None)
score = self.eval(flags)
return {f'mbpp_plus_{k}': score[k] * 100 for k in score}
def _process_answer(self, text): def _process_answer(self, text):
text = text.strip() text = text.strip()