mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Merge ea517e3377
into 8a5029b121
This commit is contained in:
commit
d755c7a834
@ -444,6 +444,7 @@ class HuggingFacewithChatTemplate(BaseModel):
|
|||||||
tokenize_kwargs['add_special_tokens'] = False
|
tokenize_kwargs['add_special_tokens'] = False
|
||||||
tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs)
|
tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs)
|
||||||
|
|
||||||
|
prompt_list = messages
|
||||||
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
|
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
|
||||||
|
|
||||||
if self.mode == 'mid':
|
if self.mode == 'mid':
|
||||||
@ -486,7 +487,7 @@ class HuggingFacewithChatTemplate(BaseModel):
|
|||||||
for stop in stopping_criteria:
|
for stop in stopping_criteria:
|
||||||
decodeds = [t.split(stop)[0] for t in decodeds]
|
decodeds = [t.split(stop)[0] for t in decodeds]
|
||||||
|
|
||||||
return decodeds
|
return prompt_list, decodeds
|
||||||
|
|
||||||
def get_token_len(self, prompt: str) -> int:
|
def get_token_len(self, prompt: str) -> int:
|
||||||
m = _convert_chat_messages([prompt])[0]
|
m = _convert_chat_messages([prompt])[0]
|
||||||
|
@ -126,7 +126,7 @@ class VLLMwithChatTemplate(BaseModel):
|
|||||||
prompt_list.append(prompt)
|
prompt_list.append(prompt)
|
||||||
output_strs.append(generated_text)
|
output_strs.append(generated_text)
|
||||||
|
|
||||||
return output_strs
|
return prompt_list, output_strs
|
||||||
|
|
||||||
def get_token_len(self, prompt: str) -> int:
|
def get_token_len(self, prompt: str) -> int:
|
||||||
"""Get lengths of the tokenized strings.
|
"""Get lengths of the tokenized strings.
|
||||||
|
@ -111,13 +111,20 @@ class GenInferencerOutputHandler:
|
|||||||
"""Dump the result to a json file."""
|
"""Dump the result to a json file."""
|
||||||
dump_results_dict(self.results_dict, Path(save_dir) / filename)
|
dump_results_dict(self.results_dict, Path(save_dir) / filename)
|
||||||
|
|
||||||
def save_results(self, origin_prompt, prediction, idx, gold=None):
|
def save_results(self,
|
||||||
self.results_dict[str(idx)] = {
|
origin_prompt,
|
||||||
'origin_prompt': origin_prompt,
|
prediction,
|
||||||
'prediction': prediction,
|
idx,
|
||||||
}
|
api_prompts=None,
|
||||||
|
gold=None):
|
||||||
|
results = {}
|
||||||
|
if api_prompts:
|
||||||
|
results['api_prompts'] = api_prompts
|
||||||
|
results['origin_prompt'] = origin_prompt
|
||||||
|
results['prediction'] = prediction
|
||||||
if gold:
|
if gold:
|
||||||
self.results_dict[str(idx)]['gold'] = gold
|
results['gold'] = gold
|
||||||
|
self.results_dict[str(idx)] = results
|
||||||
|
|
||||||
|
|
||||||
class PPLInferencerOutputHandler:
|
class PPLInferencerOutputHandler:
|
||||||
|
@ -153,18 +153,25 @@ class GenInferencer(BaseInferencer):
|
|||||||
results = self.model.generate_from_template(
|
results = self.model.generate_from_template(
|
||||||
entry, max_out_len=self.max_out_len, **extra_gen_kwargs)
|
entry, max_out_len=self.max_out_len, **extra_gen_kwargs)
|
||||||
generated = results
|
generated = results
|
||||||
|
if isinstance(generated, tuple):
|
||||||
|
api_prompts_list = parsed_entries
|
||||||
|
prompts, generated = generated
|
||||||
|
else:
|
||||||
|
api_prompts_list = [None] * len(generated)
|
||||||
|
prompts = parsed_entries
|
||||||
|
|
||||||
num_return_sequences = getattr(self.model, 'generation_kwargs',
|
num_return_sequences = getattr(self.model, 'generation_kwargs',
|
||||||
{}).get('num_return_sequences', 1)
|
{}).get('num_return_sequences', 1)
|
||||||
# 5-3. Save current output
|
# 5-3. Save current output
|
||||||
for prompt, prediction, gold in zip(
|
for api_prompts, prompt, prediction, gold in zip(
|
||||||
parsed_entries, batched(generated, num_return_sequences),
|
api_prompts_list, prompts,
|
||||||
golds):
|
batched(generated, num_return_sequences), golds):
|
||||||
if num_return_sequences == 1:
|
if num_return_sequences == 1:
|
||||||
prediction = prediction[0]
|
prediction = prediction[0]
|
||||||
output_handler.save_results(prompt,
|
output_handler.save_results(prompt,
|
||||||
prediction,
|
prediction,
|
||||||
index,
|
index,
|
||||||
|
api_prompts=api_prompts,
|
||||||
gold=gold)
|
gold=gold)
|
||||||
index = index + 1
|
index = index + 1
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user