This commit is contained in:
xusong28 2024-07-12 16:07:24 +08:00
parent 2450a22f18
commit ea517e3377
2 changed files with 9 additions and 4 deletions

View File

@ -111,7 +111,12 @@ class GenInferencerOutputHandler:
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, Path(save_dir) / filename)
def save_results(self, origin_prompt, prediction, idx, api_prompts=None, gold=None):
def save_results(self,
origin_prompt,
prediction,
idx,
api_prompts=None,
gold=None):
results = {}
if api_prompts:
results['api_prompts'] = api_prompts

View File

@ -156,15 +156,15 @@ class GenInferencer(BaseInferencer):
api_prompts_list = parsed_entries
prompts, generated = generated
else:
api_prompts_list = [None]*len(generated)
api_prompts_list = [None] * len(generated)
prompts = parsed_entries
num_return_sequences = getattr(self.model, 'generation_kwargs',
{}).get('num_return_sequences', 1)
# 5-3. Save current output
for api_prompts, prompt, prediction, gold in zip(
api_prompts_list, prompts, batched(generated, num_return_sequences),
golds):
api_prompts_list, prompts,
batched(generated, num_return_sequences), golds):
if num_return_sequences == 1:
prediction = prediction[0]
output_handler.save_results(prompt,