[Feat] Add codegeex2 and Humanevalx (#210)

* add codegeex2

* add humanevalx dataset

* add evaluator

* update evaluator

* update configs

* update clean code

* update configs

* fix lint

* remove sleep

* fix lint

* update docs

* fix lint
This commit is contained in:
Ezra-Yu 2023-08-17 11:03:16 +08:00 committed by GitHub
parent 0fe2366a72
commit 17ccaa5980
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 448 additions and 9 deletions

View File

@ -0,0 +1,4 @@
from mmengine.config import read_base
with read_base():
from .humanevalx_gen_fd5822 import humanevalx_datasets # noqa: F401, F403

View File

@ -0,0 +1,37 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import HumanevalXDataset, HumanevalXEvaluator
humanevalx_reader_cfg = dict(
input_columns=['prompt'], output_column='task_id', train_split='test')
humanevalx_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template='{prompt}'),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=1024))
humanevalx_eval_cfg_dict = {
lang : dict(
evaluator=dict(
type=HumanevalXEvaluator,
language=lang,
ip_address="localhost", # replace to your code_eval_server ip_address, port
port=5000), # refer to https://github.com/Ezra-Yu/code-evaluator to launch a server
pred_role='BOT')
for lang in ['python', 'cpp', 'go', 'java', 'js'] # do not support rust now
}
humanevalx_datasets = [
dict(
type=HumanevalXDataset,
abbr=f'humanevalx-{lang}',
language=lang,
path='./data/humanevalx',
reader_cfg=humanevalx_reader_cfg,
infer_cfg=humanevalx_infer_cfg,
eval_cfg=humanevalx_eval_cfg_dict[lang])
for lang in ['python', 'cpp', 'go', 'java', 'js']
]

View File

@ -0,0 +1,7 @@
from mmengine.config import read_base
with read_base():
from .datasets.humanevalx.humanevalx_gen import humanevalx_datasets
from .models.hf_codegeex2_6b import models
datasets = humanevalx_datasets

View File

@ -0,0 +1,25 @@
from opencompass.models import HuggingFace
# refer to https://github.com/THUDM/CodeGeeX2/tree/main
# For pass@1 : n=20 , temperature=0.2, top_p=0.95
# For Pass@10 : n=200, temperature=0.8, top_p=0.95
# For Pass@100 : n=200, temperature=0.8, top_p=0.95
models = [
dict(
type=HuggingFace,
abbr='codegeex2-6b',
path='THUDM/codegeex2-6b',
tokenizer_path='THUDM/codegeex2-6b',
tokenizer_kwargs=dict(
padding_side='left',
truncation_side='left',
trust_remote_code=True,
),
max_out_len=1024,
max_seq_len=2048,
batch_size=8,
model_kwargs=dict(trust_remote_code=True, device_map='auto'),
run_cfg=dict(num_gpus=1, num_procs=1),
)
]

View File

@ -0,0 +1,85 @@
# Code Evaluation Service
We support evaluating datasets of multiple programming languages, similar to [humaneval-x](https://huggingface.co/datasets/THUDM/humaneval-x). Before starting, make sure that you have started the code evaluation service. You can refer to the [code-evaluator](https://github.com/Ezra-Yu/code-evaluator) project for the code evaluation service.
## Launching the Code Evaluation Service
Make sure you have installed Docker, then build an image and run a container service.
Build the Docker image:
```shell
git clone https://github.com/Ezra-Yu/code-evaluator.git
cd code-evaluator/docker
sudo docker build -t code-eval:latest .
```
After obtaining the image, create a container with the following commands:
```shell
# Log output format
sudo docker run -it -p 5000:5000 code-eval:latest python server.py
# Run the program in the background
# sudo docker run -itd -p 5000:5000 code-eval:latest python server.py
# Using different ports
# sudo docker run -itd -p 5001:5001 code-eval:latest python server.py --port 5001
```
Ensure that you can access the service and check the following commands (skip this step if you are running the service on a local host):
```shell
ping your_service_ip_address
telnet your_service_ip_address your_service_port
```
```note
If computing nodes cannot connect to the evaluation service, you can directly run `python run.py xxx...`. The resulting code will be saved in the 'outputs' folder. After migration, use [code-evaluator](https://github.com/Ezra-Yu/code-evaluator) directly to get the results (no need to consider the eval_cfg configuration later).
```
## Configuration File
We have provided the [configuration file](https://github.com/InternLM/opencompass/blob/main/configs/eval_codegeex2.py) for evaluating huamaneval-x on codegeex2 .
The dataset and related post-processing configuration files can be found at this [link](https://github.com/InternLM/opencompass/tree/main/configs/datasets/humanevalx). Note the `evaluator` field in `humanevalx_eval_cfg_dict`.
```python
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import HumanevalXDataset, HumanevalXEvaluator
humanevalx_reader_cfg = dict(
input_columns=['prompt'], output_column='task_id', train_split='test')
humanevalx_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template='{prompt}'),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=1024))
humanevalx_eval_cfg_dict = {
lang : dict(
evaluator=dict(
type=HumanevalXEvaluator,
language=lang,
ip_address="localhost", # replace to your code_eval_server ip_address, port
port=5000), # refer to https://github.com/Ezra-Yu/code-evaluator to launch a server
pred_role='BOT')
for lang in ['python', 'cpp', 'go', 'java', 'js'] # do not support rust now
}
humanevalx_datasets = [
dict(
type=HumanevalXDataset,
abbr=f'humanevalx-{lang}',
language=lang,
path='./data/humanevalx',
reader_cfg=humanevalx_reader_cfg,
infer_cfg=humanevalx_infer_cfg,
eval_cfg=humanevalx_eval_cfg_dict[lang])
for lang in ['python', 'cpp', 'go', 'java', 'js']
]
```

View File

@ -37,15 +37,6 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass.
user_guides/experimentation.md
user_guides/metrics.md
.. _AdvancedGuides:
.. toctree::
:maxdepth: 1
:caption: Advanced Guides
advanced_guides/new_dataset.md
advanced_guides/new_model.md
advanced_guides/evaluation_turbomind.md
.. _Prompt:
.. toctree::
:maxdepth: 1
@ -56,6 +47,17 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass.
prompt/meta_template.md
prompt/chain_of_thought.md
.. _AdvancedGuides:
.. toctree::
:maxdepth: 1
:caption: Advanced Guides
advanced_guides/new_dataset.md
advanced_guides/new_model.md
advanced_guides/evaluation_turbomind.md
advanced_guides/code_eval_service.md
.. _Tools:
.. toctree::
:maxdepth: 1

View File

@ -0,0 +1,86 @@
# 代码评测服务
我们支持评测多编程语言的数据集,类似 [humaneval-x](https://huggingface.co/datasets/THUDM/humaneval-x). 在启动之前需要确保你已经启动了代码评测服务,代码评测服务可参考[code-evaluator](https://github.com/Ezra-Yu/code-evaluator)项目。
## 启动代码评测服务
确保您已经安装了 docker然后构建一个镜像并运行一个容器服务。
构建 Docker 镜像:
```shell
git clone https://github.com/Ezra-Yu/code-evaluator.git
cd code-evaluator/docker
sudo docker build -t code-eval:latest .
```
获取镜像后,使用以下命令创建容器:
```shell
# 输出日志格式
sudo docker run -it -p 5000:5000 code-eval:latest python server.py
# 在后台运行程序
# sudo docker run -itd -p 5000:5000 code-eval:latest python server.py
# 使用不同的端口
# sudo docker run -itd -p 5001:5001 code-eval:latest python server.py --port 5001
```
确保您能够访问服务,检查以下命令(如果在本地主机中运行服务,就跳过这个操作)
```shell
ping your_service_ip_address
telnet your_service_ip_address your_service_port
```
```note
如果运算节点不能连接到评估服务,也可直接运行 `python run.py xxx...`,代码生成结果会保存在 'outputs' 文件夹下,迁移后直接使用 [code-evaluator](https://github.com/Ezra-Yu/code-evaluator) 评测得到结果(不需要考虑后面 eval_cfg 的配置)。
```
## 配置文件
我么已经给了 huamaneval-x 在 codegeex2 上评估的[配置文件](https://github.com/InternLM/opencompass/blob/main/configs/eval_codegeex2.py)。
其中数据集以及相关后处理的配置文件为这个[链接](https://github.com/InternLM/opencompass/tree/main/configs/datasets/humanevalx) 需要注意 `humanevalx_eval_cfg_dict` 中的
`evaluator` 字段。
```python
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import HumanevalXDataset, HumanevalXEvaluator
humanevalx_reader_cfg = dict(
input_columns=['prompt'], output_column='task_id', train_split='test')
humanevalx_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template='{prompt}'),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=1024))
humanevalx_eval_cfg_dict = {
lang : dict(
evaluator=dict(
type=HumanevalXEvaluator,
language=lang,
ip_address="localhost", # replace to your code_eval_server ip_address, port
port=5000), # refer to https://github.com/Ezra-Yu/code-evaluator to launch a server
pred_role='BOT')
for lang in ['python', 'cpp', 'go', 'java', 'js'] # do not support rust now
}
humanevalx_datasets = [
dict(
type=HumanevalXDataset,
abbr=f'humanevalx-{lang}',
language=lang,
path='./data/humanevalx',
reader_cfg=humanevalx_reader_cfg,
infer_cfg=humanevalx_infer_cfg,
eval_cfg=humanevalx_eval_cfg_dict[lang])
for lang in ['python', 'cpp', 'go', 'java', 'js']
]
```

View File

@ -56,6 +56,7 @@ OpenCompass 上手路线
advanced_guides/new_dataset.md
advanced_guides/new_model.md
advanced_guides/evaluation_turbomind.md
advanced_guides/code_eval_service.md
.. _工具:
.. toctree::

View File

@ -31,6 +31,7 @@ from .gsm8k import * # noqa: F401, F403
from .hellaswag import * # noqa: F401, F403
from .huggingface import * # noqa: F401, F403
from .humaneval import * # noqa: F401, F403
from .humanevalx import * # noqa: F401, F403
from .iwslt2017 import * # noqa: F401, F403
from .jigsawmultilingual import * # noqa: F401, F403
from .lambada import * # noqa: F401, F403

View File

@ -0,0 +1,191 @@
import gzip
import json
import os
import os.path as osp
import re
import subprocess
import tempfile
from shutil import copyfile
from typing import Dict, Iterable
from datasets import Dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from .base import BaseDataset
_LANGUAGE_NAME_DICT = {
'cpp': 'CPP',
'go': 'Go',
'java': 'Java',
'js': 'JavaScript',
'python': 'Python',
'rust': 'Rust',
}
class HumanevalXDataset(BaseDataset):
@staticmethod
def load(path, language, **kwargs):
assert language in _LANGUAGE_NAME_DICT.keys(), (
f'language must be in {list(_LANGUAGE_NAME_DICT.keys())}')
file_path = osp.join(path, f'humanevalx_{language}.jsonl.gz')
dataset = HumanevalXDataset._stream_jsonl_all(file_path)
return Dataset.from_list(dataset)
@staticmethod
def _stream_jsonl_all(filename: str) -> Iterable[Dict]:
results = []
if filename.endswith('.gz'):
fp = gzip.open(open(filename, 'rb'), 'rt')
else:
fp = open(filename, 'r')
for line in fp:
if any(not x.isspace() for x in line):
results.append(json.loads(line))
fp.close()
return results
class HumanevalXEvaluator(BaseEvaluator):
"""Evaluator for humanevalx.
Before you use this Evaluator, launch a code eval service according
to to readme of https://github.com/Ezra-Yu/code-evaluator.
Set `ip_address` and `port` according your environment.
Args:
language (str): the program language to evaluate.
ip_address (str): The IP Address of HumanevalX code evaluate service.
refer to https://github.com/Ezra-Yu/code-evaluator to launch a
code evaluate service. Defaults to 'localhost'.
port (int): The port of HumanevalX code evaluate service.
Defaults to 5000.
timeout (int): Maximum wait time when accessing the service,
Defaults to 100.
TODO: support 'k' of pass@k. default to use k = [1, 10, 100]
"""
def __init__(self,
language,
ip_address='localhost',
port=5000,
timeout=180) -> None:
assert language in _LANGUAGE_NAME_DICT.keys(), (
f'language must be in {list(_LANGUAGE_NAME_DICT.keys())}')
if language == 'rust':
timeout *= 10 # rust need more time
self.language = language
self.ip_address = ip_address
self.port = port
self.timeout = timeout
super().__init__()
def score(self, predictions, references):
predictions = [{
'task_id': f'{_LANGUAGE_NAME_DICT[self.language]}/{i}',
'generation': _clean_up_code(pred, self.language),
} for i, pred in enumerate(predictions)]
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_out_path = osp.join(tmp_dir,
f'humanevalx_{self.language}.json')
with open(tmp_out_path, 'w') as f:
for pred in predictions:
f.write(json.dumps(pred) + '\n')
succeed, output = self._code_eval_service(file_path=tmp_out_path)
if succeed:
if isinstance(output, str):
return json.loads(output)
elif isinstance(output, dict):
return output
ref_url = 'https://github.com/Ezra-Yu/code-evaluator'
result_file_path = os.path.join(
'outputs', f'humanevalx_{self.language}.json')
copyfile(tmp_out_path, result_file_path)
raise Exception(
f'Call CodeEvalService Error in `HumanevalXEvaluator`, The '
f"results have been saved in path '{result_file_path}', You "
'need to check that your code evaluate service is launched and'
f' the network to service is connected, you can also get '
f'results directly by using `curl` command refer to {ref_url}.'
f'\nError Information: {output}')
def _code_eval_service(self, file_path):
exec_result = subprocess.run([
'curl', '-X', 'POST', '-F', f'file=@{file_path}', '-F',
f'dataset=humanevalx/{self.language}',
f'{self.ip_address}:{self.port}/evaluate'
],
timeout=self.timeout,
capture_output=True)
if exec_result.returncode == 0 and re.match(
"\"{.*:.*}\"", exec_result.stdout.decode('utf-8')):
return True, json.loads(exec_result.stdout.decode('utf-8'))
else:
if exec_result.stderr:
try:
err = exec_result.stderr.decode()
except Exception:
err = exec_result.stderr
else:
try:
err = exec_result.stdout.decode()
except Exception:
err = exec_result.stdout
return False, err
def _clean_up_code(text: str, language_type: str) -> str:
"""Cleans up the generated code."""
if language_type.lower() == 'python':
text_splits = text.split('\n')
is_empty_line = False
ind_empty_line = None
for i, line in enumerate(text_splits):
if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t':
is_empty_line = True
ind_empty_line = i
break
if is_empty_line:
text = '\n'.join(text_splits[:ind_empty_line])
else:
end_words = [
'\ndef', '\nclass', '\n#', '\nassert', '\n"""', '\nprint',
'\nif', '\n\n\n'
]
for w in end_words:
if w in text:
text = text[:text.rfind(w)]
elif language_type.lower() == 'java':
main_pos = text.find('public static void main')
if main_pos != -1:
text = text[:main_pos] + '}'
if '}' in text:
text = text[:text.rfind('}')] + '}'
if text.count('{') + 1 == text.count('}'):
text += '\n}'
elif language_type.lower() == 'go':
if '\nfunc main(' in text:
text = text[:text.rfind('func main(')]
if '}' in text:
text = text[:text.rfind('}')] + '}'
elif language_type.lower() == 'cpp':
if '\nint main()' in text:
text = text[:text.rfind('int main()')]
if '}' in text:
text = text[:text.rfind('}')] + '}'
elif language_type.lower() == 'js':
if '}' in text:
text = text[:text.rfind('}')] + '}'
elif language_type.lower() == 'rust':
if '}' in text:
text = text[:text.rfind('}')] + '}'
return text