mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Fix] fix clp inferencer (#44)
This commit is contained in:
parent
50b658d234
commit
c8f1d513b2
@ -5,24 +5,22 @@ import os
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from tqdm import trange
|
||||
|
||||
from opencompass.models import BaseModel
|
||||
from opencompass.openicl import PromptTemplate
|
||||
from opencompass.openicl.icl_inferencer.icl_base_inferencer import \
|
||||
PPLInferencerOutputHandler
|
||||
from opencompass.openicl.icl_retriever import BaseRetriever
|
||||
from opencompass.openicl.utils.logging import get_logger
|
||||
from opencompass.registry import ICL_INFERENCERS
|
||||
|
||||
from ..icl_prompt_template import PromptTemplate
|
||||
from ..icl_retriever import BaseRetriever
|
||||
from ..utils import get_logger
|
||||
from .icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ICL_INFERENCERS.register_module()
|
||||
class CLPInferencer:
|
||||
class CLPInferencer(BaseInferencer):
|
||||
"""Conditional log probability based In-context Learning Inferencer.
|
||||
|
||||
Calculate the log probability of each choices according the logits.
|
||||
@ -42,8 +40,6 @@ class CLPInferencer:
|
||||
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by
|
||||
the LM.
|
||||
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`
|
||||
accelerator (:obj:`Accelerator`, optional): An instance of the
|
||||
`Accelerator` class, used for multiprocessing.
|
||||
output_json_filepath (:obj:`str`, optional): File path for output
|
||||
`JSON` file.
|
||||
output_json_filename (:obj:`str`, optional): File name for output
|
||||
@ -57,29 +53,20 @@ class CLPInferencer:
|
||||
model: BaseModel,
|
||||
max_seq_len: Optional[int] = None,
|
||||
batch_size: Optional[int] = 1,
|
||||
accelerator: Optional[Accelerator] = None,
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
fix_id_list: Optional[List[int]] = None,
|
||||
single_token: bool = True,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
model=model,
|
||||
max_seq_len=max_seq_len,
|
||||
batch_size=batch_size,
|
||||
output_json_filename=output_json_filename,
|
||||
output_json_filepath=output_json_filepath,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.model = model
|
||||
|
||||
self.accelerator = accelerator
|
||||
self.is_main_process = (True if self.accelerator is None
|
||||
or self.accelerator.is_main_process else False)
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
if self.model is not None:
|
||||
self.model.to(self.device)
|
||||
|
||||
self.max_seq_len = max_seq_len
|
||||
self.batch_size = batch_size
|
||||
self.output_json_filepath = output_json_filepath
|
||||
self.output_json_filename = output_json_filename
|
||||
if not os.path.exists(self.output_json_filepath):
|
||||
os.makedirs(self.output_json_filepath)
|
||||
self.fix_id_list = fix_id_list
|
||||
# TODO: support multiple token
|
||||
assert single_token, 'Only support single token choice currently.'
|
||||
@ -111,8 +98,8 @@ class CLPInferencer:
|
||||
# 3. Generate in-context examples for testing inputs
|
||||
for idx in range(len(ice_idx_list)):
|
||||
ice.append(
|
||||
retriever.generate_ice(ice_idx_list[idx],
|
||||
ice_template=ice_template))
|
||||
retriever.generate_ice(
|
||||
ice_idx_list[idx], ice_template=ice_template))
|
||||
output_handler.save_ice(ice)
|
||||
|
||||
# 4. Collect prompts and calculate conditional log probs
|
||||
@ -129,6 +116,9 @@ class CLPInferencer:
|
||||
]
|
||||
except ValueError:
|
||||
choice_ids = [self.model.tokenizer.encode(c) for c in choices]
|
||||
if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer': # noqa
|
||||
choice_ids = [c[2:] for c in choice_ids]
|
||||
else:
|
||||
if self.model.tokenizer.add_bos_token:
|
||||
choice_ids = [c[1:] for c in choice_ids]
|
||||
if self.model.tokenizer.add_eos_token:
|
||||
@ -175,7 +165,8 @@ class CLPInferencer:
|
||||
choice_target_ids.append(prompt_token_num - 1)
|
||||
|
||||
logger.info('Calculating conditional log probability for prompts.')
|
||||
for idx in trange(0,
|
||||
for idx in trange(
|
||||
0,
|
||||
len(prompt_list),
|
||||
self.batch_size,
|
||||
disable=not self.is_main_process):
|
||||
@ -209,10 +200,11 @@ class CLPInferencer:
|
||||
choice_ids,
|
||||
mask_length=None):
|
||||
# TODO: support multiple tokens
|
||||
try:
|
||||
if hasattr(self.model, 'generator'):
|
||||
outputs, _ = self.model.generator.get_logits(input_texts)
|
||||
except AttributeError:
|
||||
else:
|
||||
outputs, _ = self.model.get_logits(input_texts)
|
||||
|
||||
shift_logits = outputs[..., :-1, :].contiguous()
|
||||
|
||||
shift_logits = F.log_softmax(shift_logits, dim=-1)
|
||||
|
Loading…
Reference in New Issue
Block a user