diff --git a/configs/datasets/mbpp_plus/mbpp_plus_gen.py b/configs/datasets/mbpp_plus/mbpp_plus_gen.py new file mode 100644 index 00000000..5a1ce3da --- /dev/null +++ b/configs/datasets/mbpp_plus/mbpp_plus_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from.mbpp_plus_gen_94815c import mbpp_plus_datasets # noqa: F401, F403 \ No newline at end of file diff --git a/configs/datasets/mbpp_plus/mbpp_plus_gen_94815c.py b/configs/datasets/mbpp_plus/mbpp_plus_gen_94815c.py new file mode 100644 index 00000000..da28a21a --- /dev/null +++ b/configs/datasets/mbpp_plus/mbpp_plus_gen_94815c.py @@ -0,0 +1,64 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import MBPPEvaluator, MBPPPlusDataset + +mbpp_plus_reader_cfg = dict( + input_columns=['text', 'test_list'], output_column='task_id') + +mbpp_plus_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict( + role="HUMAN", + prompt= + "You are an expert Python programmer, and here is your task: Write a function to find the shared elements from the given two lists. Your code should pass these tests:\n\n assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)\n assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4) \n assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14) \n" + ), + dict( + role="BOT", + prompt= + "[BEGIN]\n 'def similar_elements(test_tup1, test_tup2):\n return tuple(set(test_tup1) & set(test_tup2))' \n[DONE] \n\n " + ), + dict( + role="HUMAN", + prompt= + "You are an expert Python programmer, and here is your task: Write a python function to identify non-prime numbers. Your code should pass these tests:\n\n assert is_not_prime(2) == False \n assert is_not_prime(10) == True \n assert is_not_prime(35) == True \n" + ), + dict( + role="BOT", + prompt= + "[BEGIN]\n 'import math\ndef is_not_prime(n):\n if n == 1:\n return True\n for i in range(2, int(math.sqrt(n))+1):\n if n % i == 0:\n return True\n return False' \n[DONE] \n\n " + ), + dict( + role="HUMAN", + prompt= + "You are an expert Python programmer, and here is your task: Write a function to find the n largest integers from a given list of numbers, returned in descending order. Your code should pass these tests:\n\n assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] \n assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] \n assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35] \n" + ), + dict( + role="BOT", + prompt= + "[BEGIN]\n 'import heapq as hq\ndef heap_queue_largest(nums: list,n: int) -> list:\n largest_nums = hq.nlargest(n, nums)\n return largest_nums' \n[DONE] \n\n " + ), + dict( + role="HUMAN", + prompt= + "You are an expert Python programmer, and here is your task: {text} Your code should pass these tests:\n\n {test_list} \n" + ), + dict(role="BOT", prompt="[BEGIN]\n"), + ], )), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=512)) + +mbpp_plus_eval_cfg = dict(evaluator=dict(type=MBPPEvaluator, metric='MBPPPlus'), pred_role="BOT") + +mbpp_plus_datasets = [ + dict( + type=MBPPPlusDataset, + abbr='mbpp_plus', + path='./data/mbpp_plus/mbpp_plus.jsonl', + reader_cfg=mbpp_plus_reader_cfg, + infer_cfg=mbpp_plus_infer_cfg, + eval_cfg=mbpp_plus_eval_cfg) +] diff --git a/opencompass/datasets/mbpp.py b/opencompass/datasets/mbpp.py index 6fc3dd44..5bbf76e4 100644 --- a/opencompass/datasets/mbpp.py +++ b/opencompass/datasets/mbpp.py @@ -1,15 +1,18 @@ import contextlib import io import itertools +import json import multiprocessing +import os.path as osp import re import signal +import tempfile from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Sequence, Union import numpy as np -from datasets import DatasetDict, concatenate_datasets, load_dataset +from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset from opencompass.openicl.icl_evaluator import BaseEvaluator from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET @@ -110,6 +113,35 @@ class SanitizedMBPPDataset(BaseDataset): return DatasetDict({'train': train, 'test': test}) +class MBPPPlusDataset(BaseDataset): + + @staticmethod + def load(path: str, num_repeats: int = 1): + """Load mbpp dataset for pass k mode. Note that you can use + num_repeats. + + > 1 when your model does not support `num_return_sequence` in + generation, otherwise use the raw mbpp dataset and set + `num_return_sequence` in model config to generate multiple responses + for testing pass@k>1. + + It better to change your dataset abbr correspondingly if you want to + change num_repeats>1, otherwise the number in + `.cache/dataset_size.json` might be inconsistent. + + Args: + num_repeats(int): Number of repetition for this dataset to get + multiple responses in special cases. + """ + + dataset = [] + with open(path, 'r', encoding='utf-8') as f: + for line in f: + dataset.extend( + [json.loads(line.strip()) for _ in range(num_repeats)]) + return Dataset.from_list(dataset) + + class TimeOutException(Exception): pass @@ -160,36 +192,75 @@ class redirect_stdin(contextlib._RedirectStream): # type: ignore @ICL_EVALUATORS.register_module() class MBPPEvaluator(BaseEvaluator): + """Evaluator for MBPP or MBPPPlus.""" + + def __init__(self, metric: str = 'MBPP') -> None: + self.metric = metric + assert self.metric in ['MBPP', 'MBPPPlus'] def score(self, predictions, references): assert len(predictions) == len(references) predictions = [self._process_answer(pred) for pred in predictions] - result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0} - details = {} - for index, (test_case, pred) in enumerate(zip(references, - predictions)): - programs = self._process_test(test_case, pred) - try: - # Add exec globals to prevent the exec to raise - # unnecessary NameError for correct answer - exec_globals = {} - with swallow_io(): - with time_limit(2): - exec(programs, exec_globals) - r = 'pass' - except TimeOutException: - r = 'timeout' - except AssertionError: - r = 'wrong_answer' - except BaseException: - r = 'failed' - result[r] += 1 - details[str(index)] = {'programs': programs, 'result': r} + if self.metric == 'MBPP': + result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0} + details = {} + for index, (test_case, + pred) in enumerate(zip(references, predictions)): + programs = self._process_test(test_case, pred) + try: + # Add exec globals to prevent the exec to raise + # unnecessary NameError for correct answer + exec_globals = {} + with swallow_io(): + with time_limit(2): + exec(programs, exec_globals) + r = 'pass' + except TimeOutException: + r = 'timeout' + except AssertionError: + r = 'wrong_answer' + except BaseException: + r = 'failed' + result[r] += 1 + details[str(index)] = {'programs': programs, 'result': r} - result['score'] = result['pass'] / len(predictions) * 100 - result['details'] = details - return result + result['score'] = result['pass'] / len(predictions) * 100 + result['details'] = details + return result + else: + try: + from evalplus.data import write_jsonl + from evalplus.evaluate import evaluate + self.write_jsonl = write_jsonl + self.eval = evaluate + except ImportError: + raise ImportError( + 'Please install evalplus use following steps:\n' + 'git clone --recurse-submodules git@github.com:open-compass/human-eval.git\n' # noqa + 'cd human-eval\n' + 'pip install -e .\n' + 'pip install -e evalplus\n') + mbpp_preds = [] + for preds, refer in zip(predictions, references): + if not isinstance(preds, list): + preds = [preds] + for pred in preds: + mbpp_preds.append({'task_id': refer, 'solution': pred}) + with tempfile.TemporaryDirectory() as tmp_dir: + out_dir = osp.join(tmp_dir, 'mbpp_eval.jsonl') + self.write_jsonl(out_dir, mbpp_preds) + flags = dict(dataset='mbpp', + samples=out_dir, + base_only=None, + parallel=None, + i_just_wanna_run=None, + test_details=0.2, + min_time_limit=0.2, + gt_time_limit_factor=4.0, + mini=None) + score = self.eval(flags) + return {f'mbpp_plus_{k}': score[k] * 100 for k in score} def _process_answer(self, text): text = text.strip()