mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Update] Internal humaneval add (#1641)
* [Update] internal_humaneval_add * update
This commit is contained in:
parent
84be90669b
commit
a61e8a0803
@ -0,0 +1,33 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import HumanevalDataset, HumanEvalEvaluator, humaneval_internal_v2_postprocess
|
||||
|
||||
humaneval_reader_cfg = dict(
|
||||
input_columns=['prompt'], output_column='task_id', train_split='test')
|
||||
|
||||
# TODO: allow empty output-column
|
||||
humaneval_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template='# Complete the following python code:\n{prompt}',
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512))
|
||||
|
||||
humaneval_eval_cfg = dict(
|
||||
evaluator=dict(type=HumanEvalEvaluator),
|
||||
pred_role='BOT',
|
||||
k=[1, 10, 100], # the parameter only for humaneval
|
||||
pred_postprocessor=dict(type=humaneval_internal_v2_postprocess),
|
||||
)
|
||||
|
||||
humaneval_datasets = [
|
||||
dict(
|
||||
abbr='openai_humaneval',
|
||||
type=HumanevalDataset,
|
||||
path='opencompass/humaneval',
|
||||
reader_cfg=humaneval_reader_cfg,
|
||||
infer_cfg=humaneval_infer_cfg,
|
||||
eval_cfg=humaneval_eval_cfg)
|
||||
]
|
@ -0,0 +1,33 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import HumanevalDataset, HumanEvalEvaluator, humaneval_internal_v1_postprocess
|
||||
|
||||
humaneval_reader_cfg = dict(
|
||||
input_columns=['prompt'], output_column='task_id', train_split='test')
|
||||
|
||||
# TODO: allow empty output-column
|
||||
humaneval_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template='Complete the following python code:\n{prompt}',
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512))
|
||||
|
||||
humaneval_eval_cfg = dict(
|
||||
evaluator=dict(type=HumanEvalEvaluator),
|
||||
pred_role='BOT',
|
||||
k=[1, 10, 100], # the parameter only for humaneval
|
||||
pred_postprocessor=dict(type=humaneval_internal_v1_postprocess),
|
||||
)
|
||||
|
||||
humaneval_datasets = [
|
||||
dict(
|
||||
abbr='openai_humaneval',
|
||||
type=HumanevalDataset,
|
||||
path='opencompass/humaneval',
|
||||
reader_cfg=humaneval_reader_cfg,
|
||||
infer_cfg=humaneval_infer_cfg,
|
||||
eval_cfg=humaneval_eval_cfg)
|
||||
]
|
@ -198,3 +198,57 @@ def humaneval_internal_v2_postprocess(text: str):
|
||||
break
|
||||
return_list.append(line)
|
||||
return '\n'.join(return_list)
|
||||
|
||||
def humaneval_internal_v1_postprocess(text: str) -> str:
|
||||
"""This is an advanced version of previous postprocess to handle more
|
||||
situations, better to use this one."""
|
||||
try:
|
||||
# for chatGLM related text
|
||||
eval_text = eval(text)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
if isinstance(eval_text, str):
|
||||
text = eval_text
|
||||
text = text.lstrip('\n')
|
||||
if '```' in text:
|
||||
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
|
||||
if len(blocks) == 0:
|
||||
text = text.split('```')[1] # fall back to default strategy
|
||||
else:
|
||||
text = blocks[0] # fetch the first code block
|
||||
if not text.startswith('\n'): # in case starting with ```python
|
||||
text = text[max(text.find('\n') + 1, 0) :]
|
||||
if text.strip().startswith('from') or text.strip().startswith('import'):
|
||||
def_idx = text.find('def')
|
||||
if def_idx != -1:
|
||||
text = text[max(text.find('\n', def_idx) + 1, 0) :]
|
||||
# remove empty lines
|
||||
text = '\n'.join([line for line in text.split('\n') if line != ''])
|
||||
text = text.lstrip('\n')
|
||||
if text.strip().startswith('def'):
|
||||
text = '\n'.join(text.split('\n')[1:])
|
||||
# deal with the indentation error
|
||||
if text.startswith(' '):
|
||||
text = ' ' + text.lstrip()
|
||||
else:
|
||||
text = '\n'.join([' ' + line for line in text.split('\n')])
|
||||
text = text.split('\n')
|
||||
|
||||
# If number of leading space reduces, we assume that the code block ends.
|
||||
min_leading_space = None
|
||||
end_index = None
|
||||
for index, line in enumerate(text):
|
||||
if line.strip() == '' or line.strip()[0] in ["'", '"', '#']:
|
||||
continue
|
||||
current_leading_space = len(line.rstrip()) - len(line.strip())
|
||||
if min_leading_space is None:
|
||||
min_leading_space = current_leading_space
|
||||
elif current_leading_space < min_leading_space:
|
||||
end_index = index
|
||||
break
|
||||
if end_index is not None:
|
||||
text = '\n'.join(text[:end_index])
|
||||
else:
|
||||
text = '\n'.join(text)
|
||||
return text
|
||||
|
Loading…
Reference in New Issue
Block a user