add api_prompts to GenInferencerOutput

This commit is contained in:
xusong28 2024-07-12 15:11:34 +08:00
parent 3aeabbc427
commit 2450a22f18
4 changed files with 20 additions and 10 deletions

View File

@ -270,6 +270,7 @@ class HuggingFacewithChatTemplate(BaseModel):
tokenize_kwargs['add_special_tokens'] = False tokenize_kwargs['add_special_tokens'] = False
tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs) tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs)
prompt_list = messages
tokens = {k: v.to(self.model.device) for k, v in tokens.items()} tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
generation_kwargs = self.generation_kwargs.copy() generation_kwargs = self.generation_kwargs.copy()
@ -292,7 +293,7 @@ class HuggingFacewithChatTemplate(BaseModel):
for stop in stopping_criteria: for stop in stopping_criteria:
decodeds = [t.split(stop)[0] for t in decodeds] decodeds = [t.split(stop)[0] for t in decodeds]
return decodeds return prompt_list, decodeds
def get_token_len(self, prompt: str) -> int: def get_token_len(self, prompt: str) -> int:
m = _convert_chat_messages([prompt])[0] m = _convert_chat_messages([prompt])[0]

View File

@ -118,7 +118,7 @@ class VLLMwithChatTemplate(BaseModel):
prompt_list.append(prompt) prompt_list.append(prompt)
output_strs.append(generated_text) output_strs.append(generated_text)
return output_strs return prompt_list, output_strs
def get_token_len(self, prompt: str) -> int: def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized strings. """Get lengths of the tokenized strings.

View File

@ -111,13 +111,15 @@ class GenInferencerOutputHandler:
"""Dump the result to a json file.""" """Dump the result to a json file."""
dump_results_dict(self.results_dict, Path(save_dir) / filename) dump_results_dict(self.results_dict, Path(save_dir) / filename)
def save_results(self, origin_prompt, prediction, idx, gold=None): def save_results(self, origin_prompt, prediction, idx, api_prompts=None, gold=None):
self.results_dict[str(idx)] = { results = {}
'origin_prompt': origin_prompt, if api_prompts:
'prediction': prediction, results['api_prompts'] = api_prompts
} results['origin_prompt'] = origin_prompt
results['prediction'] = prediction
if gold: if gold:
self.results_dict[str(idx)]['gold'] = gold results['gold'] = gold
self.results_dict[str(idx)] = results
class PPLInferencerOutputHandler: class PPLInferencerOutputHandler:

View File

@ -152,18 +152,25 @@ class GenInferencer(BaseInferencer):
results = self.model.generate_from_template( results = self.model.generate_from_template(
entry, max_out_len=self.max_out_len, **extra_gen_kwargs) entry, max_out_len=self.max_out_len, **extra_gen_kwargs)
generated = results generated = results
if isinstance(generated, tuple):
api_prompts_list = parsed_entries
prompts, generated = generated
else:
api_prompts_list = [None]*len(generated)
prompts = parsed_entries
num_return_sequences = getattr(self.model, 'generation_kwargs', num_return_sequences = getattr(self.model, 'generation_kwargs',
{}).get('num_return_sequences', 1) {}).get('num_return_sequences', 1)
# 5-3. Save current output # 5-3. Save current output
for prompt, prediction, gold in zip( for api_prompts, prompt, prediction, gold in zip(
parsed_entries, batched(generated, num_return_sequences), api_prompts_list, prompts, batched(generated, num_return_sequences),
golds): golds):
if num_return_sequences == 1: if num_return_sequences == 1:
prediction = prediction[0] prediction = prediction[0]
output_handler.save_results(prompt, output_handler.save_results(prompt,
prediction, prediction,
index, index,
api_prompts=api_prompts,
gold=gold) gold=gold)
index = index + 1 index = index + 1