[Feature] Log gold answer in prediction output (#419)

* [Feature] Log gold answer in prediction output

* support clp golden ans

* minor fix

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>
This commit is contained in:
Tong Gao 2023-09-22 12:44:40 +08:00 committed by GitHub
parent 97fdc51102
commit 681d3013de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 134 additions and 28 deletions

View File

@ -1,4 +1,4 @@
from abc import abstractclassmethod
from abc import abstractmethod
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union
@ -37,7 +37,7 @@ class BaseModel:
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']
@abstractclassmethod
@abstractmethod
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
"""Generate results given a list of inputs.
@ -48,8 +48,11 @@ class BaseModel:
Returns:
List[str]: A list of generated strings.
"""
raise NotImplementedError(f'{self.__class__.__name__} does not support'
' gen-based evaluation yet, try ppl-based '
'instead.')
@abstractclassmethod
@abstractmethod
def get_ppl(self,
inputs: List[str],
mask_length: Optional[List[int]] = None) -> List[float]:
@ -66,8 +69,11 @@ class BaseModel:
Returns:
List[float]: A list of perplexity scores.
"""
raise NotImplementedError(f'{self.__class__.__name__} does not support'
' ppl-based evaluation yet, try gen-based '
'instead.')
@abstractclassmethod
@abstractmethod
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized strings.
@ -192,7 +198,7 @@ class LMTemplateParser:
Returns:
str: The final string.
"""
assert isinstance(prompt_template, (str, list, PromptList))
assert isinstance(prompt_template, (str, list, PromptList, tuple))
if not isinstance(prompt_template, (str, PromptList)):
return [self.parse_template(p, mode=mode) for p in prompt_template]

View File

@ -1,7 +1,7 @@
import re
import threading
import warnings
from abc import abstractclassmethod
from abc import abstractmethod
from copy import deepcopy
from time import sleep
from typing import Dict, List, Optional, Tuple, Union
@ -46,7 +46,7 @@ class BaseAPIModel(BaseModel):
self.template_parser = APITemplateParser(meta_template)
self.logger = get_logger()
@abstractclassmethod
@abstractmethod
def generate(self, inputs: List[PromptType],
max_out_len: int) -> List[str]:
"""Generate results given a list of inputs.
@ -60,8 +60,11 @@ class BaseAPIModel(BaseModel):
Returns:
List[str]: A list of generated strings.
"""
raise NotImplementedError(f'{self.__class__.__name__} does not support'
' gen-based evaluation yet, try ppl-based '
'instead.')
@abstractclassmethod
@abstractmethod
def get_ppl(self,
inputs: List[PromptType],
mask_length: Optional[List[int]] = None) -> List[float]:
@ -78,6 +81,9 @@ class BaseAPIModel(BaseModel):
Returns:
List[float]: A list of perplexity scores.
"""
raise NotImplementedError(f'{self.__class__.__name__} does not support'
' ppl-based evaluation yet, try gen-based '
'instead.')
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized string. Only English and Chinese
@ -161,7 +167,7 @@ class APITemplateParser:
Returns:
List[str or PromptList]: The finalized prompt or a conversation.
"""
assert isinstance(prompt_template, (str, list, PromptList))
assert isinstance(prompt_template, (str, list, PromptList, tuple))
if not isinstance(prompt_template, (str, PromptList)):
return [self.parse_template(p, mode=mode) for p in prompt_template]

View File

@ -108,6 +108,12 @@ class AttackInferencer(BaseInferencer):
ice_template=self.ice_template,
prompt_template=self.prompt_template)
# 3.1 Fetch and zip prompt & gold answer if output column exists
ds_reader = self.retriever.dataset_reader
if ds_reader.output_column:
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
prompt_list = list(zip(prompt_list, gold_ans))
# Create tmp json file for saving intermediate results and future
# resuming
index = 0
@ -124,7 +130,12 @@ class AttackInferencer(BaseInferencer):
# 5. Inference for prompts in each batch
logger.info('Starting inference process...')
for entry in tqdm(dataloader, disable=not self.is_main_process):
for datum in tqdm(dataloader, disable=not self.is_main_process):
if ds_reader.output_column:
entry, golds = list(zip(*datum))
else:
entry = datum
golds = [None for _ in range(len(entry))]
# 5-1. Inference with local model
with torch.no_grad():
parsed_entries = self.model.parse_template(entry, mode='gen')
@ -133,8 +144,12 @@ class AttackInferencer(BaseInferencer):
generated = results
# 5-3. Save current output
for prompt, prediction in zip(parsed_entries, generated):
output_handler.save_results(prompt, prediction, index)
for prompt, prediction, gold in zip(parsed_entries, generated,
golds):
output_handler.save_results(prompt,
prediction,
index,
gold=gold)
index = index + 1
# 5-4. Save intermediate results

View File

@ -108,11 +108,13 @@ class GenInferencerOutputHandler:
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, Path(save_dir) / filename)
def save_results(self, origin_prompt, prediction, idx):
def save_results(self, origin_prompt, prediction, idx, gold=None):
self.results_dict[str(idx)] = {
'origin_prompt': origin_prompt,
'prediction': prediction,
}
if gold:
self.results_dict[str(idx)]['gold'] = gold
class PPLInferencerOutputHandler:
@ -147,6 +149,12 @@ class PPLInferencerOutputHandler:
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl
def save_golds(self, golds):
for idx, gold in enumerate(golds):
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}
self.results_dict[str(idx)]['gold'] = gold
class CLPInferencerOutputHandler:
results_dict = {}
@ -164,7 +172,13 @@ class CLPInferencerOutputHandler:
self.results_dict[str(idx)] = {}
self.results_dict[str(idx)]['in-context examples'] = example
def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices):
def save_prompt_and_condprob(self,
input,
prompt,
cond_prob,
idx,
choices,
gold=None):
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}
# TODO:
@ -177,3 +191,4 @@ class CLPInferencerOutputHandler:
self.results_dict[str(idx)]['prediction'] = cond_prob
# set pred label in case needed
self.results_dict[str(idx)]['pred_label'] = int(np.argmax(cond_prob))
self.results_dict[str(idx)]['gold'] = gold

View File

@ -175,22 +175,35 @@ class CLPInferencer(BaseInferencer):
# minus the bos token
choice_target_ids.append(prompt_token_num - 1)
# 4.1 Fetch and zip prompt & gold answer if output column exists
ds_reader = retriever.dataset_reader
if ds_reader.output_column:
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
else:
gold_ans = [None] * len(prompt_list)
logger.info('Calculating conditional log probability for prompts.')
for idx in trange(0,
len(prompt_list),
self.batch_size,
disable=not self.is_main_process):
sub_prompt_list = prompt_list[idx:idx + self.batch_size]
sub_golds = gold_ans[idx:idx + self.batch_size]
sub_choice_target_ids = choice_target_ids[idx:idx +
self.batch_size]
sub_res = self.__get_cond_prob(sub_prompt_list,
sub_choice_target_ids,
choice_ids)
for res, prompt in zip(sub_res, sub_prompt_list):
output_handler.save_prompt_and_condprob(
prompt.replace(ice[idx], ''), prompt, res, index,
choices)
for res, prompt, gold in zip(sub_res, sub_prompt_list,
sub_golds):
example_input = prompt.replace(ice[idx], '')
output_handler.save_prompt_and_condprob(example_input,
prompt,
res,
index,
choices,
gold=gold)
index = index + 1
# 5. Output

View File

@ -99,6 +99,12 @@ class GenInferencer(BaseInferencer):
ice_template=ice_template,
prompt_template=prompt_template)
# 3.1 Fetch and zip prompt & gold answer if output column exists
ds_reader = retriever.dataset_reader
if ds_reader.output_column:
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
prompt_list = list(zip(prompt_list, gold_ans))
# Create tmp json file for saving intermediate results and future
# resuming
index = 0
@ -115,7 +121,12 @@ class GenInferencer(BaseInferencer):
# 5. Inference for prompts in each batch
logger.info('Starting inference process...')
for entry in tqdm(dataloader, disable=not self.is_main_process):
for datum in tqdm(dataloader, disable=not self.is_main_process):
if ds_reader.output_column:
entry, golds = list(zip(*datum))
else:
entry = datum
golds = [None for _ in range(len(entry))]
# 5-1. Inference with local model
with torch.no_grad():
parsed_entries = self.model.parse_template(entry, mode='gen')
@ -124,8 +135,12 @@ class GenInferencer(BaseInferencer):
generated = results
# 5-3. Save current output
for prompt, prediction in zip(parsed_entries, generated):
output_handler.save_results(prompt, prediction, index)
for prompt, prediction, gold in zip(parsed_entries, generated,
golds):
output_handler.save_results(prompt,
prediction,
index,
gold=gold)
index = index + 1
# 5-4. Save intermediate results

View File

@ -200,7 +200,13 @@ class PPLInferencer(BaseInferencer):
sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
output_handler.save_predictions(sub_predictions)
# 7. Output
# 7. Fetch gold answers if exist
ds_reader = retriever.dataset_reader
if ds_reader.output_column:
golds = ds_reader.dataset['test'][ds_reader.output_column]
output_handler.save_golds(golds)
# 8. Output
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
output_handler.write_to_json(output_json_filepath,

View File

@ -105,6 +105,12 @@ class SCInferencer(BaseInferencer):
ice_template=ice_template,
prompt_template=prompt_template)
# 3.1 Fetch and zip prompt & gold answer if output column exists
ds_reader = retriever.dataset_reader
if ds_reader.output_column:
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
prompt_list = list(zip(prompt_list, gold_ans))
# Create tmp json file for saving intermediate results and future
# resuming
index = 0
@ -121,7 +127,12 @@ class SCInferencer(BaseInferencer):
# 5. Inference for prompts in each batch
logger.info('Starting inference process...')
for entry in tqdm(dataloader, disable=not self.is_main_process):
for datum in tqdm(dataloader, disable=not self.is_main_process):
if ds_reader.output_column:
entry, golds = list(zip(*datum))
else:
entry = datum
golds = [None for _ in range(len(entry))]
# TODO: add more types of CoT method
# 5-1. Inference sc_size times with local model
with torch.no_grad():
@ -137,8 +148,12 @@ class SCInferencer(BaseInferencer):
generated = sc_prediction
# 5-3. Save current output
for prompt, prediction in zip(parsed_entries, generated):
output_handler.save_results(prompt, prediction, index)
for prompt, prediction, gold in zip(parsed_entries, generated,
golds):
output_handler.save_results(prompt,
prediction,
index,
gold=gold)
index = index + 1
# 5-4. Save intermediate results

View File

@ -333,6 +333,12 @@ class ToTInferencer(GenInferencer):
ice_template=ice_template,
prompt_template=prompt_template)
# 3.1 Fetch and zip prompt & gold answer if output column exists
ds_reader = retriever.dataset_reader
if ds_reader.output_column:
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
prompt_list = list(zip(prompt_list, gold_ans))
# Create tmp json file for saving intermediate results and future
# resuming
index = 0
@ -349,15 +355,24 @@ class ToTInferencer(GenInferencer):
# 5. Inference for prompts in each batch
logger.info('Starting ToT inference process...')
for entries in tqdm(dataloader, disable=not self.is_main_process):
for datum in tqdm(dataloader, disable=not self.is_main_process):
if ds_reader.output_column:
entries, golds = list(zip(*datum))
else:
entries = datum
golds = [None for _ in range(len(entries))]
# 5-1. Inference with ToT and local model
with torch.no_grad():
parsed_entries = self.model.parse_template(entries, mode='gen')
generated = [self.tot_solve(entry) for entry in entries]
# 5-2. Save current output
for prompt, prediction in zip(parsed_entries, generated):
output_handler.save_results(prompt, prediction, index)
for prompt, prediction, gold in zip(parsed_entries, generated,
golds):
output_handler.save_results(prompt,
prediction,
index,
gold=gold)
index = index + 1
# 5-3. Save intermediate results