[Fix] fix clp inferencer (#44)

This commit is contained in:
Hubert 2023-07-11 14:54:39 +08:00 committed by GitHub
parent 50b658d234
commit c8f1d513b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,24 +5,22 @@ import os
from functools import partial
from typing import List, Optional
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from tqdm import trange
from opencompass.models import BaseModel
from opencompass.openicl import PromptTemplate
from opencompass.openicl.icl_inferencer.icl_base_inferencer import \
PPLInferencerOutputHandler
from opencompass.openicl.icl_retriever import BaseRetriever
from opencompass.openicl.utils.logging import get_logger
from opencompass.registry import ICL_INFERENCERS
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils import get_logger
from .icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler
logger = get_logger(__name__)
@ICL_INFERENCERS.register_module()
class CLPInferencer:
class CLPInferencer(BaseInferencer):
"""Conditional log probability based In-context Learning Inferencer.
Calculate the log probability of each choices according the logits.
@ -42,8 +40,6 @@ class CLPInferencer:
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by
the LM.
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`
accelerator (:obj:`Accelerator`, optional): An instance of the
`Accelerator` class, used for multiprocessing.
output_json_filepath (:obj:`str`, optional): File path for output
`JSON` file.
output_json_filename (:obj:`str`, optional): File name for output
@ -57,29 +53,20 @@ class CLPInferencer:
model: BaseModel,
max_seq_len: Optional[int] = None,
batch_size: Optional[int] = 1,
accelerator: Optional[Accelerator] = None,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
fix_id_list: Optional[List[int]] = None,
single_token: bool = True,
**kwargs) -> None:
super().__init__(
model=model,
max_seq_len=max_seq_len,
batch_size=batch_size,
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
**kwargs,
)
self.model = model
self.accelerator = accelerator
self.is_main_process = (True if self.accelerator is None
or self.accelerator.is_main_process else False)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if self.model is not None:
self.model.to(self.device)
self.max_seq_len = max_seq_len
self.batch_size = batch_size
self.output_json_filepath = output_json_filepath
self.output_json_filename = output_json_filename
if not os.path.exists(self.output_json_filepath):
os.makedirs(self.output_json_filepath)
self.fix_id_list = fix_id_list
# TODO: support multiple token
assert single_token, 'Only support single token choice currently.'
@ -111,8 +98,8 @@ class CLPInferencer:
# 3. Generate in-context examples for testing inputs
for idx in range(len(ice_idx_list)):
ice.append(
retriever.generate_ice(ice_idx_list[idx],
ice_template=ice_template))
retriever.generate_ice(
ice_idx_list[idx], ice_template=ice_template))
output_handler.save_ice(ice)
# 4. Collect prompts and calculate conditional log probs
@ -129,6 +116,9 @@ class CLPInferencer:
]
except ValueError:
choice_ids = [self.model.tokenizer.encode(c) for c in choices]
if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer': # noqa
choice_ids = [c[2:] for c in choice_ids]
else:
if self.model.tokenizer.add_bos_token:
choice_ids = [c[1:] for c in choice_ids]
if self.model.tokenizer.add_eos_token:
@ -175,7 +165,8 @@ class CLPInferencer:
choice_target_ids.append(prompt_token_num - 1)
logger.info('Calculating conditional log probability for prompts.')
for idx in trange(0,
for idx in trange(
0,
len(prompt_list),
self.batch_size,
disable=not self.is_main_process):
@ -209,10 +200,11 @@ class CLPInferencer:
choice_ids,
mask_length=None):
# TODO: support multiple tokens
try:
if hasattr(self.model, 'generator'):
outputs, _ = self.model.generator.get_logits(input_texts)
except AttributeError:
else:
outputs, _ = self.model.get_logits(input_texts)
shift_logits = outputs[..., :-1, :].contiguous()
shift_logits = F.log_softmax(shift_logits, dim=-1)