[Fix] fix clp potential error and support bs>1 (#439)

* [Fix] fix clp potential error and support bs>1

* [Fix] fix clp potential error and support bs>1

* minor fix

* minor fix
This commit is contained in:
Hubert 2023-09-27 16:32:57 +08:00 committed by GitHub
parent 3bb3d330eb
commit d9f3e88dfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -119,7 +119,7 @@ class CLPInferencer(BaseInferencer):
if self.single_token:
index = 0
prompt_list = []
choice_target_ids = []
target_pos = []
# TODO: Hard code temperaily, need to modified here
choices = retriever.test_ds[0]['choices']
try:
@ -142,6 +142,13 @@ class CLPInferencer(BaseInferencer):
get_token_len = self.model.get_token_len
if hasattr(self.model.tokenizer, 'padding_side'):
# get padding_side for huggingface model
padding_side = self.model.tokenizer.padding_side
else:
# defaults to left for internal model
padding_side = 'left'
# prepare in context for each example and control the length
for idx in range(len(ice_idx_list)):
prompt = retriever.generate_prompt_for_generate_task(
@ -149,7 +156,7 @@ class CLPInferencer(BaseInferencer):
ice[idx],
ice_template=ice_template,
prompt_template=prompt_template)
prompt = self.model.parse_template(prompt, mode='ppl')
prompt = self.model.parse_template(prompt, mode='gen')
if self.max_seq_len is not None:
prompt_token_num = get_token_len(prompt)
# add one because additional token will be added in the end
@ -165,15 +172,19 @@ class CLPInferencer(BaseInferencer):
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = get_token_len(prompt)
# Add single token for prompt, this token can be any token
prompt += 'yes'
prompt_list.append(prompt)
# in case prompt token num reaches
# in case prompt token num reaches max
if self.max_seq_len is not None and \
prompt_token_num + 1 > self.max_seq_len:
prompt_token_num = self.max_seq_len - 1
# minus the bos token
choice_target_ids.append(prompt_token_num - 1)
# get the target position index
if padding_side == 'left':
# always the last position
target_pos.append(-1)
else:
# the last position of the original prompt
target_pos.append(prompt_token_num - 1)
# 4.1 Fetch and zip prompt & gold answer if output column exists
ds_reader = retriever.dataset_reader
@ -182,19 +193,36 @@ class CLPInferencer(BaseInferencer):
else:
gold_ans = [None] * len(prompt_list)
if hasattr(self.model, 'batch_padding'):
# get batch padding for huggingface model
batch_padding = self.model.batch_padding
else:
# defaults to False for internal model
batch_padding = False
logger.info('Calculating conditional log probability for prompts.')
for idx in trange(0,
len(prompt_list),
self.batch_size,
disable=not self.is_main_process):
# get batch data
sub_prompt_list = prompt_list[idx:idx + self.batch_size]
sub_golds = gold_ans[idx:idx + self.batch_size]
sub_choice_target_ids = choice_target_ids[idx:idx +
self.batch_size]
sub_res = self.__get_cond_prob(sub_prompt_list,
sub_choice_target_ids,
choice_ids)
sub_target_pos = target_pos[idx:idx + self.batch_size]
# get probability result
if batch_padding and self.batch_size > 1:
sub_res = self._get_cond_prob(sub_prompt_list,
sub_target_pos, choice_ids)
else:
sub_res = []
for prompt, position in zip(sub_prompt_list,
sub_target_pos):
sub_res.extend(
self._get_cond_prob([prompt], [position],
choice_ids))
# save all the result
for res, prompt, gold in zip(sub_res, sub_prompt_list,
sub_golds):
example_input = prompt.replace(ice[idx], '')
@ -217,22 +245,29 @@ class CLPInferencer(BaseInferencer):
for sample in output_handler.results_dict.values()
]
def __get_cond_prob(self,
input_texts: List[str],
sub_choice_target_ids,
choice_ids,
mask_length=None):
# TODO: support multiple tokens
def _get_cond_prob(self, input_texts: List[str], target_pos: List[int],
choice_ids: List[int]):
"""Get the condition probability of next token.
Args:
input_texts (List[str]): All the input prompt to be tested.
target_pos (List[int]): Target position of next token.
choice_ids (List[int]): Choice ids of target tokens.
"""
if hasattr(self.model, 'generator'):
outputs, _ = self.model.generator.get_logits(input_texts)
get_logits = self.model.generator.get_logits
else:
outputs, _ = self.model.get_logits(input_texts)
get_logits = self.model.get_logits
shift_logits = outputs[..., :-1, :].contiguous().float()
outputs, _ = get_logits(input_texts)
shift_logits = F.log_softmax(shift_logits, dim=-1)
# we want get the next token probability
# therefore no shift here
logits = outputs.contiguous().float()
logits = F.log_softmax(logits, dim=-1)
log_probs = []
for logits, target_ids in zip(shift_logits, sub_choice_target_ids):
for logit, target_ids in zip(logits, target_pos):
log_probs.append(
F.softmax(logits[target_ids, choice_ids], dim=-1).tolist())
F.softmax(logit[target_ids, choice_ids], dim=-1).tolist())
return log_probs