mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Fix] fix clp potential error and support bs>1 (#439)
* [Fix] fix clp potential error and support bs>1 * [Fix] fix clp potential error and support bs>1 * minor fix * minor fix
This commit is contained in:
parent
3bb3d330eb
commit
d9f3e88dfe
@ -119,7 +119,7 @@ class CLPInferencer(BaseInferencer):
|
|||||||
if self.single_token:
|
if self.single_token:
|
||||||
index = 0
|
index = 0
|
||||||
prompt_list = []
|
prompt_list = []
|
||||||
choice_target_ids = []
|
target_pos = []
|
||||||
# TODO: Hard code temperaily, need to modified here
|
# TODO: Hard code temperaily, need to modified here
|
||||||
choices = retriever.test_ds[0]['choices']
|
choices = retriever.test_ds[0]['choices']
|
||||||
try:
|
try:
|
||||||
@ -142,6 +142,13 @@ class CLPInferencer(BaseInferencer):
|
|||||||
|
|
||||||
get_token_len = self.model.get_token_len
|
get_token_len = self.model.get_token_len
|
||||||
|
|
||||||
|
if hasattr(self.model.tokenizer, 'padding_side'):
|
||||||
|
# get padding_side for huggingface model
|
||||||
|
padding_side = self.model.tokenizer.padding_side
|
||||||
|
else:
|
||||||
|
# defaults to left for internal model
|
||||||
|
padding_side = 'left'
|
||||||
|
|
||||||
# prepare in context for each example and control the length
|
# prepare in context for each example and control the length
|
||||||
for idx in range(len(ice_idx_list)):
|
for idx in range(len(ice_idx_list)):
|
||||||
prompt = retriever.generate_prompt_for_generate_task(
|
prompt = retriever.generate_prompt_for_generate_task(
|
||||||
@ -149,7 +156,7 @@ class CLPInferencer(BaseInferencer):
|
|||||||
ice[idx],
|
ice[idx],
|
||||||
ice_template=ice_template,
|
ice_template=ice_template,
|
||||||
prompt_template=prompt_template)
|
prompt_template=prompt_template)
|
||||||
prompt = self.model.parse_template(prompt, mode='ppl')
|
prompt = self.model.parse_template(prompt, mode='gen')
|
||||||
if self.max_seq_len is not None:
|
if self.max_seq_len is not None:
|
||||||
prompt_token_num = get_token_len(prompt)
|
prompt_token_num = get_token_len(prompt)
|
||||||
# add one because additional token will be added in the end
|
# add one because additional token will be added in the end
|
||||||
@ -165,15 +172,19 @@ class CLPInferencer(BaseInferencer):
|
|||||||
ice_template=ice_template,
|
ice_template=ice_template,
|
||||||
prompt_template=prompt_template)
|
prompt_template=prompt_template)
|
||||||
prompt_token_num = get_token_len(prompt)
|
prompt_token_num = get_token_len(prompt)
|
||||||
# Add single token for prompt, this token can be any token
|
|
||||||
prompt += 'yes'
|
|
||||||
prompt_list.append(prompt)
|
prompt_list.append(prompt)
|
||||||
# in case prompt token num reaches
|
# in case prompt token num reaches max
|
||||||
if self.max_seq_len is not None and \
|
if self.max_seq_len is not None and \
|
||||||
prompt_token_num + 1 > self.max_seq_len:
|
prompt_token_num + 1 > self.max_seq_len:
|
||||||
prompt_token_num = self.max_seq_len - 1
|
prompt_token_num = self.max_seq_len - 1
|
||||||
# minus the bos token
|
|
||||||
choice_target_ids.append(prompt_token_num - 1)
|
# get the target position index
|
||||||
|
if padding_side == 'left':
|
||||||
|
# always the last position
|
||||||
|
target_pos.append(-1)
|
||||||
|
else:
|
||||||
|
# the last position of the original prompt
|
||||||
|
target_pos.append(prompt_token_num - 1)
|
||||||
|
|
||||||
# 4.1 Fetch and zip prompt & gold answer if output column exists
|
# 4.1 Fetch and zip prompt & gold answer if output column exists
|
||||||
ds_reader = retriever.dataset_reader
|
ds_reader = retriever.dataset_reader
|
||||||
@ -182,19 +193,36 @@ class CLPInferencer(BaseInferencer):
|
|||||||
else:
|
else:
|
||||||
gold_ans = [None] * len(prompt_list)
|
gold_ans = [None] * len(prompt_list)
|
||||||
|
|
||||||
|
if hasattr(self.model, 'batch_padding'):
|
||||||
|
# get batch padding for huggingface model
|
||||||
|
batch_padding = self.model.batch_padding
|
||||||
|
else:
|
||||||
|
# defaults to False for internal model
|
||||||
|
batch_padding = False
|
||||||
|
|
||||||
logger.info('Calculating conditional log probability for prompts.')
|
logger.info('Calculating conditional log probability for prompts.')
|
||||||
for idx in trange(0,
|
for idx in trange(0,
|
||||||
len(prompt_list),
|
len(prompt_list),
|
||||||
self.batch_size,
|
self.batch_size,
|
||||||
disable=not self.is_main_process):
|
disable=not self.is_main_process):
|
||||||
|
# get batch data
|
||||||
sub_prompt_list = prompt_list[idx:idx + self.batch_size]
|
sub_prompt_list = prompt_list[idx:idx + self.batch_size]
|
||||||
sub_golds = gold_ans[idx:idx + self.batch_size]
|
sub_golds = gold_ans[idx:idx + self.batch_size]
|
||||||
sub_choice_target_ids = choice_target_ids[idx:idx +
|
sub_target_pos = target_pos[idx:idx + self.batch_size]
|
||||||
self.batch_size]
|
|
||||||
sub_res = self.__get_cond_prob(sub_prompt_list,
|
|
||||||
sub_choice_target_ids,
|
|
||||||
choice_ids)
|
|
||||||
|
|
||||||
|
# get probability result
|
||||||
|
if batch_padding and self.batch_size > 1:
|
||||||
|
sub_res = self._get_cond_prob(sub_prompt_list,
|
||||||
|
sub_target_pos, choice_ids)
|
||||||
|
else:
|
||||||
|
sub_res = []
|
||||||
|
for prompt, position in zip(sub_prompt_list,
|
||||||
|
sub_target_pos):
|
||||||
|
sub_res.extend(
|
||||||
|
self._get_cond_prob([prompt], [position],
|
||||||
|
choice_ids))
|
||||||
|
|
||||||
|
# save all the result
|
||||||
for res, prompt, gold in zip(sub_res, sub_prompt_list,
|
for res, prompt, gold in zip(sub_res, sub_prompt_list,
|
||||||
sub_golds):
|
sub_golds):
|
||||||
example_input = prompt.replace(ice[idx], '')
|
example_input = prompt.replace(ice[idx], '')
|
||||||
@ -217,22 +245,29 @@ class CLPInferencer(BaseInferencer):
|
|||||||
for sample in output_handler.results_dict.values()
|
for sample in output_handler.results_dict.values()
|
||||||
]
|
]
|
||||||
|
|
||||||
def __get_cond_prob(self,
|
def _get_cond_prob(self, input_texts: List[str], target_pos: List[int],
|
||||||
input_texts: List[str],
|
choice_ids: List[int]):
|
||||||
sub_choice_target_ids,
|
"""Get the condition probability of next token.
|
||||||
choice_ids,
|
|
||||||
mask_length=None):
|
Args:
|
||||||
# TODO: support multiple tokens
|
input_texts (List[str]): All the input prompt to be tested.
|
||||||
|
target_pos (List[int]): Target position of next token.
|
||||||
|
choice_ids (List[int]): Choice ids of target tokens.
|
||||||
|
"""
|
||||||
if hasattr(self.model, 'generator'):
|
if hasattr(self.model, 'generator'):
|
||||||
outputs, _ = self.model.generator.get_logits(input_texts)
|
get_logits = self.model.generator.get_logits
|
||||||
else:
|
else:
|
||||||
outputs, _ = self.model.get_logits(input_texts)
|
get_logits = self.model.get_logits
|
||||||
|
|
||||||
shift_logits = outputs[..., :-1, :].contiguous().float()
|
outputs, _ = get_logits(input_texts)
|
||||||
|
|
||||||
shift_logits = F.log_softmax(shift_logits, dim=-1)
|
# we want get the next token probability
|
||||||
|
# therefore no shift here
|
||||||
|
logits = outputs.contiguous().float()
|
||||||
|
|
||||||
|
logits = F.log_softmax(logits, dim=-1)
|
||||||
log_probs = []
|
log_probs = []
|
||||||
for logits, target_ids in zip(shift_logits, sub_choice_target_ids):
|
for logit, target_ids in zip(logits, target_pos):
|
||||||
log_probs.append(
|
log_probs.append(
|
||||||
F.softmax(logits[target_ids, choice_ids], dim=-1).tolist())
|
F.softmax(logit[target_ids, choice_ids], dim=-1).tolist())
|
||||||
return log_probs
|
return log_probs
|
||||||
|
Loading…
Reference in New Issue
Block a user