diff --git a/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py b/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py index 9b266125..5369c9e3 100644 --- a/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py @@ -5,24 +5,22 @@ import os from functools import partial from typing import List, Optional -import torch import torch.nn.functional as F -from accelerate import Accelerator from tqdm import trange from opencompass.models import BaseModel -from opencompass.openicl import PromptTemplate -from opencompass.openicl.icl_inferencer.icl_base_inferencer import \ - PPLInferencerOutputHandler -from opencompass.openicl.icl_retriever import BaseRetriever -from opencompass.openicl.utils.logging import get_logger from opencompass.registry import ICL_INFERENCERS +from ..icl_prompt_template import PromptTemplate +from ..icl_retriever import BaseRetriever +from ..utils import get_logger +from .icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler + logger = get_logger(__name__) @ICL_INFERENCERS.register_module() -class CLPInferencer: +class CLPInferencer(BaseInferencer): """Conditional log probability based In-context Learning Inferencer. Calculate the log probability of each choices according the logits. @@ -42,8 +40,6 @@ class CLPInferencer: max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by the LM. batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader` - accelerator (:obj:`Accelerator`, optional): An instance of the - `Accelerator` class, used for multiprocessing. output_json_filepath (:obj:`str`, optional): File path for output `JSON` file. output_json_filename (:obj:`str`, optional): File name for output @@ -57,29 +53,20 @@ class CLPInferencer: model: BaseModel, max_seq_len: Optional[int] = None, batch_size: Optional[int] = 1, - accelerator: Optional[Accelerator] = None, output_json_filepath: Optional[str] = './icl_inference_output', output_json_filename: Optional[str] = 'predictions', fix_id_list: Optional[List[int]] = None, single_token: bool = True, **kwargs) -> None: + super().__init__( + model=model, + max_seq_len=max_seq_len, + batch_size=batch_size, + output_json_filename=output_json_filename, + output_json_filepath=output_json_filepath, + **kwargs, + ) - self.model = model - - self.accelerator = accelerator - self.is_main_process = (True if self.accelerator is None - or self.accelerator.is_main_process else False) - - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - if self.model is not None: - self.model.to(self.device) - - self.max_seq_len = max_seq_len - self.batch_size = batch_size - self.output_json_filepath = output_json_filepath - self.output_json_filename = output_json_filename - if not os.path.exists(self.output_json_filepath): - os.makedirs(self.output_json_filepath) self.fix_id_list = fix_id_list # TODO: support multiple token assert single_token, 'Only support single token choice currently.' @@ -111,8 +98,8 @@ class CLPInferencer: # 3. Generate in-context examples for testing inputs for idx in range(len(ice_idx_list)): ice.append( - retriever.generate_ice(ice_idx_list[idx], - ice_template=ice_template)) + retriever.generate_ice( + ice_idx_list[idx], ice_template=ice_template)) output_handler.save_ice(ice) # 4. Collect prompts and calculate conditional log probs @@ -129,10 +116,13 @@ class CLPInferencer: ] except ValueError: choice_ids = [self.model.tokenizer.encode(c) for c in choices] - if self.model.tokenizer.add_bos_token: - choice_ids = [c[1:] for c in choice_ids] - if self.model.tokenizer.add_eos_token: - choice_ids = [c[:-1] for c in choice_ids] + if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer': # noqa + choice_ids = [c[2:] for c in choice_ids] + else: + if self.model.tokenizer.add_bos_token: + choice_ids = [c[1:] for c in choice_ids] + if self.model.tokenizer.add_eos_token: + choice_ids = [c[:-1] for c in choice_ids] if isinstance(choice_ids[0], list): # in case tokenizer returns list for single token choice_ids = list(itertools.chain(*choice_ids)) @@ -175,10 +165,11 @@ class CLPInferencer: choice_target_ids.append(prompt_token_num - 1) logger.info('Calculating conditional log probability for prompts.') - for idx in trange(0, - len(prompt_list), - self.batch_size, - disable=not self.is_main_process): + for idx in trange( + 0, + len(prompt_list), + self.batch_size, + disable=not self.is_main_process): sub_prompt_list = prompt_list[idx:idx + self.batch_size] sub_choice_target_ids = choice_target_ids[idx:idx + self.batch_size] @@ -209,10 +200,11 @@ class CLPInferencer: choice_ids, mask_length=None): # TODO: support multiple tokens - try: + if hasattr(self.model, 'generator'): outputs, _ = self.model.generator.get_logits(input_texts) - except AttributeError: + else: outputs, _ = self.model.get_logits(input_texts) + shift_logits = outputs[..., :-1, :].contiguous() shift_logits = F.log_softmax(shift_logits, dim=-1)