mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Merge ea517e3377
into 8a5029b121
This commit is contained in:
commit
d755c7a834
@ -444,6 +444,7 @@ class HuggingFacewithChatTemplate(BaseModel):
|
||||
tokenize_kwargs['add_special_tokens'] = False
|
||||
tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs)
|
||||
|
||||
prompt_list = messages
|
||||
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
|
||||
|
||||
if self.mode == 'mid':
|
||||
@ -486,7 +487,7 @@ class HuggingFacewithChatTemplate(BaseModel):
|
||||
for stop in stopping_criteria:
|
||||
decodeds = [t.split(stop)[0] for t in decodeds]
|
||||
|
||||
return decodeds
|
||||
return prompt_list, decodeds
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
m = _convert_chat_messages([prompt])[0]
|
||||
|
@ -126,7 +126,7 @@ class VLLMwithChatTemplate(BaseModel):
|
||||
prompt_list.append(prompt)
|
||||
output_strs.append(generated_text)
|
||||
|
||||
return output_strs
|
||||
return prompt_list, output_strs
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized strings.
|
||||
|
@ -111,13 +111,20 @@ class GenInferencerOutputHandler:
|
||||
"""Dump the result to a json file."""
|
||||
dump_results_dict(self.results_dict, Path(save_dir) / filename)
|
||||
|
||||
def save_results(self, origin_prompt, prediction, idx, gold=None):
|
||||
self.results_dict[str(idx)] = {
|
||||
'origin_prompt': origin_prompt,
|
||||
'prediction': prediction,
|
||||
}
|
||||
def save_results(self,
|
||||
origin_prompt,
|
||||
prediction,
|
||||
idx,
|
||||
api_prompts=None,
|
||||
gold=None):
|
||||
results = {}
|
||||
if api_prompts:
|
||||
results['api_prompts'] = api_prompts
|
||||
results['origin_prompt'] = origin_prompt
|
||||
results['prediction'] = prediction
|
||||
if gold:
|
||||
self.results_dict[str(idx)]['gold'] = gold
|
||||
results['gold'] = gold
|
||||
self.results_dict[str(idx)] = results
|
||||
|
||||
|
||||
class PPLInferencerOutputHandler:
|
||||
|
@ -153,18 +153,25 @@ class GenInferencer(BaseInferencer):
|
||||
results = self.model.generate_from_template(
|
||||
entry, max_out_len=self.max_out_len, **extra_gen_kwargs)
|
||||
generated = results
|
||||
if isinstance(generated, tuple):
|
||||
api_prompts_list = parsed_entries
|
||||
prompts, generated = generated
|
||||
else:
|
||||
api_prompts_list = [None] * len(generated)
|
||||
prompts = parsed_entries
|
||||
|
||||
num_return_sequences = getattr(self.model, 'generation_kwargs',
|
||||
{}).get('num_return_sequences', 1)
|
||||
# 5-3. Save current output
|
||||
for prompt, prediction, gold in zip(
|
||||
parsed_entries, batched(generated, num_return_sequences),
|
||||
golds):
|
||||
for api_prompts, prompt, prediction, gold in zip(
|
||||
api_prompts_list, prompts,
|
||||
batched(generated, num_return_sequences), golds):
|
||||
if num_return_sequences == 1:
|
||||
prediction = prediction[0]
|
||||
output_handler.save_results(prompt,
|
||||
prediction,
|
||||
index,
|
||||
api_prompts=api_prompts,
|
||||
gold=gold)
|
||||
index = index + 1
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user