OpenCompass/opencompass/datasets/livecodebench/extract_utils.py
Songyang Zhang fb43dd1906
[Update] Update Skywork/Qwen-QwQ (#1728)
* Update JuderBench

* Support O1-style Prompts

* Update Code

* Update OpenAI

* Update BigCodeBench

* Update BigCodeBench

* Update BigCodeBench

* Update BigCodeBench

* Update BigCodeBench

* Update
2024-12-05 19:30:43 +08:00

96 lines
3.0 KiB
Python

# Copyright LiveCodeBench @ 2024,
import re
def extract_code_generation(model_output: str, model_type: str = 'chat'):
# modified from
outputlines = model_output.split('\n')
# TODO: handle codellama
if model_type == 'base':
return model_output.strip()
elif model_type == 'chat':
indexlines = [i for i, line in enumerate(outputlines) if '```' in line]
else:
raise ValueError(f'Invalid mode type: {model_type}')
if len(indexlines) < 2:
return ''
return '\n'.join(outputlines[indexlines[0] + 1:indexlines[1]])
def extract_code_generation_v2(model_output: str, model_type: str = 'chat'):
# modified from
outputlines = model_output.split('\n')
# TODO: handle codellama
if model_type == 'base':
return model_output.strip()
elif model_type == 'chat':
indexlines = [i for i, line in enumerate(outputlines) if '```' in line]
else:
raise ValueError(f'Invalid mode type: {model_type}')
if len(indexlines) < 2:
return ''
elif len(indexlines) > 2:
# Only Keep the last code block
indexlines = indexlines[-2:]
return '\n'.join(outputlines[indexlines[0] + 1:indexlines[1]])
def extract_code_execution(model_output: str, cot: bool = False):
pattern = r'\[PYTHON\](.*?)\[\/PYTHON\]'
matches = re.findall(pattern, model_output, re.DOTALL)
if matches:
# fetch the last one
model_output = matches[-1]
if '[PYTHON]' in model_output:
model_output
if cot:
if '[ANSWER]' in model_output:
model_output = model_output.split('[ANSWER]')[1].strip()
if '==' in model_output:
model_output = model_output.split('==')[1].strip()
if '[/ANSWER]' in model_output:
model_output = model_output.split('[/ANSWER]')[0].strip()
else:
model_output = model_output.split('\n')[0].strip()
return model_output.strip()
def extract_test_output_code(model_output: str):
outputlines = model_output.split('\n')
# find the last line startwith assert...
indexlines = [
i for i, line in enumerate(outputlines) if line.startswith('assert')
]
if indexlines:
return outputlines[indexlines[-1]]
# TODO: handle codellama format
# if lmstyle and lmstyle == LMStyle.CodeLLaMaInstruct:
# indexlines = \
# [i for i, line in enumerate(outputlines) if "PYTHON]" in line]
# else:
# first try to extract ```python if not then try ```
indexlines = [
i for i, line in enumerate(outputlines)
if '```python' in line or '```Python' in line
]
if indexlines:
start_index = indexlines[0]
else:
start_index = None
indexlines = [i for i, line in enumerate(outputlines) if '```' in line]
if start_index is not None:
indexlines = [i for i in indexlines if i > start_index]
indexlines = [start_index] + indexlines
if len(indexlines) < 2:
return ''
return '\n'.join(outputlines[indexlines[0] + 1:indexlines[1]])