[Feat] support humaneval and mbpp pass@k (#598)

* [Feat] support pass@ k

* [Feat] support pass@k

* [Feat] support pass@k

* [Feat] support pass@k

* [Feat] support pass@k

* [Feat] support pass@k docs

* update naming

---------

Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
Hubert 2023-11-16 21:22:06 +08:00 committed by GitHub
parent c0acd06b05
commit 91fba2c2e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 622 additions and 65 deletions

View File

@ -0,0 +1,54 @@
# This config is used for pass@k evaluation with `num_return_sequences`
# That model can generate multiple responses for single input
from mmengine.config import read_base
from opencompass.partitioners import SizePartitioner
from opencompass.models import HuggingFaceCausalLM
from opencompass.runners import LocalRunner
from opencompass.partitioners import SizePartitioner
from opencompass.tasks import OpenICLInferTask
from opencompass.datasets import MBPPDataset_V2, MBPPPassKEvaluator
with read_base():
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
from .datasets.mbpp.mbpp_gen_1e1056 import mbpp_datasets
mbpp_datasets[0]['type'] = MBPPDataset_V2
mbpp_datasets[0]['eval_cfg']['evaluator']['type'] = MBPPPassKEvaluator
mbpp_datasets[0]['reader_cfg']['output_column'] = 'test_column'
datasets = []
datasets += humaneval_datasets
datasets += mbpp_datasets
models = [
dict(
type=HuggingFaceCausalLM,
abbr='CodeLlama-7b-Python',
path="codellama/CodeLlama-7b-Python-hf",
tokenizer_path='codellama/CodeLlama-7b-Python-hf',
tokenizer_kwargs=dict(
padding_side='left',
truncation_side='left',
trust_remote_code=True,
),
max_out_len=1024,
max_seq_len=2048,
batch_size=8,
model_kwargs=dict(trust_remote_code=True, device_map='auto'),
generation_kwargs=dict(
num_return_sequences=10,
do_sample=True,
top_p=0.95,
temperature=0.8,
),
run_cfg=dict(num_gpus=1, num_procs=1),
),
]
infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=300),
runner=dict(
type=LocalRunner, max_num_workers=16,
task=dict(type=OpenICLInferTask)),
)

View File

@ -0,0 +1,65 @@
# This config is used for pass@k evaluation with dataset repetition
# That model cannot generate multiple response for single input
from mmengine.config import read_base
from opencompass.partitioners import SizePartitioner
from opencompass.models import HuggingFaceCausalLM
from opencompass.runners import LocalRunner
from opencompass.partitioners import SizePartitioner
from opencompass.tasks import OpenICLInferTask
from opencompass.datasets import MBPPDataset_V2, MBPPPassKEvaluator
with read_base():
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
from .datasets.mbpp.mbpp_gen_1e1056 import mbpp_datasets
humaneval_datasets[0]['abbr'] = 'openai_humaneval_pass10'
humaneval_datasets[0]['num_repeats'] = 10
mbpp_datasets[0]['abbr'] = 'mbpp_pass10'
mbpp_datasets[0]['num_repeats'] = 10
mbpp_datasets[0]['type'] = MBPPDataset_V2
mbpp_datasets[0]['eval_cfg']['evaluator']['type'] = MBPPPassKEvaluator
mbpp_datasets[0]['reader_cfg']['output_column'] = 'test_column'
datasets = []
datasets += humaneval_datasets
datasets += mbpp_datasets
_meta_template = dict(
round=[
dict(role="HUMAN", begin="<|User|>:", end="\n"),
dict(role="BOT", begin="<|Bot|>:", end="<eoa>\n", generate=True),
],
)
models = [
dict(
abbr="internlm-chat-7b-hf-v11",
type=HuggingFaceCausalLM,
path="internlm/internlm-chat-7b-v1_1",
tokenizer_path="internlm/internlm-chat-7b-v1_1",
tokenizer_kwargs=dict(
padding_side="left",
truncation_side="left",
use_fast=False,
trust_remote_code=True,
),
max_seq_len=2048,
meta_template=_meta_template,
model_kwargs=dict(trust_remote_code=True, device_map="auto"),
generation_kwargs=dict(
do_sample=True,
top_p=0.95,
temperature=0.8,
),
run_cfg=dict(num_gpus=1, num_procs=1),
batch_size=8,
)
]
infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=600),
runner=dict(
type=LocalRunner, max_num_workers=16,
task=dict(type=OpenICLInferTask)),
)

View File

@ -0,0 +1,104 @@
# Code Evaluation Tutorial
This tutorial primarily focuses on evaluating a model's coding proficiency, using `humaneval` and `mbpp` as examples.
## pass@1
If you only need to generate a single response to evaluate the pass@1 performance, you can directly use [configs/datasets/humaneval/humaneval_gen_8e312c.py](https://github.com/open-compass/opencompass/blob/main/configs/datasets/humaneval/humaneval_gen_8e312c.py) and [configs/datasets/mbpp/mbpp_gen_1e1056.py](https://github.com/open-compass/opencompass/blob/main/configs/datasets/mbpp/mbpp_gen_1e1056.py), referring to the general [quick start tutorial](../get_started/quick_start.md).
For multilingual evaluation, please refer to the [Multilingual Code Evaluation Tutorial](./code_eval_service.md).
## pass@k
If you need to generate multiple responses for a single example to evaluate the pass@k performance, consider the following two situations. Here we take 10 responses as an example:
### Typical Situation
For most models that support the `num_return_sequences` parameter in HF's generation, we can use it directly to obtain multiple responses. Refer to the following configuration file:
```python
from opencompass.datasets import MBPPDataset_V2, MBPPPassKEvaluator
with read_base():
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
from .datasets.mbpp.mbpp_gen_1e1056 import mbpp_datasets
mbpp_datasets[0]['type'] = MBPPDataset_V2
mbpp_datasets[0]['eval_cfg']['evaluator']['type'] = MBPPPassKEvaluator
mbpp_datasets[0]['reader_cfg']['output_column'] = 'test_column'
datasets = []
datasets += humaneval_datasets
datasets += mbpp_datasets
models = [
dict(
type=HuggingFaceCausalLM,
...,
generation_kwargs=dict(
num_return_sequences=10,
do_sample=True,
top_p=0.95,
temperature=0.8,
),
...,
)
]
```
For `mbpp`, new changes are needed in the dataset and evaluation, so we simultaneously modify the `type`, `eval_cfg.evaluator.type`, `reader_cfg.output_column` fields to accommodate these requirements.
We also need model responses with randomness, thus setting the `generation_kwargs` parameter is necessary. Note that we need to set `num_return_sequences` to get the number of responses.
Note: `num_return_sequences` must be greater than or equal to k, as pass@k itself is a probability estimate.
You can specifically refer to the following configuration file [configs/eval_code_passk.py](https://github.com/open-compass/opencompass/blob/main/configs/eval_code_passk.py)
### For Models That Do Not Support Multiple Responses
This applies to some HF models with poorly designed APIs or missing features. In this case, we need to repeatedly construct datasets to achieve multiple response effects. Refer to the following configuration:
```python
from opencompass.datasets import MBPPDataset_V2, MBPPPassKEvaluator
with read_base():
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
from .datasets.mbpp.mbpp_gen_1e1056 import mbpp_datasets
humaneval_datasets[0]['abbr'] = 'openai_humaneval_pass10'
humaneval_datasets[0]['num_repeats'] = 10
mbpp_datasets[0]['abbr'] = 'mbpp_pass10'
mbpp_datasets[0]['num_repeats'] = 10
mbpp_datasets[0]['type'] = MBPPDataset_V2
mbpp_datasets[0]['eval_cfg']['evaluator']['type'] = MBPPPassKEvaluator
mbpp_datasets[0]['reader_cfg']['output_column'] = 'test_column'
datasets = []
datasets += humaneval_datasets
datasets += mbpp_datasets
models = [
dict(
type=HuggingFaceCausalLM,
...,
generation_kwargs=dict(
do_sample=True,
top_p=0.95,
temperature=0.8,
),
...,
)
]
```
Since the dataset's prompt has not been modified, we need to replace the corresponding fields to achieve the purpose of repeating the dataset.
You need to modify these fields:
- `num_repeats`: the number of times the dataset is repeated
- `abbr`: It's best to modify the dataset abbreviation along with the number of repetitions because the number of datasets will change, preventing potential issues arising from discrepancies with the values in `.cache/dataset_size.json`.
For `mbpp`, modify the `type`, `eval_cfg.evaluator.type`, `reader_cfg.output_column` fields as well.
We also need model responses with randomness, thus setting the `generation_kwargs` parameter is necessary.
You can specifically refer to the following configuration file [configs/eval_code_passk_repeat_dataset.py](https://github.com/open-compass/opencompass/blob/main/configs/eval_code_passk_repeat_dataset.py)

View File

@ -1,4 +1,4 @@
# Code Evaluation Tutorial
# Multilingual Code Evaluation Tutorial
To complete LLM code capability evaluation, we need to set up an independent evaluation environment to avoid executing erroneous codes on development environments which would cause unavoidable losses. The current Code Evaluation Service used in OpenCompass refers to the project [code-evaluator](https://github.com/open-compass/code-evaluator.git), which has already supported evaluating datasets for multiple programming languages [humaneval-x](https://huggingface.co/datasets/THUDM/humaneval-x). The following tutorials will introduce how to conduct code review services under different requirements.

View File

@ -0,0 +1,106 @@
# 代码评测教程
这里以 `humaneval``mbpp` 为例,主要介绍如何评测模型的代码能力。
## pass@1
如果只需要生成单条回复来评测pass@1的性能可以直接使用[configs/datasets/humaneval/humaneval_gen_8e312c.py](https://github.com/open-compass/opencompass/blob/main/configs/datasets/humaneval/humaneval_gen_8e312c.py) 和 [configs/datasets/mbpp/mbpp_gen_1e1056.py](https://github.com/open-compass/opencompass/blob/main/configs/datasets/mbpp/mbpp_gen_1e1056.py) 并参考通用的[快速上手教程](../get_started/quick_start.md)即可。
如果要进行多语言评测,可以参考[多语言代码评测教程](./code_eval_service.md)。
## pass@k
如果对于单个example需要生成多条回复来评测pass@k的性能需要参考以下两种情况。这里以10回复为例子
### 通常情况
对于绝大多数模型来说模型支持HF的generation中带有`num_return_sequences` 参数,我们可以直接使用来获取多回复。可以参考以下配置文件。
```python
from opencompass.datasets import MBPPDataset_V2, MBPPPassKEvaluator
with read_base():
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
from .datasets.mbpp.mbpp_gen_1e1056 import mbpp_datasets
mbpp_datasets[0]['type'] = MBPPDataset_V2
mbpp_datasets[0]['eval_cfg']['evaluator']['type'] = MBPPPassKEvaluator
mbpp_datasets[0]['reader_cfg']['output_column'] = 'test_column'
datasets = []
datasets += humaneval_datasets
datasets += mbpp_datasets
models = [
dict(
type=HuggingFaceCausalLM,
...,
generation_kwargs=dict(
num_return_sequences=10,
do_sample=True,
top_p=0.95,
temperature=0.8,
),
...,
)
]
```
对于 `mbpp`,在数据集和评测上需要有新的变更,所以同步修改`type`, `eval_cfg.evaluator.type`, `reader_cfg.output_column` 字段来适应新的需求。
另外我们需要模型的回复有随机性,同步需要设置`generation_kwargs`参数。这里注意要设置`num_return_sequences`得到回复数。
注意:`num_return_sequences` 必须大于等于k本身pass@k是计算的概率估计。
具体可以参考以下配置文件
[configs/eval_code_passk.py](https://github.com/open-compass/opencompass/blob/main/configs/eval_code_passk.py)
### 模型不支持多回复
适用于一些没有设计好的API以及功能缺失的HF模型。这个时候我们需要重复构造数据集来达到多回复的效果。这里可以参考以下配置文件。
```python
from opencompass.datasets import MBPPDataset_V2, MBPPPassKEvaluator
with read_base():
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
from .datasets.mbpp.mbpp_gen_1e1056 import mbpp_datasets
humaneval_datasets[0]['abbr'] = 'openai_humaneval_pass10'
humaneval_datasets[0]['num_repeats'] = 10
mbpp_datasets[0]['abbr'] = 'mbpp_pass10'
mbpp_datasets[0]['num_repeats'] = 10
mbpp_datasets[0]['type'] = MBPPDataset_V2
mbpp_datasets[0]['eval_cfg']['evaluator']['type'] = MBPPPassKEvaluator
mbpp_datasets[0]['reader_cfg']['output_column'] = 'test_column'
datasets = []
datasets += humaneval_datasets
datasets += mbpp_datasets
models = [
dict(
type=HuggingFaceCausalLM,
...,
generation_kwargs=dict(
do_sample=True,
top_p=0.95,
temperature=0.8,
),
...,
)
]
```
由于数据集的prompt并没有修改我们需要替换对应的字段来达到数据集重复的目的。
需要修改以下字段:
- `num_repeats`: 数据集重复的次数
- `abbr`: 数据集的缩写最好随着重复次数一并修改,因为数据集数量会发生变化,防止与`.cache/dataset_size.json` 中的数值出现差异导致一些潜在的问题。
对于 `mbpp`,同样修改`type`, `eval_cfg.evaluator.type`, `reader_cfg.output_column` 字段。
另外我们需要模型的回复有随机性,同步需要设置`generation_kwargs`参数。
具体可以参考以下配置文件
[configs/eval_code_passk_repeat_dataset.py](https://github.com/open-compass/opencompass/blob/main/configs/eval_code_passk_repeat_dataset.py)

View File

@ -1,4 +1,4 @@
# 代码评测教程
# 多语言代码评测教程
为了完成LLM代码能力评测我们需要搭建一套独立的评测环境避免在开发环境执行错误代码从而造成不可避免的损失。目前 OpenCompass 使用的代码评测服务可参考[code-evaluator](https://github.com/open-compass/code-evaluator)项目,并已经支持评测多编程语言的数据集 [humaneval-x](https://huggingface.co/datasets/THUDM/humaneval-x)。接下来将围绕代码评测服务介绍不同需要下的评测教程。

View File

@ -16,11 +16,27 @@ from .base import BaseDataset
class HumanevalDataset(BaseDataset):
@staticmethod
def load(path):
def load(path: str, num_repeats: int = 1):
"""Load humaneval dataset for pass k mode.
Note that you can use num_repeats > 1 when your model does not support
`num_return_sequence` in generation, otherwise use the raw
humaneval dataset and set `num_return_sequence` in model config to
generate multiple responses for testing pass@k>1.
It better to change your dataset abbr correspondingly if you want to
change num_repeats>1, otherwise the number in
`.cache/dataset_size.json` might be inconsistent.
Args:
num_repeats(int): Number of repetition for this dataset to get
multiple responses in special cases.
"""
dataset = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
dataset.append(json.loads(line.strip()))
dataset.extend(
[json.loads(line.strip()) for _ in range(num_repeats)])
return Dataset.from_list(dataset)
@ -42,14 +58,19 @@ class HumanEvaluator(BaseEvaluator):
super().__init__()
def score(self, predictions, references):
predictions = [{
'task_id': f'HumanEval/{i}',
'completion': predictions[i]
} for i in range(len(predictions))]
humaneval_preds = []
# create json file in human_eval format
for preds, refer in zip(predictions, references):
# suits for two case
# 1. use repeated dataset
# 2. use `num_return_sequences` to generate multiple responses
if not isinstance(preds, list):
preds = [preds]
for pred in preds:
humaneval_preds.append({'task_id': refer, 'completion': pred})
with tempfile.TemporaryDirectory() as tmp_dir:
out_dir = osp.join(tmp_dir, 'human_eval.json')
self.write_jsonl(out_dir, predictions)
self.write_jsonl(out_dir, humaneval_preds)
score = self.eval(out_dir,
self.k,
n_workers=4,

View File

@ -1,9 +1,15 @@
import contextlib
import io
import itertools
import multiprocessing
import re
import signal
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Sequence, Union
from datasets import DatasetDict, load_dataset
import numpy as np
from datasets import DatasetDict, concatenate_datasets, load_dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
@ -30,10 +36,89 @@ class MBPPDataset(BaseDataset):
return DatasetDict({'train': train, 'test': test})
class MBPPDataset_V2(BaseDataset):
@staticmethod
def load(path: str, num_repeats: int = 1):
"""Load mbpp dataset for pass k mode.
Note that you can use num_repeats > 1 when your model does not support
`num_return_sequence` in generation, otherwise use the raw
mbpp dataset and set `num_return_sequence` in model config to
generate multiple responses for testing pass@k>1.
It better to change your dataset abbr correspondingly if you want to
change num_repeats>1, otherwise the number in
`.cache/dataset_size.json` might be inconsistent.
Args:
num_repeats(int): Number of repetition for this dataset to get
multiple responses in special cases.
"""
def processing_test(example):
example['test_case'] = example['test_list']
example['test_list'] = '\n'.join(example['test_list'])
example['test_column'] = dict(test_list_2=example['test_list'],
task_id=example['task_id'])
return example
train = load_dataset('json', data_files=path,
split='train[:10]').map(processing_test)
test = load_dataset('json', data_files=path,
split='train[10:510]').map(processing_test)
test = concatenate_datasets([test] * num_repeats)
return DatasetDict({'train': train, 'test': test})
class TimeOutException(Exception):
pass
@contextlib.contextmanager
def swallow_io():
stream = WriteOnlyStringIO()
with contextlib.redirect_stdout(stream):
with contextlib.redirect_stderr(stream):
with redirect_stdin(stream):
yield
@contextlib.contextmanager
def time_limit(seconds: float):
def signal_handler(signum, frame):
raise TimeOutException('Time out!')
signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
class WriteOnlyStringIO(io.StringIO):
"""StringIO that throws an exception when it's read from."""
def read(self, *args, **kwargs):
raise IOError
def readline(self, *args, **kwargs):
raise IOError
def readlines(self, *args, **kwargs):
raise IOError
def readable(self, *args, **kwargs):
"""Returns True if the IO object can be read."""
return False
class redirect_stdin(contextlib._RedirectStream): # type: ignore
_stream = 'stdin'
@ICL_EVALUATORS.register_module()
class MBPPEvaluator(BaseEvaluator):
@ -48,8 +133,8 @@ class MBPPEvaluator(BaseEvaluator):
# Add exec globals to prevent the exec to raise
# unnecessary NameError for correct answer
exec_globals = {}
with self.swallow_io():
with self.time_limit(2):
with swallow_io():
with time_limit(2):
exec(programs, exec_globals)
result['pass'] += 1
except TimeOutException:
@ -82,46 +167,6 @@ class MBPPEvaluator(BaseEvaluator):
formatted += test_case
return formatted
@contextlib.contextmanager
def swallow_io(self):
stream = self.WriteOnlyStringIO()
with contextlib.redirect_stdout(stream):
with contextlib.redirect_stderr(stream):
with self.redirect_stdin(stream):
yield
@contextlib.contextmanager
def time_limit(self, seconds: float):
def signal_handler(signum, frame):
raise TimeOutException('Time out!')
signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
class WriteOnlyStringIO(io.StringIO):
"""StringIO that throws an exception when it's read from."""
def read(self, *args, **kwargs):
raise IOError
def readline(self, *args, **kwargs):
raise IOError
def readlines(self, *args, **kwargs):
raise IOError
def readable(self, *args, **kwargs):
"""Returns True if the IO object can be read."""
return False
class redirect_stdin(contextlib._RedirectStream): # type: ignore
_stream = 'stdin'
@ICL_EVALUATORS.register_module()
class MBPPEvaluator2(MBPPEvaluator):
@ -159,3 +204,140 @@ class MBPPEvaluator2(MBPPEvaluator):
if text.startswith("'"):
text = text[1:]
return text
def execution(programs, task_id, timeout):
"""Execution function for running generation code.
Args:
programs(str): Python code to be executed.
task_id(int): Task id of the current example.
timeout(int): Time limit for execution, avoid unnecessary
blocking.
In pass@k scenario, a lot of programs should be executed.
Some internal error cannot be handled properly, such as
`RecursionError` might cause system break. It is better to
separate the execution in thread or multiprocess to better
control the process.
"""
def _execution(programs, timeout):
try:
# Add exec globals to prevent the exec to raise
# unnecessary NameError for correct answer
exec_globals = {}
with swallow_io():
with time_limit(timeout):
exec(programs, exec_globals)
key.append('pass')
except TimeOutException:
key.append('timeout')
except AssertionError:
key.append('wrong_answer')
except BaseException as e:
print(e)
key.append('failed')
manager = multiprocessing.Manager()
key = manager.list()
# `signal` cannot be used in child thread, therefore, we
# need to create a process in the thread.
p = multiprocessing.Process(target=_execution,
args=(programs, timeout - 1))
p.start()
p.join(timeout=timeout)
if p.is_alive():
p.kill()
# key might not have value if killed
return task_id, 'timeout'
return task_id, key[0]
class MBPPPassKEvaluator(MBPPEvaluator):
"""Better use for pass k evaluation.
Args:
k(Tuple[int]): Choices of Pass@k. Defaults to (1, 10, 100)
"""
def __init__(self, k=(1, 10, 100)) -> None:
if not isinstance(k, Sequence):
k = (k, )
self.k = k
@staticmethod
def estimate_pass_at_k(
num_samples: Union[int, List[int], np.ndarray],
num_correct: Union[List[int], np.ndarray],
k: int,
) -> np.ndarray:
"""Estimates pass@k of each problem and returns them in an array."""
def estimator(n: int, c: int, k: int) -> float:
"""
Calculates 1 - comb(n - c, k) / comb(n, k).
"""
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
if isinstance(num_samples, int):
num_samples_it = itertools.repeat(num_samples, len(num_correct))
else:
assert len(num_samples) == len(num_correct)
num_samples_it = iter(num_samples)
return np.array([
estimator(int(n), int(c), k)
for n, c in zip(num_samples_it, num_correct)
])
def score(self, predictions, references):
assert len(predictions) == len(references)
task_pass = defaultdict(int)
task_total = defaultdict(int)
result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0}
with ThreadPoolExecutor() as executor:
futures = []
for refer, preds in zip(references, predictions):
# suits for two case
# 1. use repeated dataset
# 2. use `num_return_sequences` to generate multiple responses
if not isinstance(preds, list):
preds = [preds]
test_case = refer['test_list_2']
task_id = refer['task_id']
# create empty task_pass in case all example failed
if task_id not in task_pass:
task_pass[task_id] = 0
for pred in preds:
pred = self._process_answer(pred)
programs = self._process_test(test_case, pred)
future = executor.submit(execution, programs, task_id, 3)
futures.append(future)
from tqdm import tqdm
for future in tqdm(as_completed(futures), total=len(futures)):
task_id, key = future.result()
result[key] += 1
task_total[task_id] += 1
if key == 'pass':
task_pass[task_id] += 1
def get_number(tasks):
return np.array([
task[1] for task in sorted(tasks.items(), key=lambda x: x[0])
])
task_pass = get_number(task_pass)
task_total = get_number(task_total)
pass_at_k = {
f'pass@{k}':
self.estimate_pass_at_k(task_total, task_pass, k).mean() * 100
for k in self.k if (task_total >= k).all()
}
result.update(pass_at_k)
return result

View File

@ -63,6 +63,7 @@ class HuggingFace(BaseModel):
peft_path: Optional[str] = None,
tokenizer_only: bool = False,
model_kwargs: dict = dict(device_map='auto'),
generation_kwargs: dict = dict(),
meta_template: Optional[Dict] = None,
extract_pred_after_decode: bool = False,
batch_padding: bool = False,
@ -89,6 +90,7 @@ class HuggingFace(BaseModel):
self._load_model(path=path,
model_kwargs=model_kwargs,
peft_path=peft_path)
self.generation_kwargs = generation_kwargs
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
tokenizer_kwargs: dict):
@ -193,13 +195,15 @@ class HuggingFace(BaseModel):
Returns:
List[str]: A list of generated strings.
"""
generation_kwargs = kwargs.copy()
generation_kwargs.update(self.generation_kwargs)
if self.batch_padding and len(inputs) > 1:
return self._batch_generate(inputs=inputs,
max_out_len=max_out_len,
**kwargs)
**generation_kwargs)
else:
return sum((self._single_generate(
inputs=[input_], max_out_len=max_out_len, **kwargs)
inputs=[input_], max_out_len=max_out_len, **generation_kwargs)
for input_ in inputs), [])
def _batch_generate(self, inputs: List[str], max_out_len: int,

View File

@ -10,6 +10,7 @@ from tqdm import tqdm
from opencompass.models.base import BaseModel
from opencompass.registry import ICL_INFERENCERS
from opencompass.utils import batched
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
@ -129,9 +130,14 @@ class GenInferencer(BaseInferencer):
entry, max_out_len=self.max_out_len)
generated = results
num_return_sequences = self.model.generation_kwargs.get(
'num_return_sequences', 1)
# 5-3. Save current output
for prompt, prediction, gold in zip(parsed_entries, generated,
golds):
for prompt, prediction, gold in zip(
parsed_entries, batched(generated, num_return_sequences),
golds):
if num_return_sequences == 1:
prediction = prediction[0]
output_handler.save_results(prompt,
prediction,
index,

View File

@ -122,7 +122,7 @@ class OpenICLEvalTask(BaseTask):
preds = {k: [pred.get(k) for pred in preds] for k in preds[0]}
pred_strs = preds.pop('prediction')
pred_list_flag = isinstance(pred_strs[0], list)
if ('pred_role' in self.eval_cfg
and 'meta_template' in self.model_cfg
and not MODELS.get(self.model_cfg['type']).is_api):
@ -131,16 +131,16 @@ class OpenICLEvalTask(BaseTask):
parser = LMTemplateParser(self.model_cfg['meta_template'])
role = parser.roles[self.eval_cfg['pred_role']]
if sc_size is not None:
assert pred_list_flag, (
'The prediction for Self-Consistency'
'must be list.')
if pred_list_flag:
for pred in pred_strs:
if not isinstance(pred, list):
raise TypeError(
'The prediction for Self-Consistency'
'must be list.')
pred_strs.append([
self._extract_role_pred(sc_pred,
self._extract_role_pred(_pred,
role.get('begin', None),
role.get('end', None))
for sc_pred in pred
for _pred in pred
])
else:
pred_strs = [
@ -155,7 +155,7 @@ class OpenICLEvalTask(BaseTask):
proc = kwargs.pop('type')
if isinstance(proc, str):
proc = TEXT_POSTPROCESSORS.get(proc)
if sc_size is not None:
if pred_list_flag:
pred_strs = [[proc(s, **kwargs) for s in preds]
for preds in pred_strs]
else:

View File

@ -1,4 +1,5 @@
from .abbr import * # noqa
from .auxiliary import * # noqa
from .build import * # noqa
from .collect_env import * # noqa
from .dependency import * # noqa

View File

@ -0,0 +1,14 @@
from itertools import islice
try:
# batched is in 3.12
from itertools import batched
except ImportError:
def batched(iterable, n):
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError('n must be at least one')
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch