mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
comments upadate
This commit is contained in:
parent
f92947c900
commit
492bf320af
@ -13,7 +13,7 @@ multiple_reader_cfg = dict(input_columns=['language', 'prompt'], output_column='
|
||||
multiple_infer_cfg = dict(
|
||||
prompt_template=dict(type=PromptTemplate, template='Based on the provided {language} code snippet, complete the subsequent content. The initial part of the completed code must match the provided code snippet exactly:\n{prompt}'),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=2048),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
multiple_eval_cfg = {
|
||||
|
@ -31,20 +31,17 @@ class MultiplEDataset(BaseDataset):
|
||||
num_repeats: int = 1,
|
||||
tag: str = 'humaneval',
|
||||
local_mode: bool = False):
|
||||
"""Load humaneval dataset for pass k mode.
|
||||
|
||||
Note that you can use num_repeats > 1 when your model does not support
|
||||
`num_return_sequence` in generation, otherwise use the raw
|
||||
humaneval dataset and set `num_return_sequence` in model config to
|
||||
generate multiple responses for testing pass@k>1.
|
||||
|
||||
It better to change your dataset abbr correspondingly if you want to
|
||||
change num_repeats>1, otherwise the number in
|
||||
`.cache/dataset_size.json` might be inconsistent.
|
||||
"""Load dataset for pass k mode.
|
||||
|
||||
Args:
|
||||
num_repeats(int): Number of repetition for this dataset to get
|
||||
multiple responses in special cases.
|
||||
path(str): The path to the dataset.
|
||||
language(str): The language of the dataset.
|
||||
num_repeats(int): Number of repetition for this dataset to get.
|
||||
tag(str): The tag of the dataset.
|
||||
local_mode(bool): Whether to load the dataset in local mode.
|
||||
|
||||
Returns:
|
||||
Dataset: A PyTorch dataset.
|
||||
"""
|
||||
path = get_data_path(path, local_mode=local_mode)
|
||||
assert tag in ['humaneval',
|
||||
@ -72,6 +69,13 @@ class MultiplEEvaluator(CodeEvaluator):
|
||||
|
||||
WARNING: the decoded_string *must not* include the prompt,
|
||||
which may have stop tokens itself.
|
||||
|
||||
Args:
|
||||
decoded_string: A string generated by the model.
|
||||
stop_tokens: A list of strings, where each string is a stop token.
|
||||
Returns:
|
||||
The decoded_string, truncated at the first occurrence of a stop
|
||||
token.
|
||||
"""
|
||||
min_stop_index = len(decoded_string)
|
||||
for stop_token in stop_tokens:
|
||||
@ -81,6 +85,14 @@ class MultiplEEvaluator(CodeEvaluator):
|
||||
return decoded_string[:min_stop_index]
|
||||
|
||||
def _process_completions(self, test_case, completions):
|
||||
"""Process completions with a test case.
|
||||
|
||||
Args:
|
||||
test_case: A test case.
|
||||
completions: A list of completions.
|
||||
Returns:
|
||||
A list of processed completions.
|
||||
"""
|
||||
processed_completions = []
|
||||
for comp in completions:
|
||||
comp = self._extract_code(comp)
|
||||
|
Loading…
Reference in New Issue
Block a user