mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Support Mbpp_plus dataset (#770)
* support mbpp+ * support mbpp+ * minor fix * [Feat] minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com>
This commit is contained in:
parent
3c606cb712
commit
30a90d8dd8
4
configs/datasets/mbpp_plus/mbpp_plus_gen.py
Normal file
4
configs/datasets/mbpp_plus/mbpp_plus_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from.mbpp_plus_gen_94815c import mbpp_plus_datasets # noqa: F401, F403
|
64
configs/datasets/mbpp_plus/mbpp_plus_gen_94815c.py
Normal file
64
configs/datasets/mbpp_plus/mbpp_plus_gen_94815c.py
Normal file
@ -0,0 +1,64 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import MBPPEvaluator, MBPPPlusDataset
|
||||
|
||||
mbpp_plus_reader_cfg = dict(
|
||||
input_columns=['text', 'test_list'], output_column='task_id')
|
||||
|
||||
mbpp_plus_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"You are an expert Python programmer, and here is your task: Write a function to find the shared elements from the given two lists. Your code should pass these tests:\n\n assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)\n assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4) \n assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14) \n"
|
||||
),
|
||||
dict(
|
||||
role="BOT",
|
||||
prompt=
|
||||
"[BEGIN]\n 'def similar_elements(test_tup1, test_tup2):\n return tuple(set(test_tup1) & set(test_tup2))' \n[DONE] \n\n "
|
||||
),
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"You are an expert Python programmer, and here is your task: Write a python function to identify non-prime numbers. Your code should pass these tests:\n\n assert is_not_prime(2) == False \n assert is_not_prime(10) == True \n assert is_not_prime(35) == True \n"
|
||||
),
|
||||
dict(
|
||||
role="BOT",
|
||||
prompt=
|
||||
"[BEGIN]\n 'import math\ndef is_not_prime(n):\n if n == 1:\n return True\n for i in range(2, int(math.sqrt(n))+1):\n if n % i == 0:\n return True\n return False' \n[DONE] \n\n "
|
||||
),
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"You are an expert Python programmer, and here is your task: Write a function to find the n largest integers from a given list of numbers, returned in descending order. Your code should pass these tests:\n\n assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] \n assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] \n assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35] \n"
|
||||
),
|
||||
dict(
|
||||
role="BOT",
|
||||
prompt=
|
||||
"[BEGIN]\n 'import heapq as hq\ndef heap_queue_largest(nums: list,n: int) -> list:\n largest_nums = hq.nlargest(n, nums)\n return largest_nums' \n[DONE] \n\n "
|
||||
),
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"You are an expert Python programmer, and here is your task: {text} Your code should pass these tests:\n\n {test_list} \n"
|
||||
),
|
||||
dict(role="BOT", prompt="[BEGIN]\n"),
|
||||
], )),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512))
|
||||
|
||||
mbpp_plus_eval_cfg = dict(evaluator=dict(type=MBPPEvaluator, metric='MBPPPlus'), pred_role="BOT")
|
||||
|
||||
mbpp_plus_datasets = [
|
||||
dict(
|
||||
type=MBPPPlusDataset,
|
||||
abbr='mbpp_plus',
|
||||
path='./data/mbpp_plus/mbpp_plus.jsonl',
|
||||
reader_cfg=mbpp_plus_reader_cfg,
|
||||
infer_cfg=mbpp_plus_infer_cfg,
|
||||
eval_cfg=mbpp_plus_eval_cfg)
|
||||
]
|
@ -1,15 +1,18 @@
|
||||
import contextlib
|
||||
import io
|
||||
import itertools
|
||||
import json
|
||||
import multiprocessing
|
||||
import os.path as osp
|
||||
import re
|
||||
import signal
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import DatasetDict, concatenate_datasets, load_dataset
|
||||
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
|
||||
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
||||
@ -110,6 +113,35 @@ class SanitizedMBPPDataset(BaseDataset):
|
||||
return DatasetDict({'train': train, 'test': test})
|
||||
|
||||
|
||||
class MBPPPlusDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, num_repeats: int = 1):
|
||||
"""Load mbpp dataset for pass k mode. Note that you can use
|
||||
num_repeats.
|
||||
|
||||
> 1 when your model does not support `num_return_sequence` in
|
||||
generation, otherwise use the raw mbpp dataset and set
|
||||
`num_return_sequence` in model config to generate multiple responses
|
||||
for testing pass@k>1.
|
||||
|
||||
It better to change your dataset abbr correspondingly if you want to
|
||||
change num_repeats>1, otherwise the number in
|
||||
`.cache/dataset_size.json` might be inconsistent.
|
||||
|
||||
Args:
|
||||
num_repeats(int): Number of repetition for this dataset to get
|
||||
multiple responses in special cases.
|
||||
"""
|
||||
|
||||
dataset = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
dataset.extend(
|
||||
[json.loads(line.strip()) for _ in range(num_repeats)])
|
||||
return Dataset.from_list(dataset)
|
||||
|
||||
|
||||
class TimeOutException(Exception):
|
||||
pass
|
||||
|
||||
@ -160,15 +192,21 @@ class redirect_stdin(contextlib._RedirectStream): # type: ignore
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MBPPEvaluator(BaseEvaluator):
|
||||
"""Evaluator for MBPP or MBPPPlus."""
|
||||
|
||||
def __init__(self, metric: str = 'MBPP') -> None:
|
||||
self.metric = metric
|
||||
assert self.metric in ['MBPP', 'MBPPPlus']
|
||||
|
||||
def score(self, predictions, references):
|
||||
assert len(predictions) == len(references)
|
||||
predictions = [self._process_answer(pred) for pred in predictions]
|
||||
|
||||
if self.metric == 'MBPP':
|
||||
result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0}
|
||||
details = {}
|
||||
for index, (test_case, pred) in enumerate(zip(references,
|
||||
predictions)):
|
||||
for index, (test_case,
|
||||
pred) in enumerate(zip(references, predictions)):
|
||||
programs = self._process_test(test_case, pred)
|
||||
try:
|
||||
# Add exec globals to prevent the exec to raise
|
||||
@ -190,6 +228,39 @@ class MBPPEvaluator(BaseEvaluator):
|
||||
result['score'] = result['pass'] / len(predictions) * 100
|
||||
result['details'] = details
|
||||
return result
|
||||
else:
|
||||
try:
|
||||
from evalplus.data import write_jsonl
|
||||
from evalplus.evaluate import evaluate
|
||||
self.write_jsonl = write_jsonl
|
||||
self.eval = evaluate
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please install evalplus use following steps:\n'
|
||||
'git clone --recurse-submodules git@github.com:open-compass/human-eval.git\n' # noqa
|
||||
'cd human-eval\n'
|
||||
'pip install -e .\n'
|
||||
'pip install -e evalplus\n')
|
||||
mbpp_preds = []
|
||||
for preds, refer in zip(predictions, references):
|
||||
if not isinstance(preds, list):
|
||||
preds = [preds]
|
||||
for pred in preds:
|
||||
mbpp_preds.append({'task_id': refer, 'solution': pred})
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
out_dir = osp.join(tmp_dir, 'mbpp_eval.jsonl')
|
||||
self.write_jsonl(out_dir, mbpp_preds)
|
||||
flags = dict(dataset='mbpp',
|
||||
samples=out_dir,
|
||||
base_only=None,
|
||||
parallel=None,
|
||||
i_just_wanna_run=None,
|
||||
test_details=0.2,
|
||||
min_time_limit=0.2,
|
||||
gt_time_limit_factor=4.0,
|
||||
mini=None)
|
||||
score = self.eval(flags)
|
||||
return {f'mbpp_plus_{k}': score[k] * 100 for k in score}
|
||||
|
||||
def _process_answer(self, text):
|
||||
text = text.strip()
|
||||
|
Loading…
Reference in New Issue
Block a user