mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Log gold answer in prediction output (#419)
* [Feature] Log gold answer in prediction output * support clp golden ans * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com>
This commit is contained in:
parent
97fdc51102
commit
681d3013de
@ -1,4 +1,4 @@
|
|||||||
from abc import abstractclassmethod
|
from abc import abstractmethod
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -37,7 +37,7 @@ class BaseModel:
|
|||||||
if meta_template and 'eos_token_id' in meta_template:
|
if meta_template and 'eos_token_id' in meta_template:
|
||||||
self.eos_token_id = meta_template['eos_token_id']
|
self.eos_token_id = meta_template['eos_token_id']
|
||||||
|
|
||||||
@abstractclassmethod
|
@abstractmethod
|
||||||
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
||||||
"""Generate results given a list of inputs.
|
"""Generate results given a list of inputs.
|
||||||
|
|
||||||
@ -48,8 +48,11 @@ class BaseModel:
|
|||||||
Returns:
|
Returns:
|
||||||
List[str]: A list of generated strings.
|
List[str]: A list of generated strings.
|
||||||
"""
|
"""
|
||||||
|
raise NotImplementedError(f'{self.__class__.__name__} does not support'
|
||||||
|
' gen-based evaluation yet, try ppl-based '
|
||||||
|
'instead.')
|
||||||
|
|
||||||
@abstractclassmethod
|
@abstractmethod
|
||||||
def get_ppl(self,
|
def get_ppl(self,
|
||||||
inputs: List[str],
|
inputs: List[str],
|
||||||
mask_length: Optional[List[int]] = None) -> List[float]:
|
mask_length: Optional[List[int]] = None) -> List[float]:
|
||||||
@ -66,8 +69,11 @@ class BaseModel:
|
|||||||
Returns:
|
Returns:
|
||||||
List[float]: A list of perplexity scores.
|
List[float]: A list of perplexity scores.
|
||||||
"""
|
"""
|
||||||
|
raise NotImplementedError(f'{self.__class__.__name__} does not support'
|
||||||
|
' ppl-based evaluation yet, try gen-based '
|
||||||
|
'instead.')
|
||||||
|
|
||||||
@abstractclassmethod
|
@abstractmethod
|
||||||
def get_token_len(self, prompt: str) -> int:
|
def get_token_len(self, prompt: str) -> int:
|
||||||
"""Get lengths of the tokenized strings.
|
"""Get lengths of the tokenized strings.
|
||||||
|
|
||||||
@ -192,7 +198,7 @@ class LMTemplateParser:
|
|||||||
Returns:
|
Returns:
|
||||||
str: The final string.
|
str: The final string.
|
||||||
"""
|
"""
|
||||||
assert isinstance(prompt_template, (str, list, PromptList))
|
assert isinstance(prompt_template, (str, list, PromptList, tuple))
|
||||||
if not isinstance(prompt_template, (str, PromptList)):
|
if not isinstance(prompt_template, (str, PromptList)):
|
||||||
return [self.parse_template(p, mode=mode) for p in prompt_template]
|
return [self.parse_template(p, mode=mode) for p in prompt_template]
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
from abc import abstractclassmethod
|
from abc import abstractmethod
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
@ -46,7 +46,7 @@ class BaseAPIModel(BaseModel):
|
|||||||
self.template_parser = APITemplateParser(meta_template)
|
self.template_parser = APITemplateParser(meta_template)
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
|
|
||||||
@abstractclassmethod
|
@abstractmethod
|
||||||
def generate(self, inputs: List[PromptType],
|
def generate(self, inputs: List[PromptType],
|
||||||
max_out_len: int) -> List[str]:
|
max_out_len: int) -> List[str]:
|
||||||
"""Generate results given a list of inputs.
|
"""Generate results given a list of inputs.
|
||||||
@ -60,8 +60,11 @@ class BaseAPIModel(BaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
List[str]: A list of generated strings.
|
List[str]: A list of generated strings.
|
||||||
"""
|
"""
|
||||||
|
raise NotImplementedError(f'{self.__class__.__name__} does not support'
|
||||||
|
' gen-based evaluation yet, try ppl-based '
|
||||||
|
'instead.')
|
||||||
|
|
||||||
@abstractclassmethod
|
@abstractmethod
|
||||||
def get_ppl(self,
|
def get_ppl(self,
|
||||||
inputs: List[PromptType],
|
inputs: List[PromptType],
|
||||||
mask_length: Optional[List[int]] = None) -> List[float]:
|
mask_length: Optional[List[int]] = None) -> List[float]:
|
||||||
@ -78,6 +81,9 @@ class BaseAPIModel(BaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
List[float]: A list of perplexity scores.
|
List[float]: A list of perplexity scores.
|
||||||
"""
|
"""
|
||||||
|
raise NotImplementedError(f'{self.__class__.__name__} does not support'
|
||||||
|
' ppl-based evaluation yet, try gen-based '
|
||||||
|
'instead.')
|
||||||
|
|
||||||
def get_token_len(self, prompt: str) -> int:
|
def get_token_len(self, prompt: str) -> int:
|
||||||
"""Get lengths of the tokenized string. Only English and Chinese
|
"""Get lengths of the tokenized string. Only English and Chinese
|
||||||
@ -161,7 +167,7 @@ class APITemplateParser:
|
|||||||
Returns:
|
Returns:
|
||||||
List[str or PromptList]: The finalized prompt or a conversation.
|
List[str or PromptList]: The finalized prompt or a conversation.
|
||||||
"""
|
"""
|
||||||
assert isinstance(prompt_template, (str, list, PromptList))
|
assert isinstance(prompt_template, (str, list, PromptList, tuple))
|
||||||
|
|
||||||
if not isinstance(prompt_template, (str, PromptList)):
|
if not isinstance(prompt_template, (str, PromptList)):
|
||||||
return [self.parse_template(p, mode=mode) for p in prompt_template]
|
return [self.parse_template(p, mode=mode) for p in prompt_template]
|
||||||
|
@ -108,6 +108,12 @@ class AttackInferencer(BaseInferencer):
|
|||||||
ice_template=self.ice_template,
|
ice_template=self.ice_template,
|
||||||
prompt_template=self.prompt_template)
|
prompt_template=self.prompt_template)
|
||||||
|
|
||||||
|
# 3.1 Fetch and zip prompt & gold answer if output column exists
|
||||||
|
ds_reader = self.retriever.dataset_reader
|
||||||
|
if ds_reader.output_column:
|
||||||
|
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
|
||||||
|
prompt_list = list(zip(prompt_list, gold_ans))
|
||||||
|
|
||||||
# Create tmp json file for saving intermediate results and future
|
# Create tmp json file for saving intermediate results and future
|
||||||
# resuming
|
# resuming
|
||||||
index = 0
|
index = 0
|
||||||
@ -124,7 +130,12 @@ class AttackInferencer(BaseInferencer):
|
|||||||
|
|
||||||
# 5. Inference for prompts in each batch
|
# 5. Inference for prompts in each batch
|
||||||
logger.info('Starting inference process...')
|
logger.info('Starting inference process...')
|
||||||
for entry in tqdm(dataloader, disable=not self.is_main_process):
|
for datum in tqdm(dataloader, disable=not self.is_main_process):
|
||||||
|
if ds_reader.output_column:
|
||||||
|
entry, golds = list(zip(*datum))
|
||||||
|
else:
|
||||||
|
entry = datum
|
||||||
|
golds = [None for _ in range(len(entry))]
|
||||||
# 5-1. Inference with local model
|
# 5-1. Inference with local model
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
parsed_entries = self.model.parse_template(entry, mode='gen')
|
parsed_entries = self.model.parse_template(entry, mode='gen')
|
||||||
@ -133,8 +144,12 @@ class AttackInferencer(BaseInferencer):
|
|||||||
generated = results
|
generated = results
|
||||||
|
|
||||||
# 5-3. Save current output
|
# 5-3. Save current output
|
||||||
for prompt, prediction in zip(parsed_entries, generated):
|
for prompt, prediction, gold in zip(parsed_entries, generated,
|
||||||
output_handler.save_results(prompt, prediction, index)
|
golds):
|
||||||
|
output_handler.save_results(prompt,
|
||||||
|
prediction,
|
||||||
|
index,
|
||||||
|
gold=gold)
|
||||||
index = index + 1
|
index = index + 1
|
||||||
|
|
||||||
# 5-4. Save intermediate results
|
# 5-4. Save intermediate results
|
||||||
|
@ -108,11 +108,13 @@ class GenInferencerOutputHandler:
|
|||||||
"""Dump the result to a json file."""
|
"""Dump the result to a json file."""
|
||||||
dump_results_dict(self.results_dict, Path(save_dir) / filename)
|
dump_results_dict(self.results_dict, Path(save_dir) / filename)
|
||||||
|
|
||||||
def save_results(self, origin_prompt, prediction, idx):
|
def save_results(self, origin_prompt, prediction, idx, gold=None):
|
||||||
self.results_dict[str(idx)] = {
|
self.results_dict[str(idx)] = {
|
||||||
'origin_prompt': origin_prompt,
|
'origin_prompt': origin_prompt,
|
||||||
'prediction': prediction,
|
'prediction': prediction,
|
||||||
}
|
}
|
||||||
|
if gold:
|
||||||
|
self.results_dict[str(idx)]['gold'] = gold
|
||||||
|
|
||||||
|
|
||||||
class PPLInferencerOutputHandler:
|
class PPLInferencerOutputHandler:
|
||||||
@ -147,6 +149,12 @@ class PPLInferencerOutputHandler:
|
|||||||
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
|
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
|
||||||
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl
|
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl
|
||||||
|
|
||||||
|
def save_golds(self, golds):
|
||||||
|
for idx, gold in enumerate(golds):
|
||||||
|
if str(idx) not in self.results_dict.keys():
|
||||||
|
self.results_dict[str(idx)] = {}
|
||||||
|
self.results_dict[str(idx)]['gold'] = gold
|
||||||
|
|
||||||
|
|
||||||
class CLPInferencerOutputHandler:
|
class CLPInferencerOutputHandler:
|
||||||
results_dict = {}
|
results_dict = {}
|
||||||
@ -164,7 +172,13 @@ class CLPInferencerOutputHandler:
|
|||||||
self.results_dict[str(idx)] = {}
|
self.results_dict[str(idx)] = {}
|
||||||
self.results_dict[str(idx)]['in-context examples'] = example
|
self.results_dict[str(idx)]['in-context examples'] = example
|
||||||
|
|
||||||
def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices):
|
def save_prompt_and_condprob(self,
|
||||||
|
input,
|
||||||
|
prompt,
|
||||||
|
cond_prob,
|
||||||
|
idx,
|
||||||
|
choices,
|
||||||
|
gold=None):
|
||||||
if str(idx) not in self.results_dict.keys():
|
if str(idx) not in self.results_dict.keys():
|
||||||
self.results_dict[str(idx)] = {}
|
self.results_dict[str(idx)] = {}
|
||||||
# TODO:
|
# TODO:
|
||||||
@ -177,3 +191,4 @@ class CLPInferencerOutputHandler:
|
|||||||
self.results_dict[str(idx)]['prediction'] = cond_prob
|
self.results_dict[str(idx)]['prediction'] = cond_prob
|
||||||
# set pred label in case needed
|
# set pred label in case needed
|
||||||
self.results_dict[str(idx)]['pred_label'] = int(np.argmax(cond_prob))
|
self.results_dict[str(idx)]['pred_label'] = int(np.argmax(cond_prob))
|
||||||
|
self.results_dict[str(idx)]['gold'] = gold
|
||||||
|
@ -175,22 +175,35 @@ class CLPInferencer(BaseInferencer):
|
|||||||
# minus the bos token
|
# minus the bos token
|
||||||
choice_target_ids.append(prompt_token_num - 1)
|
choice_target_ids.append(prompt_token_num - 1)
|
||||||
|
|
||||||
|
# 4.1 Fetch and zip prompt & gold answer if output column exists
|
||||||
|
ds_reader = retriever.dataset_reader
|
||||||
|
if ds_reader.output_column:
|
||||||
|
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
|
||||||
|
else:
|
||||||
|
gold_ans = [None] * len(prompt_list)
|
||||||
|
|
||||||
logger.info('Calculating conditional log probability for prompts.')
|
logger.info('Calculating conditional log probability for prompts.')
|
||||||
for idx in trange(0,
|
for idx in trange(0,
|
||||||
len(prompt_list),
|
len(prompt_list),
|
||||||
self.batch_size,
|
self.batch_size,
|
||||||
disable=not self.is_main_process):
|
disable=not self.is_main_process):
|
||||||
sub_prompt_list = prompt_list[idx:idx + self.batch_size]
|
sub_prompt_list = prompt_list[idx:idx + self.batch_size]
|
||||||
|
sub_golds = gold_ans[idx:idx + self.batch_size]
|
||||||
sub_choice_target_ids = choice_target_ids[idx:idx +
|
sub_choice_target_ids = choice_target_ids[idx:idx +
|
||||||
self.batch_size]
|
self.batch_size]
|
||||||
sub_res = self.__get_cond_prob(sub_prompt_list,
|
sub_res = self.__get_cond_prob(sub_prompt_list,
|
||||||
sub_choice_target_ids,
|
sub_choice_target_ids,
|
||||||
choice_ids)
|
choice_ids)
|
||||||
|
|
||||||
for res, prompt in zip(sub_res, sub_prompt_list):
|
for res, prompt, gold in zip(sub_res, sub_prompt_list,
|
||||||
output_handler.save_prompt_and_condprob(
|
sub_golds):
|
||||||
prompt.replace(ice[idx], ''), prompt, res, index,
|
example_input = prompt.replace(ice[idx], '')
|
||||||
choices)
|
output_handler.save_prompt_and_condprob(example_input,
|
||||||
|
prompt,
|
||||||
|
res,
|
||||||
|
index,
|
||||||
|
choices,
|
||||||
|
gold=gold)
|
||||||
index = index + 1
|
index = index + 1
|
||||||
|
|
||||||
# 5. Output
|
# 5. Output
|
||||||
|
@ -99,6 +99,12 @@ class GenInferencer(BaseInferencer):
|
|||||||
ice_template=ice_template,
|
ice_template=ice_template,
|
||||||
prompt_template=prompt_template)
|
prompt_template=prompt_template)
|
||||||
|
|
||||||
|
# 3.1 Fetch and zip prompt & gold answer if output column exists
|
||||||
|
ds_reader = retriever.dataset_reader
|
||||||
|
if ds_reader.output_column:
|
||||||
|
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
|
||||||
|
prompt_list = list(zip(prompt_list, gold_ans))
|
||||||
|
|
||||||
# Create tmp json file for saving intermediate results and future
|
# Create tmp json file for saving intermediate results and future
|
||||||
# resuming
|
# resuming
|
||||||
index = 0
|
index = 0
|
||||||
@ -115,7 +121,12 @@ class GenInferencer(BaseInferencer):
|
|||||||
|
|
||||||
# 5. Inference for prompts in each batch
|
# 5. Inference for prompts in each batch
|
||||||
logger.info('Starting inference process...')
|
logger.info('Starting inference process...')
|
||||||
for entry in tqdm(dataloader, disable=not self.is_main_process):
|
for datum in tqdm(dataloader, disable=not self.is_main_process):
|
||||||
|
if ds_reader.output_column:
|
||||||
|
entry, golds = list(zip(*datum))
|
||||||
|
else:
|
||||||
|
entry = datum
|
||||||
|
golds = [None for _ in range(len(entry))]
|
||||||
# 5-1. Inference with local model
|
# 5-1. Inference with local model
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
parsed_entries = self.model.parse_template(entry, mode='gen')
|
parsed_entries = self.model.parse_template(entry, mode='gen')
|
||||||
@ -124,8 +135,12 @@ class GenInferencer(BaseInferencer):
|
|||||||
generated = results
|
generated = results
|
||||||
|
|
||||||
# 5-3. Save current output
|
# 5-3. Save current output
|
||||||
for prompt, prediction in zip(parsed_entries, generated):
|
for prompt, prediction, gold in zip(parsed_entries, generated,
|
||||||
output_handler.save_results(prompt, prediction, index)
|
golds):
|
||||||
|
output_handler.save_results(prompt,
|
||||||
|
prediction,
|
||||||
|
index,
|
||||||
|
gold=gold)
|
||||||
index = index + 1
|
index = index + 1
|
||||||
|
|
||||||
# 5-4. Save intermediate results
|
# 5-4. Save intermediate results
|
||||||
|
@ -200,7 +200,13 @@ class PPLInferencer(BaseInferencer):
|
|||||||
sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
|
sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
|
||||||
output_handler.save_predictions(sub_predictions)
|
output_handler.save_predictions(sub_predictions)
|
||||||
|
|
||||||
# 7. Output
|
# 7. Fetch gold answers if exist
|
||||||
|
ds_reader = retriever.dataset_reader
|
||||||
|
if ds_reader.output_column:
|
||||||
|
golds = ds_reader.dataset['test'][ds_reader.output_column]
|
||||||
|
output_handler.save_golds(golds)
|
||||||
|
|
||||||
|
# 8. Output
|
||||||
if self.is_main_process:
|
if self.is_main_process:
|
||||||
os.makedirs(output_json_filepath, exist_ok=True)
|
os.makedirs(output_json_filepath, exist_ok=True)
|
||||||
output_handler.write_to_json(output_json_filepath,
|
output_handler.write_to_json(output_json_filepath,
|
||||||
|
@ -105,6 +105,12 @@ class SCInferencer(BaseInferencer):
|
|||||||
ice_template=ice_template,
|
ice_template=ice_template,
|
||||||
prompt_template=prompt_template)
|
prompt_template=prompt_template)
|
||||||
|
|
||||||
|
# 3.1 Fetch and zip prompt & gold answer if output column exists
|
||||||
|
ds_reader = retriever.dataset_reader
|
||||||
|
if ds_reader.output_column:
|
||||||
|
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
|
||||||
|
prompt_list = list(zip(prompt_list, gold_ans))
|
||||||
|
|
||||||
# Create tmp json file for saving intermediate results and future
|
# Create tmp json file for saving intermediate results and future
|
||||||
# resuming
|
# resuming
|
||||||
index = 0
|
index = 0
|
||||||
@ -121,7 +127,12 @@ class SCInferencer(BaseInferencer):
|
|||||||
|
|
||||||
# 5. Inference for prompts in each batch
|
# 5. Inference for prompts in each batch
|
||||||
logger.info('Starting inference process...')
|
logger.info('Starting inference process...')
|
||||||
for entry in tqdm(dataloader, disable=not self.is_main_process):
|
for datum in tqdm(dataloader, disable=not self.is_main_process):
|
||||||
|
if ds_reader.output_column:
|
||||||
|
entry, golds = list(zip(*datum))
|
||||||
|
else:
|
||||||
|
entry = datum
|
||||||
|
golds = [None for _ in range(len(entry))]
|
||||||
# TODO: add more types of CoT method
|
# TODO: add more types of CoT method
|
||||||
# 5-1. Inference sc_size times with local model
|
# 5-1. Inference sc_size times with local model
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -137,8 +148,12 @@ class SCInferencer(BaseInferencer):
|
|||||||
generated = sc_prediction
|
generated = sc_prediction
|
||||||
|
|
||||||
# 5-3. Save current output
|
# 5-3. Save current output
|
||||||
for prompt, prediction in zip(parsed_entries, generated):
|
for prompt, prediction, gold in zip(parsed_entries, generated,
|
||||||
output_handler.save_results(prompt, prediction, index)
|
golds):
|
||||||
|
output_handler.save_results(prompt,
|
||||||
|
prediction,
|
||||||
|
index,
|
||||||
|
gold=gold)
|
||||||
index = index + 1
|
index = index + 1
|
||||||
|
|
||||||
# 5-4. Save intermediate results
|
# 5-4. Save intermediate results
|
||||||
|
@ -333,6 +333,12 @@ class ToTInferencer(GenInferencer):
|
|||||||
ice_template=ice_template,
|
ice_template=ice_template,
|
||||||
prompt_template=prompt_template)
|
prompt_template=prompt_template)
|
||||||
|
|
||||||
|
# 3.1 Fetch and zip prompt & gold answer if output column exists
|
||||||
|
ds_reader = retriever.dataset_reader
|
||||||
|
if ds_reader.output_column:
|
||||||
|
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
|
||||||
|
prompt_list = list(zip(prompt_list, gold_ans))
|
||||||
|
|
||||||
# Create tmp json file for saving intermediate results and future
|
# Create tmp json file for saving intermediate results and future
|
||||||
# resuming
|
# resuming
|
||||||
index = 0
|
index = 0
|
||||||
@ -349,15 +355,24 @@ class ToTInferencer(GenInferencer):
|
|||||||
|
|
||||||
# 5. Inference for prompts in each batch
|
# 5. Inference for prompts in each batch
|
||||||
logger.info('Starting ToT inference process...')
|
logger.info('Starting ToT inference process...')
|
||||||
for entries in tqdm(dataloader, disable=not self.is_main_process):
|
for datum in tqdm(dataloader, disable=not self.is_main_process):
|
||||||
|
if ds_reader.output_column:
|
||||||
|
entries, golds = list(zip(*datum))
|
||||||
|
else:
|
||||||
|
entries = datum
|
||||||
|
golds = [None for _ in range(len(entries))]
|
||||||
# 5-1. Inference with ToT and local model
|
# 5-1. Inference with ToT and local model
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
parsed_entries = self.model.parse_template(entries, mode='gen')
|
parsed_entries = self.model.parse_template(entries, mode='gen')
|
||||||
generated = [self.tot_solve(entry) for entry in entries]
|
generated = [self.tot_solve(entry) for entry in entries]
|
||||||
|
|
||||||
# 5-2. Save current output
|
# 5-2. Save current output
|
||||||
for prompt, prediction in zip(parsed_entries, generated):
|
for prompt, prediction, gold in zip(parsed_entries, generated,
|
||||||
output_handler.save_results(prompt, prediction, index)
|
golds):
|
||||||
|
output_handler.save_results(prompt,
|
||||||
|
prediction,
|
||||||
|
index,
|
||||||
|
gold=gold)
|
||||||
index = index + 1
|
index = index + 1
|
||||||
|
|
||||||
# 5-3. Save intermediate results
|
# 5-3. Save intermediate results
|
||||||
|
Loading…
Reference in New Issue
Block a user