Compare commits

...

3 Commits

Author SHA1 Message Date
Xu Song
d755c7a834
Merge ea517e3377 into 8a5029b121 2025-03-23 14:49:50 +08:00
xusong28
ea517e3377 fix yapf 2024-07-12 16:07:24 +08:00
xusong28
2450a22f18 add api_prompts to GenInferencerOutput 2024-07-12 15:11:34 +08:00
4 changed files with 26 additions and 11 deletions

View File

@ -444,6 +444,7 @@ class HuggingFacewithChatTemplate(BaseModel):
tokenize_kwargs['add_special_tokens'] = False
tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs)
prompt_list = messages
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
if self.mode == 'mid':
@ -486,7 +487,7 @@ class HuggingFacewithChatTemplate(BaseModel):
for stop in stopping_criteria:
decodeds = [t.split(stop)[0] for t in decodeds]
return decodeds
return prompt_list, decodeds
def get_token_len(self, prompt: str) -> int:
m = _convert_chat_messages([prompt])[0]

View File

@ -126,7 +126,7 @@ class VLLMwithChatTemplate(BaseModel):
prompt_list.append(prompt)
output_strs.append(generated_text)
return output_strs
return prompt_list, output_strs
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized strings.

View File

@ -111,13 +111,20 @@ class GenInferencerOutputHandler:
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, Path(save_dir) / filename)
def save_results(self, origin_prompt, prediction, idx, gold=None):
self.results_dict[str(idx)] = {
'origin_prompt': origin_prompt,
'prediction': prediction,
}
def save_results(self,
origin_prompt,
prediction,
idx,
api_prompts=None,
gold=None):
results = {}
if api_prompts:
results['api_prompts'] = api_prompts
results['origin_prompt'] = origin_prompt
results['prediction'] = prediction
if gold:
self.results_dict[str(idx)]['gold'] = gold
results['gold'] = gold
self.results_dict[str(idx)] = results
class PPLInferencerOutputHandler:

View File

@ -153,18 +153,25 @@ class GenInferencer(BaseInferencer):
results = self.model.generate_from_template(
entry, max_out_len=self.max_out_len, **extra_gen_kwargs)
generated = results
if isinstance(generated, tuple):
api_prompts_list = parsed_entries
prompts, generated = generated
else:
api_prompts_list = [None] * len(generated)
prompts = parsed_entries
num_return_sequences = getattr(self.model, 'generation_kwargs',
{}).get('num_return_sequences', 1)
# 5-3. Save current output
for prompt, prediction, gold in zip(
parsed_entries, batched(generated, num_return_sequences),
golds):
for api_prompts, prompt, prediction, gold in zip(
api_prompts_list, prompts,
batched(generated, num_return_sequences), golds):
if num_return_sequences == 1:
prediction = prediction[0]
output_handler.save_results(prompt,
prediction,
index,
api_prompts=api_prompts,
gold=gold)
index = index + 1