mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Enhancement] Add humaneval postprocessor for GPT models & eval config for GPT4, enhance the original humaneval postprocessor (#129)
* [Enhancement] Enhance humaneval postprocessor * add human-eval testcase * update * update --------- Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
parent
3f36db3b06
commit
2931f3dcb8
@ -1,7 +1,7 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import HFDataset, HumanEvaluator
|
||||
from opencompass.datasets import HFDataset, HumanEvaluator, humaneval_postprocess
|
||||
|
||||
humaneval_reader_cfg = dict(
|
||||
input_columns=['prompt'], output_column='task_id', train_split='test')
|
||||
@ -17,7 +17,7 @@ humaneval_infer_cfg = dict(
|
||||
humaneval_eval_cfg = dict(
|
||||
evaluator=dict(type=HumanEvaluator),
|
||||
k=[1, 10, 100], # the parameter only for humaneval
|
||||
pred_postprocessor=dict(type='humaneval'),
|
||||
pred_postprocessor=dict(type=humaneval_postprocess),
|
||||
)
|
||||
|
||||
humaneval_datasets = [
|
||||
|
40
configs/eval_gpt4.py
Normal file
40
configs/eval_gpt4.py
Normal file
@ -0,0 +1,40 @@
|
||||
from mmengine.config import read_base
|
||||
from opencompass.models import OpenAI
|
||||
from opencompass.partitioners import NaivePartitioner
|
||||
from opencompass.runners import LocalRunner
|
||||
from opencompass.tasks import OpenICLInferTask
|
||||
|
||||
with read_base():
|
||||
from .datasets.collections.chat_medium import datasets
|
||||
from .summarizers.medium import summarizer
|
||||
|
||||
# GPT4 needs a special humaneval postprocessor
|
||||
from opencompass.datasets.humaneval import humaneval_gpt_postprocess
|
||||
for _dataset in datasets:
|
||||
if _dataset['path'] == 'openai_humaneval':
|
||||
_dataset['eval_cfg']['pred_postprocessor']['type'] = humaneval_gpt_postprocess
|
||||
|
||||
|
||||
api_meta_template = dict(
|
||||
round=[
|
||||
dict(role='HUMAN', api_role='HUMAN'),
|
||||
dict(role='BOT', api_role='BOT', generate=True),
|
||||
],
|
||||
)
|
||||
|
||||
models = [
|
||||
dict(abbr='GPT4',
|
||||
type=OpenAI, path='gpt-4-0613',
|
||||
key='ENV', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
|
||||
meta_template=api_meta_template,
|
||||
query_per_second=1,
|
||||
max_out_len=2048, max_seq_len=2048, batch_size=8),
|
||||
]
|
||||
|
||||
infer = dict(
|
||||
partitioner=dict(type=NaivePartitioner),
|
||||
runner=dict(
|
||||
type=LocalRunner,
|
||||
max_num_workers=4,
|
||||
task=dict(type=OpenICLInferTask)),
|
||||
)
|
@ -1,12 +1,11 @@
|
||||
import os.path as osp
|
||||
import re
|
||||
import tempfile
|
||||
from typing import List
|
||||
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
from opencompass.registry import ICL_EVALUATORS, TEXT_POSTPROCESSORS
|
||||
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class HumanEvaluator(BaseEvaluator):
|
||||
"""Evaluator for human eval."""
|
||||
|
||||
@ -41,11 +40,46 @@ class HumanEvaluator(BaseEvaluator):
|
||||
return {f'humaneval_{k}': score[k] * 100 for k in score}
|
||||
|
||||
|
||||
@TEXT_POSTPROCESSORS.register_module('humaneval')
|
||||
def humaneval_postprocess(text: str) -> str:
|
||||
text = text.split('\n\n')[0]
|
||||
if '```' in text:
|
||||
text = text.split('```')[1]
|
||||
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
|
||||
if len(blocks) == 0:
|
||||
text = text.split('```')[1] # fall back to default strategy
|
||||
else:
|
||||
text = blocks[0] # fetch the first code block
|
||||
if not text.startswith('\n'): # in case starting with ```python
|
||||
text = text[max(text.find('\n') + 1, 0):]
|
||||
if text.strip().startswith('from') or text.strip().startswith('import'):
|
||||
def_idx = text.find('def')
|
||||
if def_idx != -1:
|
||||
text = text[max(text.find('\n', def_idx) + 1, 0):]
|
||||
text = text.split('\n\n')[0]
|
||||
if text.strip().startswith('def'):
|
||||
text = '\n'.join(text.split('\n')[1:])
|
||||
if not text.startswith(' '):
|
||||
if text.startswith(' '):
|
||||
text = ' ' + text.lstrip()
|
||||
else:
|
||||
text = '\n'.join([' ' + line for line in text.split('\n')])
|
||||
return text
|
||||
|
||||
|
||||
def humaneval_gpt_postprocess(text: str) -> str:
|
||||
"""Better answer postprocessor for better instruction-aligned models like
|
||||
GPT."""
|
||||
if '```' in text:
|
||||
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
|
||||
if len(blocks) == 0:
|
||||
text = text.split('```')[1] # fall back to default strategy
|
||||
else:
|
||||
text = blocks[0] # fetch the first code block
|
||||
if not text.startswith('\n'): # in case starting with ```python
|
||||
text = text[max(text.find('\n') + 1, 0):]
|
||||
if text.strip().startswith('from') or text.strip().startswith('import'):
|
||||
def_idx = text.find('def')
|
||||
if def_idx != -1:
|
||||
text = text[max(text.find('\n', def_idx) + 1, 0):]
|
||||
text = text.split('\n\n\n')[0]
|
||||
if text.strip().startswith('def'):
|
||||
text = '\n'.join(text.split('\n')[1:])
|
||||
if not text.startswith(' '):
|
||||
|
110
tests/dataset/test_humaneval.py
Normal file
110
tests/dataset/test_humaneval.py
Normal file
@ -0,0 +1,110 @@
|
||||
import unittest
|
||||
|
||||
from opencompass.datasets.humaneval import humaneval_postprocess
|
||||
|
||||
|
||||
def run_humaneval_check(completion):
|
||||
program = [
|
||||
'def get_fraction(x: float) -> float:',
|
||||
humaneval_postprocess(completion),
|
||||
'',
|
||||
'assert get_fraction(1.28) == 0.28',
|
||||
'assert get_fraction(1.0) == 0.0',
|
||||
]
|
||||
program = '\n'.join(program)
|
||||
exec(program)
|
||||
|
||||
|
||||
class TestHumaneval(unittest.TestCase):
|
||||
|
||||
def test_vanilla(self):
|
||||
raw = ' return x - int(x)'
|
||||
run_humaneval_check(raw)
|
||||
|
||||
def test_python_quote(self):
|
||||
lines = [
|
||||
'```python',
|
||||
' return x - int(x)',
|
||||
'```',
|
||||
]
|
||||
raw = '\n'.join(lines)
|
||||
run_humaneval_check(raw)
|
||||
|
||||
def test_bare_quote(self):
|
||||
lines = [
|
||||
'```',
|
||||
' return x - int(x)',
|
||||
'```',
|
||||
]
|
||||
raw = '\n'.join(lines)
|
||||
run_humaneval_check(raw)
|
||||
|
||||
def test_error_space_quote(self):
|
||||
lines = [
|
||||
'```',
|
||||
' return x - int(x)',
|
||||
'```',
|
||||
]
|
||||
raw = '\n'.join(lines)
|
||||
run_humaneval_check(raw)
|
||||
|
||||
def test_import_1(self):
|
||||
lines = [
|
||||
'import numpy as np',
|
||||
'import math',
|
||||
'from typing import List',
|
||||
'',
|
||||
'def func(x):',
|
||||
' return x - int(x)',
|
||||
]
|
||||
raw = '\n'.join(lines)
|
||||
run_humaneval_check(raw)
|
||||
|
||||
def test_import_2(self):
|
||||
lines = [
|
||||
'from typing import List',
|
||||
'import numpy as np',
|
||||
'import math',
|
||||
'def func(x):',
|
||||
' return x - int(x)',
|
||||
]
|
||||
raw = '\n'.join(lines)
|
||||
run_humaneval_check(raw)
|
||||
|
||||
def test_import_3(self):
|
||||
lines = [
|
||||
'import math',
|
||||
'',
|
||||
'',
|
||||
'def func(x):',
|
||||
' return x - int(x)',
|
||||
]
|
||||
raw = '\n'.join(lines)
|
||||
run_humaneval_check(raw)
|
||||
|
||||
def test_comment(self):
|
||||
lines = [
|
||||
'def func(x: float) -> float:',
|
||||
" '''",
|
||||
' blah blah blah',
|
||||
' blah blah blah',
|
||||
" '''",
|
||||
' return x - int(x)',
|
||||
]
|
||||
raw = '\n'.join(lines)
|
||||
run_humaneval_check(raw)
|
||||
|
||||
def test_additional(self):
|
||||
lines = [
|
||||
' return x - int(x)',
|
||||
'',
|
||||
'',
|
||||
'def func(x: float) -> float:',
|
||||
" '''",
|
||||
' blah blah blah',
|
||||
' blah blah blah',
|
||||
" '''",
|
||||
' return x - int(x)',
|
||||
]
|
||||
raw = '\n'.join(lines)
|
||||
run_humaneval_check(raw)
|
Loading…
Reference in New Issue
Block a user