mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Fix] fix clp potential error and support bs>1 (#439)
* [Fix] fix clp potential error and support bs>1 * [Fix] fix clp potential error and support bs>1 * minor fix * minor fix
This commit is contained in:
parent
3bb3d330eb
commit
d9f3e88dfe
@ -119,7 +119,7 @@ class CLPInferencer(BaseInferencer):
|
||||
if self.single_token:
|
||||
index = 0
|
||||
prompt_list = []
|
||||
choice_target_ids = []
|
||||
target_pos = []
|
||||
# TODO: Hard code temperaily, need to modified here
|
||||
choices = retriever.test_ds[0]['choices']
|
||||
try:
|
||||
@ -142,6 +142,13 @@ class CLPInferencer(BaseInferencer):
|
||||
|
||||
get_token_len = self.model.get_token_len
|
||||
|
||||
if hasattr(self.model.tokenizer, 'padding_side'):
|
||||
# get padding_side for huggingface model
|
||||
padding_side = self.model.tokenizer.padding_side
|
||||
else:
|
||||
# defaults to left for internal model
|
||||
padding_side = 'left'
|
||||
|
||||
# prepare in context for each example and control the length
|
||||
for idx in range(len(ice_idx_list)):
|
||||
prompt = retriever.generate_prompt_for_generate_task(
|
||||
@ -149,7 +156,7 @@ class CLPInferencer(BaseInferencer):
|
||||
ice[idx],
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
prompt = self.model.parse_template(prompt, mode='ppl')
|
||||
prompt = self.model.parse_template(prompt, mode='gen')
|
||||
if self.max_seq_len is not None:
|
||||
prompt_token_num = get_token_len(prompt)
|
||||
# add one because additional token will be added in the end
|
||||
@ -165,15 +172,19 @@ class CLPInferencer(BaseInferencer):
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
prompt_token_num = get_token_len(prompt)
|
||||
# Add single token for prompt, this token can be any token
|
||||
prompt += 'yes'
|
||||
prompt_list.append(prompt)
|
||||
# in case prompt token num reaches
|
||||
# in case prompt token num reaches max
|
||||
if self.max_seq_len is not None and \
|
||||
prompt_token_num + 1 > self.max_seq_len:
|
||||
prompt_token_num = self.max_seq_len - 1
|
||||
# minus the bos token
|
||||
choice_target_ids.append(prompt_token_num - 1)
|
||||
|
||||
# get the target position index
|
||||
if padding_side == 'left':
|
||||
# always the last position
|
||||
target_pos.append(-1)
|
||||
else:
|
||||
# the last position of the original prompt
|
||||
target_pos.append(prompt_token_num - 1)
|
||||
|
||||
# 4.1 Fetch and zip prompt & gold answer if output column exists
|
||||
ds_reader = retriever.dataset_reader
|
||||
@ -182,19 +193,36 @@ class CLPInferencer(BaseInferencer):
|
||||
else:
|
||||
gold_ans = [None] * len(prompt_list)
|
||||
|
||||
if hasattr(self.model, 'batch_padding'):
|
||||
# get batch padding for huggingface model
|
||||
batch_padding = self.model.batch_padding
|
||||
else:
|
||||
# defaults to False for internal model
|
||||
batch_padding = False
|
||||
|
||||
logger.info('Calculating conditional log probability for prompts.')
|
||||
for idx in trange(0,
|
||||
len(prompt_list),
|
||||
self.batch_size,
|
||||
disable=not self.is_main_process):
|
||||
# get batch data
|
||||
sub_prompt_list = prompt_list[idx:idx + self.batch_size]
|
||||
sub_golds = gold_ans[idx:idx + self.batch_size]
|
||||
sub_choice_target_ids = choice_target_ids[idx:idx +
|
||||
self.batch_size]
|
||||
sub_res = self.__get_cond_prob(sub_prompt_list,
|
||||
sub_choice_target_ids,
|
||||
choice_ids)
|
||||
sub_target_pos = target_pos[idx:idx + self.batch_size]
|
||||
|
||||
# get probability result
|
||||
if batch_padding and self.batch_size > 1:
|
||||
sub_res = self._get_cond_prob(sub_prompt_list,
|
||||
sub_target_pos, choice_ids)
|
||||
else:
|
||||
sub_res = []
|
||||
for prompt, position in zip(sub_prompt_list,
|
||||
sub_target_pos):
|
||||
sub_res.extend(
|
||||
self._get_cond_prob([prompt], [position],
|
||||
choice_ids))
|
||||
|
||||
# save all the result
|
||||
for res, prompt, gold in zip(sub_res, sub_prompt_list,
|
||||
sub_golds):
|
||||
example_input = prompt.replace(ice[idx], '')
|
||||
@ -217,22 +245,29 @@ class CLPInferencer(BaseInferencer):
|
||||
for sample in output_handler.results_dict.values()
|
||||
]
|
||||
|
||||
def __get_cond_prob(self,
|
||||
input_texts: List[str],
|
||||
sub_choice_target_ids,
|
||||
choice_ids,
|
||||
mask_length=None):
|
||||
# TODO: support multiple tokens
|
||||
def _get_cond_prob(self, input_texts: List[str], target_pos: List[int],
|
||||
choice_ids: List[int]):
|
||||
"""Get the condition probability of next token.
|
||||
|
||||
Args:
|
||||
input_texts (List[str]): All the input prompt to be tested.
|
||||
target_pos (List[int]): Target position of next token.
|
||||
choice_ids (List[int]): Choice ids of target tokens.
|
||||
"""
|
||||
if hasattr(self.model, 'generator'):
|
||||
outputs, _ = self.model.generator.get_logits(input_texts)
|
||||
get_logits = self.model.generator.get_logits
|
||||
else:
|
||||
outputs, _ = self.model.get_logits(input_texts)
|
||||
get_logits = self.model.get_logits
|
||||
|
||||
shift_logits = outputs[..., :-1, :].contiguous().float()
|
||||
outputs, _ = get_logits(input_texts)
|
||||
|
||||
shift_logits = F.log_softmax(shift_logits, dim=-1)
|
||||
# we want get the next token probability
|
||||
# therefore no shift here
|
||||
logits = outputs.contiguous().float()
|
||||
|
||||
logits = F.log_softmax(logits, dim=-1)
|
||||
log_probs = []
|
||||
for logits, target_ids in zip(shift_logits, sub_choice_target_ids):
|
||||
for logit, target_ids in zip(logits, target_pos):
|
||||
log_probs.append(
|
||||
F.softmax(logits[target_ids, choice_ids], dim=-1).tolist())
|
||||
F.softmax(logit[target_ids, choice_ids], dim=-1).tolist())
|
||||
return log_probs
|
||||
|
Loading…
Reference in New Issue
Block a user