add api_prompts to GenInferencerOutput

This commit is contained in:
xusong28 2024-07-12 15:11:34 +08:00
parent 3aeabbc427
commit 2450a22f18
4 changed files with 20 additions and 10 deletions

View File

@ -270,6 +270,7 @@ class HuggingFacewithChatTemplate(BaseModel):
tokenize_kwargs['add_special_tokens'] = False
tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs)
prompt_list = messages
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
generation_kwargs = self.generation_kwargs.copy()
@ -292,7 +293,7 @@ class HuggingFacewithChatTemplate(BaseModel):
for stop in stopping_criteria:
decodeds = [t.split(stop)[0] for t in decodeds]
return decodeds
return prompt_list, decodeds
def get_token_len(self, prompt: str) -> int:
m = _convert_chat_messages([prompt])[0]

View File

@ -118,7 +118,7 @@ class VLLMwithChatTemplate(BaseModel):
prompt_list.append(prompt)
output_strs.append(generated_text)
return output_strs
return prompt_list, output_strs
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized strings.

View File

@ -111,13 +111,15 @@ class GenInferencerOutputHandler:
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, Path(save_dir) / filename)
def save_results(self, origin_prompt, prediction, idx, gold=None):
self.results_dict[str(idx)] = {
'origin_prompt': origin_prompt,
'prediction': prediction,
}
def save_results(self, origin_prompt, prediction, idx, api_prompts=None, gold=None):
results = {}
if api_prompts:
results['api_prompts'] = api_prompts
results['origin_prompt'] = origin_prompt
results['prediction'] = prediction
if gold:
self.results_dict[str(idx)]['gold'] = gold
results['gold'] = gold
self.results_dict[str(idx)] = results
class PPLInferencerOutputHandler:

View File

@ -152,18 +152,25 @@ class GenInferencer(BaseInferencer):
results = self.model.generate_from_template(
entry, max_out_len=self.max_out_len, **extra_gen_kwargs)
generated = results
if isinstance(generated, tuple):
api_prompts_list = parsed_entries
prompts, generated = generated
else:
api_prompts_list = [None]*len(generated)
prompts = parsed_entries
num_return_sequences = getattr(self.model, 'generation_kwargs',
{}).get('num_return_sequences', 1)
# 5-3. Save current output
for prompt, prediction, gold in zip(
parsed_entries, batched(generated, num_return_sequences),
for api_prompts, prompt, prediction, gold in zip(
api_prompts_list, prompts, batched(generated, num_return_sequences),
golds):
if num_return_sequences == 1:
prediction = prediction[0]
output_handler.save_results(prompt,
prediction,
index,
api_prompts=api_prompts,
gold=gold)
index = index + 1