mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Log gold answer in prediction output (#419)
* [Feature] Log gold answer in prediction output * support clp golden ans * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com>
This commit is contained in:
parent
97fdc51102
commit
681d3013de
@ -1,4 +1,4 @@
|
||||
from abc import abstractclassmethod
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
@ -37,7 +37,7 @@ class BaseModel:
|
||||
if meta_template and 'eos_token_id' in meta_template:
|
||||
self.eos_token_id = meta_template['eos_token_id']
|
||||
|
||||
@abstractclassmethod
|
||||
@abstractmethod
|
||||
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
@ -48,8 +48,11 @@ class BaseModel:
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
raise NotImplementedError(f'{self.__class__.__name__} does not support'
|
||||
' gen-based evaluation yet, try ppl-based '
|
||||
'instead.')
|
||||
|
||||
@abstractclassmethod
|
||||
@abstractmethod
|
||||
def get_ppl(self,
|
||||
inputs: List[str],
|
||||
mask_length: Optional[List[int]] = None) -> List[float]:
|
||||
@ -66,8 +69,11 @@ class BaseModel:
|
||||
Returns:
|
||||
List[float]: A list of perplexity scores.
|
||||
"""
|
||||
raise NotImplementedError(f'{self.__class__.__name__} does not support'
|
||||
' ppl-based evaluation yet, try gen-based '
|
||||
'instead.')
|
||||
|
||||
@abstractclassmethod
|
||||
@abstractmethod
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized strings.
|
||||
|
||||
@ -192,7 +198,7 @@ class LMTemplateParser:
|
||||
Returns:
|
||||
str: The final string.
|
||||
"""
|
||||
assert isinstance(prompt_template, (str, list, PromptList))
|
||||
assert isinstance(prompt_template, (str, list, PromptList, tuple))
|
||||
if not isinstance(prompt_template, (str, PromptList)):
|
||||
return [self.parse_template(p, mode=mode) for p in prompt_template]
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import re
|
||||
import threading
|
||||
import warnings
|
||||
from abc import abstractclassmethod
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
from time import sleep
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
@ -46,7 +46,7 @@ class BaseAPIModel(BaseModel):
|
||||
self.template_parser = APITemplateParser(meta_template)
|
||||
self.logger = get_logger()
|
||||
|
||||
@abstractclassmethod
|
||||
@abstractmethod
|
||||
def generate(self, inputs: List[PromptType],
|
||||
max_out_len: int) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
@ -60,8 +60,11 @@ class BaseAPIModel(BaseModel):
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
raise NotImplementedError(f'{self.__class__.__name__} does not support'
|
||||
' gen-based evaluation yet, try ppl-based '
|
||||
'instead.')
|
||||
|
||||
@abstractclassmethod
|
||||
@abstractmethod
|
||||
def get_ppl(self,
|
||||
inputs: List[PromptType],
|
||||
mask_length: Optional[List[int]] = None) -> List[float]:
|
||||
@ -78,6 +81,9 @@ class BaseAPIModel(BaseModel):
|
||||
Returns:
|
||||
List[float]: A list of perplexity scores.
|
||||
"""
|
||||
raise NotImplementedError(f'{self.__class__.__name__} does not support'
|
||||
' ppl-based evaluation yet, try gen-based '
|
||||
'instead.')
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized string. Only English and Chinese
|
||||
@ -161,7 +167,7 @@ class APITemplateParser:
|
||||
Returns:
|
||||
List[str or PromptList]: The finalized prompt or a conversation.
|
||||
"""
|
||||
assert isinstance(prompt_template, (str, list, PromptList))
|
||||
assert isinstance(prompt_template, (str, list, PromptList, tuple))
|
||||
|
||||
if not isinstance(prompt_template, (str, PromptList)):
|
||||
return [self.parse_template(p, mode=mode) for p in prompt_template]
|
||||
|
@ -108,6 +108,12 @@ class AttackInferencer(BaseInferencer):
|
||||
ice_template=self.ice_template,
|
||||
prompt_template=self.prompt_template)
|
||||
|
||||
# 3.1 Fetch and zip prompt & gold answer if output column exists
|
||||
ds_reader = self.retriever.dataset_reader
|
||||
if ds_reader.output_column:
|
||||
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
|
||||
prompt_list = list(zip(prompt_list, gold_ans))
|
||||
|
||||
# Create tmp json file for saving intermediate results and future
|
||||
# resuming
|
||||
index = 0
|
||||
@ -124,7 +130,12 @@ class AttackInferencer(BaseInferencer):
|
||||
|
||||
# 5. Inference for prompts in each batch
|
||||
logger.info('Starting inference process...')
|
||||
for entry in tqdm(dataloader, disable=not self.is_main_process):
|
||||
for datum in tqdm(dataloader, disable=not self.is_main_process):
|
||||
if ds_reader.output_column:
|
||||
entry, golds = list(zip(*datum))
|
||||
else:
|
||||
entry = datum
|
||||
golds = [None for _ in range(len(entry))]
|
||||
# 5-1. Inference with local model
|
||||
with torch.no_grad():
|
||||
parsed_entries = self.model.parse_template(entry, mode='gen')
|
||||
@ -133,8 +144,12 @@ class AttackInferencer(BaseInferencer):
|
||||
generated = results
|
||||
|
||||
# 5-3. Save current output
|
||||
for prompt, prediction in zip(parsed_entries, generated):
|
||||
output_handler.save_results(prompt, prediction, index)
|
||||
for prompt, prediction, gold in zip(parsed_entries, generated,
|
||||
golds):
|
||||
output_handler.save_results(prompt,
|
||||
prediction,
|
||||
index,
|
||||
gold=gold)
|
||||
index = index + 1
|
||||
|
||||
# 5-4. Save intermediate results
|
||||
|
@ -108,11 +108,13 @@ class GenInferencerOutputHandler:
|
||||
"""Dump the result to a json file."""
|
||||
dump_results_dict(self.results_dict, Path(save_dir) / filename)
|
||||
|
||||
def save_results(self, origin_prompt, prediction, idx):
|
||||
def save_results(self, origin_prompt, prediction, idx, gold=None):
|
||||
self.results_dict[str(idx)] = {
|
||||
'origin_prompt': origin_prompt,
|
||||
'prediction': prediction,
|
||||
}
|
||||
if gold:
|
||||
self.results_dict[str(idx)]['gold'] = gold
|
||||
|
||||
|
||||
class PPLInferencerOutputHandler:
|
||||
@ -147,6 +149,12 @@ class PPLInferencerOutputHandler:
|
||||
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
|
||||
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl
|
||||
|
||||
def save_golds(self, golds):
|
||||
for idx, gold in enumerate(golds):
|
||||
if str(idx) not in self.results_dict.keys():
|
||||
self.results_dict[str(idx)] = {}
|
||||
self.results_dict[str(idx)]['gold'] = gold
|
||||
|
||||
|
||||
class CLPInferencerOutputHandler:
|
||||
results_dict = {}
|
||||
@ -164,7 +172,13 @@ class CLPInferencerOutputHandler:
|
||||
self.results_dict[str(idx)] = {}
|
||||
self.results_dict[str(idx)]['in-context examples'] = example
|
||||
|
||||
def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices):
|
||||
def save_prompt_and_condprob(self,
|
||||
input,
|
||||
prompt,
|
||||
cond_prob,
|
||||
idx,
|
||||
choices,
|
||||
gold=None):
|
||||
if str(idx) not in self.results_dict.keys():
|
||||
self.results_dict[str(idx)] = {}
|
||||
# TODO:
|
||||
@ -177,3 +191,4 @@ class CLPInferencerOutputHandler:
|
||||
self.results_dict[str(idx)]['prediction'] = cond_prob
|
||||
# set pred label in case needed
|
||||
self.results_dict[str(idx)]['pred_label'] = int(np.argmax(cond_prob))
|
||||
self.results_dict[str(idx)]['gold'] = gold
|
||||
|
@ -175,22 +175,35 @@ class CLPInferencer(BaseInferencer):
|
||||
# minus the bos token
|
||||
choice_target_ids.append(prompt_token_num - 1)
|
||||
|
||||
# 4.1 Fetch and zip prompt & gold answer if output column exists
|
||||
ds_reader = retriever.dataset_reader
|
||||
if ds_reader.output_column:
|
||||
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
|
||||
else:
|
||||
gold_ans = [None] * len(prompt_list)
|
||||
|
||||
logger.info('Calculating conditional log probability for prompts.')
|
||||
for idx in trange(0,
|
||||
len(prompt_list),
|
||||
self.batch_size,
|
||||
disable=not self.is_main_process):
|
||||
sub_prompt_list = prompt_list[idx:idx + self.batch_size]
|
||||
sub_golds = gold_ans[idx:idx + self.batch_size]
|
||||
sub_choice_target_ids = choice_target_ids[idx:idx +
|
||||
self.batch_size]
|
||||
sub_res = self.__get_cond_prob(sub_prompt_list,
|
||||
sub_choice_target_ids,
|
||||
choice_ids)
|
||||
|
||||
for res, prompt in zip(sub_res, sub_prompt_list):
|
||||
output_handler.save_prompt_and_condprob(
|
||||
prompt.replace(ice[idx], ''), prompt, res, index,
|
||||
choices)
|
||||
for res, prompt, gold in zip(sub_res, sub_prompt_list,
|
||||
sub_golds):
|
||||
example_input = prompt.replace(ice[idx], '')
|
||||
output_handler.save_prompt_and_condprob(example_input,
|
||||
prompt,
|
||||
res,
|
||||
index,
|
||||
choices,
|
||||
gold=gold)
|
||||
index = index + 1
|
||||
|
||||
# 5. Output
|
||||
|
@ -99,6 +99,12 @@ class GenInferencer(BaseInferencer):
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
|
||||
# 3.1 Fetch and zip prompt & gold answer if output column exists
|
||||
ds_reader = retriever.dataset_reader
|
||||
if ds_reader.output_column:
|
||||
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
|
||||
prompt_list = list(zip(prompt_list, gold_ans))
|
||||
|
||||
# Create tmp json file for saving intermediate results and future
|
||||
# resuming
|
||||
index = 0
|
||||
@ -115,7 +121,12 @@ class GenInferencer(BaseInferencer):
|
||||
|
||||
# 5. Inference for prompts in each batch
|
||||
logger.info('Starting inference process...')
|
||||
for entry in tqdm(dataloader, disable=not self.is_main_process):
|
||||
for datum in tqdm(dataloader, disable=not self.is_main_process):
|
||||
if ds_reader.output_column:
|
||||
entry, golds = list(zip(*datum))
|
||||
else:
|
||||
entry = datum
|
||||
golds = [None for _ in range(len(entry))]
|
||||
# 5-1. Inference with local model
|
||||
with torch.no_grad():
|
||||
parsed_entries = self.model.parse_template(entry, mode='gen')
|
||||
@ -124,8 +135,12 @@ class GenInferencer(BaseInferencer):
|
||||
generated = results
|
||||
|
||||
# 5-3. Save current output
|
||||
for prompt, prediction in zip(parsed_entries, generated):
|
||||
output_handler.save_results(prompt, prediction, index)
|
||||
for prompt, prediction, gold in zip(parsed_entries, generated,
|
||||
golds):
|
||||
output_handler.save_results(prompt,
|
||||
prediction,
|
||||
index,
|
||||
gold=gold)
|
||||
index = index + 1
|
||||
|
||||
# 5-4. Save intermediate results
|
||||
|
@ -200,7 +200,13 @@ class PPLInferencer(BaseInferencer):
|
||||
sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
|
||||
output_handler.save_predictions(sub_predictions)
|
||||
|
||||
# 7. Output
|
||||
# 7. Fetch gold answers if exist
|
||||
ds_reader = retriever.dataset_reader
|
||||
if ds_reader.output_column:
|
||||
golds = ds_reader.dataset['test'][ds_reader.output_column]
|
||||
output_handler.save_golds(golds)
|
||||
|
||||
# 8. Output
|
||||
if self.is_main_process:
|
||||
os.makedirs(output_json_filepath, exist_ok=True)
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
|
@ -105,6 +105,12 @@ class SCInferencer(BaseInferencer):
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
|
||||
# 3.1 Fetch and zip prompt & gold answer if output column exists
|
||||
ds_reader = retriever.dataset_reader
|
||||
if ds_reader.output_column:
|
||||
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
|
||||
prompt_list = list(zip(prompt_list, gold_ans))
|
||||
|
||||
# Create tmp json file for saving intermediate results and future
|
||||
# resuming
|
||||
index = 0
|
||||
@ -121,7 +127,12 @@ class SCInferencer(BaseInferencer):
|
||||
|
||||
# 5. Inference for prompts in each batch
|
||||
logger.info('Starting inference process...')
|
||||
for entry in tqdm(dataloader, disable=not self.is_main_process):
|
||||
for datum in tqdm(dataloader, disable=not self.is_main_process):
|
||||
if ds_reader.output_column:
|
||||
entry, golds = list(zip(*datum))
|
||||
else:
|
||||
entry = datum
|
||||
golds = [None for _ in range(len(entry))]
|
||||
# TODO: add more types of CoT method
|
||||
# 5-1. Inference sc_size times with local model
|
||||
with torch.no_grad():
|
||||
@ -137,8 +148,12 @@ class SCInferencer(BaseInferencer):
|
||||
generated = sc_prediction
|
||||
|
||||
# 5-3. Save current output
|
||||
for prompt, prediction in zip(parsed_entries, generated):
|
||||
output_handler.save_results(prompt, prediction, index)
|
||||
for prompt, prediction, gold in zip(parsed_entries, generated,
|
||||
golds):
|
||||
output_handler.save_results(prompt,
|
||||
prediction,
|
||||
index,
|
||||
gold=gold)
|
||||
index = index + 1
|
||||
|
||||
# 5-4. Save intermediate results
|
||||
|
@ -333,6 +333,12 @@ class ToTInferencer(GenInferencer):
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
|
||||
# 3.1 Fetch and zip prompt & gold answer if output column exists
|
||||
ds_reader = retriever.dataset_reader
|
||||
if ds_reader.output_column:
|
||||
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
|
||||
prompt_list = list(zip(prompt_list, gold_ans))
|
||||
|
||||
# Create tmp json file for saving intermediate results and future
|
||||
# resuming
|
||||
index = 0
|
||||
@ -349,15 +355,24 @@ class ToTInferencer(GenInferencer):
|
||||
|
||||
# 5. Inference for prompts in each batch
|
||||
logger.info('Starting ToT inference process...')
|
||||
for entries in tqdm(dataloader, disable=not self.is_main_process):
|
||||
for datum in tqdm(dataloader, disable=not self.is_main_process):
|
||||
if ds_reader.output_column:
|
||||
entries, golds = list(zip(*datum))
|
||||
else:
|
||||
entries = datum
|
||||
golds = [None for _ in range(len(entries))]
|
||||
# 5-1. Inference with ToT and local model
|
||||
with torch.no_grad():
|
||||
parsed_entries = self.model.parse_template(entries, mode='gen')
|
||||
generated = [self.tot_solve(entry) for entry in entries]
|
||||
|
||||
# 5-2. Save current output
|
||||
for prompt, prediction in zip(parsed_entries, generated):
|
||||
output_handler.save_results(prompt, prediction, index)
|
||||
for prompt, prediction, gold in zip(parsed_entries, generated,
|
||||
golds):
|
||||
output_handler.save_results(prompt,
|
||||
prediction,
|
||||
index,
|
||||
gold=gold)
|
||||
index = index + 1
|
||||
|
||||
# 5-3. Save intermediate results
|
||||
|
Loading…
Reference in New Issue
Block a user