comments upadate

This commit is contained in:
Dongsheng Zhu 2025-03-21 03:49:54 +00:00
parent f92947c900
commit 492bf320af
2 changed files with 25 additions and 13 deletions

View File

@ -13,7 +13,7 @@ multiple_reader_cfg = dict(input_columns=['language', 'prompt'], output_column='
multiple_infer_cfg = dict(
prompt_template=dict(type=PromptTemplate, template='Based on the provided {language} code snippet, complete the subsequent content. The initial part of the completed code must match the provided code snippet exactly:\n{prompt}'),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=2048),
inferencer=dict(type=GenInferencer),
)
multiple_eval_cfg = {

View File

@ -31,20 +31,17 @@ class MultiplEDataset(BaseDataset):
num_repeats: int = 1,
tag: str = 'humaneval',
local_mode: bool = False):
"""Load humaneval dataset for pass k mode.
Note that you can use num_repeats > 1 when your model does not support
`num_return_sequence` in generation, otherwise use the raw
humaneval dataset and set `num_return_sequence` in model config to
generate multiple responses for testing pass@k>1.
It better to change your dataset abbr correspondingly if you want to
change num_repeats>1, otherwise the number in
`.cache/dataset_size.json` might be inconsistent.
"""Load dataset for pass k mode.
Args:
num_repeats(int): Number of repetition for this dataset to get
multiple responses in special cases.
path(str): The path to the dataset.
language(str): The language of the dataset.
num_repeats(int): Number of repetition for this dataset to get.
tag(str): The tag of the dataset.
local_mode(bool): Whether to load the dataset in local mode.
Returns:
Dataset: A PyTorch dataset.
"""
path = get_data_path(path, local_mode=local_mode)
assert tag in ['humaneval',
@ -72,6 +69,13 @@ class MultiplEEvaluator(CodeEvaluator):
WARNING: the decoded_string *must not* include the prompt,
which may have stop tokens itself.
Args:
decoded_string: A string generated by the model.
stop_tokens: A list of strings, where each string is a stop token.
Returns:
The decoded_string, truncated at the first occurrence of a stop
token.
"""
min_stop_index = len(decoded_string)
for stop_token in stop_tokens:
@ -81,6 +85,14 @@ class MultiplEEvaluator(CodeEvaluator):
return decoded_string[:min_stop_index]
def _process_completions(self, test_case, completions):
"""Process completions with a test case.
Args:
test_case: A test case.
completions: A list of completions.
Returns:
A list of processed completions.
"""
processed_completions = []
for comp in completions:
comp = self._extract_code(comp)