From 2931f3dcb8213d269ba2b9436a6c42c51fc9bfd9 Mon Sep 17 00:00:00 2001 From: Tong Gao Date: Thu, 10 Aug 2023 16:31:12 +0800 Subject: [PATCH] [Enhancement] Add humaneval postprocessor for GPT models & eval config for GPT4, enhance the original humaneval postprocessor (#129) * [Enhancement] Enhance humaneval postprocessor * add human-eval testcase * update * update --------- Co-authored-by: Leymore --- configs/datasets/glm/humaneval.py | 4 +- configs/eval_gpt4.py | 40 +++++++++++ opencompass/datasets/humaneval.py | 44 ++++++++++-- tests/dataset/test_humaneval.py | 110 ++++++++++++++++++++++++++++++ 4 files changed, 191 insertions(+), 7 deletions(-) create mode 100644 configs/eval_gpt4.py create mode 100644 tests/dataset/test_humaneval.py diff --git a/configs/datasets/glm/humaneval.py b/configs/datasets/glm/humaneval.py index ecc99087..1b785f6b 100644 --- a/configs/datasets/glm/humaneval.py +++ b/configs/datasets/glm/humaneval.py @@ -1,7 +1,7 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_inferencer import GenInferencer -from opencompass.datasets import HFDataset, HumanEvaluator +from opencompass.datasets import HFDataset, HumanEvaluator, humaneval_postprocess humaneval_reader_cfg = dict( input_columns=['prompt'], output_column='task_id', train_split='test') @@ -17,7 +17,7 @@ humaneval_infer_cfg = dict( humaneval_eval_cfg = dict( evaluator=dict(type=HumanEvaluator), k=[1, 10, 100], # the parameter only for humaneval - pred_postprocessor=dict(type='humaneval'), + pred_postprocessor=dict(type=humaneval_postprocess), ) humaneval_datasets = [ diff --git a/configs/eval_gpt4.py b/configs/eval_gpt4.py new file mode 100644 index 00000000..9e97d847 --- /dev/null +++ b/configs/eval_gpt4.py @@ -0,0 +1,40 @@ +from mmengine.config import read_base +from opencompass.models import OpenAI +from opencompass.partitioners import NaivePartitioner +from opencompass.runners import LocalRunner +from opencompass.tasks import OpenICLInferTask + +with read_base(): + from .datasets.collections.chat_medium import datasets + from .summarizers.medium import summarizer + +# GPT4 needs a special humaneval postprocessor +from opencompass.datasets.humaneval import humaneval_gpt_postprocess +for _dataset in datasets: + if _dataset['path'] == 'openai_humaneval': + _dataset['eval_cfg']['pred_postprocessor']['type'] = humaneval_gpt_postprocess + + +api_meta_template = dict( + round=[ + dict(role='HUMAN', api_role='HUMAN'), + dict(role='BOT', api_role='BOT', generate=True), + ], +) + +models = [ + dict(abbr='GPT4', + type=OpenAI, path='gpt-4-0613', + key='ENV', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well + meta_template=api_meta_template, + query_per_second=1, + max_out_len=2048, max_seq_len=2048, batch_size=8), +] + +infer = dict( + partitioner=dict(type=NaivePartitioner), + runner=dict( + type=LocalRunner, + max_num_workers=4, + task=dict(type=OpenICLInferTask)), +) diff --git a/opencompass/datasets/humaneval.py b/opencompass/datasets/humaneval.py index 9b1ec10a..a58ce05b 100644 --- a/opencompass/datasets/humaneval.py +++ b/opencompass/datasets/humaneval.py @@ -1,12 +1,11 @@ import os.path as osp +import re import tempfile from typing import List from opencompass.openicl.icl_evaluator import BaseEvaluator -from opencompass.registry import ICL_EVALUATORS, TEXT_POSTPROCESSORS -@ICL_EVALUATORS.register_module() class HumanEvaluator(BaseEvaluator): """Evaluator for human eval.""" @@ -41,11 +40,46 @@ class HumanEvaluator(BaseEvaluator): return {f'humaneval_{k}': score[k] * 100 for k in score} -@TEXT_POSTPROCESSORS.register_module('humaneval') def humaneval_postprocess(text: str) -> str: - text = text.split('\n\n')[0] if '```' in text: - text = text.split('```')[1] + blocks = re.findall(r'```(.*?)```', text, re.DOTALL) + if len(blocks) == 0: + text = text.split('```')[1] # fall back to default strategy + else: + text = blocks[0] # fetch the first code block + if not text.startswith('\n'): # in case starting with ```python + text = text[max(text.find('\n') + 1, 0):] + if text.strip().startswith('from') or text.strip().startswith('import'): + def_idx = text.find('def') + if def_idx != -1: + text = text[max(text.find('\n', def_idx) + 1, 0):] + text = text.split('\n\n')[0] + if text.strip().startswith('def'): + text = '\n'.join(text.split('\n')[1:]) + if not text.startswith(' '): + if text.startswith(' '): + text = ' ' + text.lstrip() + else: + text = '\n'.join([' ' + line for line in text.split('\n')]) + return text + + +def humaneval_gpt_postprocess(text: str) -> str: + """Better answer postprocessor for better instruction-aligned models like + GPT.""" + if '```' in text: + blocks = re.findall(r'```(.*?)```', text, re.DOTALL) + if len(blocks) == 0: + text = text.split('```')[1] # fall back to default strategy + else: + text = blocks[0] # fetch the first code block + if not text.startswith('\n'): # in case starting with ```python + text = text[max(text.find('\n') + 1, 0):] + if text.strip().startswith('from') or text.strip().startswith('import'): + def_idx = text.find('def') + if def_idx != -1: + text = text[max(text.find('\n', def_idx) + 1, 0):] + text = text.split('\n\n\n')[0] if text.strip().startswith('def'): text = '\n'.join(text.split('\n')[1:]) if not text.startswith(' '): diff --git a/tests/dataset/test_humaneval.py b/tests/dataset/test_humaneval.py new file mode 100644 index 00000000..0d1c23a1 --- /dev/null +++ b/tests/dataset/test_humaneval.py @@ -0,0 +1,110 @@ +import unittest + +from opencompass.datasets.humaneval import humaneval_postprocess + + +def run_humaneval_check(completion): + program = [ + 'def get_fraction(x: float) -> float:', + humaneval_postprocess(completion), + '', + 'assert get_fraction(1.28) == 0.28', + 'assert get_fraction(1.0) == 0.0', + ] + program = '\n'.join(program) + exec(program) + + +class TestHumaneval(unittest.TestCase): + + def test_vanilla(self): + raw = ' return x - int(x)' + run_humaneval_check(raw) + + def test_python_quote(self): + lines = [ + '```python', + ' return x - int(x)', + '```', + ] + raw = '\n'.join(lines) + run_humaneval_check(raw) + + def test_bare_quote(self): + lines = [ + '```', + ' return x - int(x)', + '```', + ] + raw = '\n'.join(lines) + run_humaneval_check(raw) + + def test_error_space_quote(self): + lines = [ + '```', + ' return x - int(x)', + '```', + ] + raw = '\n'.join(lines) + run_humaneval_check(raw) + + def test_import_1(self): + lines = [ + 'import numpy as np', + 'import math', + 'from typing import List', + '', + 'def func(x):', + ' return x - int(x)', + ] + raw = '\n'.join(lines) + run_humaneval_check(raw) + + def test_import_2(self): + lines = [ + 'from typing import List', + 'import numpy as np', + 'import math', + 'def func(x):', + ' return x - int(x)', + ] + raw = '\n'.join(lines) + run_humaneval_check(raw) + + def test_import_3(self): + lines = [ + 'import math', + '', + '', + 'def func(x):', + ' return x - int(x)', + ] + raw = '\n'.join(lines) + run_humaneval_check(raw) + + def test_comment(self): + lines = [ + 'def func(x: float) -> float:', + " '''", + ' blah blah blah', + ' blah blah blah', + " '''", + ' return x - int(x)', + ] + raw = '\n'.join(lines) + run_humaneval_check(raw) + + def test_additional(self): + lines = [ + ' return x - int(x)', + '', + '', + 'def func(x: float) -> float:', + " '''", + ' blah blah blah', + ' blah blah blah', + " '''", + ' return x - int(x)', + ] + raw = '\n'.join(lines) + run_humaneval_check(raw)