mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Fix] fix clp inferencer (#44)
This commit is contained in:
parent
50b658d234
commit
c8f1d513b2
@ -5,24 +5,22 @@ import os
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from accelerate import Accelerator
|
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from opencompass.models import BaseModel
|
from opencompass.models import BaseModel
|
||||||
from opencompass.openicl import PromptTemplate
|
|
||||||
from opencompass.openicl.icl_inferencer.icl_base_inferencer import \
|
|
||||||
PPLInferencerOutputHandler
|
|
||||||
from opencompass.openicl.icl_retriever import BaseRetriever
|
|
||||||
from opencompass.openicl.utils.logging import get_logger
|
|
||||||
from opencompass.registry import ICL_INFERENCERS
|
from opencompass.registry import ICL_INFERENCERS
|
||||||
|
|
||||||
|
from ..icl_prompt_template import PromptTemplate
|
||||||
|
from ..icl_retriever import BaseRetriever
|
||||||
|
from ..utils import get_logger
|
||||||
|
from .icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ICL_INFERENCERS.register_module()
|
@ICL_INFERENCERS.register_module()
|
||||||
class CLPInferencer:
|
class CLPInferencer(BaseInferencer):
|
||||||
"""Conditional log probability based In-context Learning Inferencer.
|
"""Conditional log probability based In-context Learning Inferencer.
|
||||||
|
|
||||||
Calculate the log probability of each choices according the logits.
|
Calculate the log probability of each choices according the logits.
|
||||||
@ -42,8 +40,6 @@ class CLPInferencer:
|
|||||||
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by
|
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by
|
||||||
the LM.
|
the LM.
|
||||||
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`
|
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`
|
||||||
accelerator (:obj:`Accelerator`, optional): An instance of the
|
|
||||||
`Accelerator` class, used for multiprocessing.
|
|
||||||
output_json_filepath (:obj:`str`, optional): File path for output
|
output_json_filepath (:obj:`str`, optional): File path for output
|
||||||
`JSON` file.
|
`JSON` file.
|
||||||
output_json_filename (:obj:`str`, optional): File name for output
|
output_json_filename (:obj:`str`, optional): File name for output
|
||||||
@ -57,29 +53,20 @@ class CLPInferencer:
|
|||||||
model: BaseModel,
|
model: BaseModel,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
batch_size: Optional[int] = 1,
|
batch_size: Optional[int] = 1,
|
||||||
accelerator: Optional[Accelerator] = None,
|
|
||||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||||
output_json_filename: Optional[str] = 'predictions',
|
output_json_filename: Optional[str] = 'predictions',
|
||||||
fix_id_list: Optional[List[int]] = None,
|
fix_id_list: Optional[List[int]] = None,
|
||||||
single_token: bool = True,
|
single_token: bool = True,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
|
super().__init__(
|
||||||
|
model=model,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
batch_size=batch_size,
|
||||||
|
output_json_filename=output_json_filename,
|
||||||
|
output_json_filepath=output_json_filepath,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
self.accelerator = accelerator
|
|
||||||
self.is_main_process = (True if self.accelerator is None
|
|
||||||
or self.accelerator.is_main_process else False)
|
|
||||||
|
|
||||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
||||||
if self.model is not None:
|
|
||||||
self.model.to(self.device)
|
|
||||||
|
|
||||||
self.max_seq_len = max_seq_len
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.output_json_filepath = output_json_filepath
|
|
||||||
self.output_json_filename = output_json_filename
|
|
||||||
if not os.path.exists(self.output_json_filepath):
|
|
||||||
os.makedirs(self.output_json_filepath)
|
|
||||||
self.fix_id_list = fix_id_list
|
self.fix_id_list = fix_id_list
|
||||||
# TODO: support multiple token
|
# TODO: support multiple token
|
||||||
assert single_token, 'Only support single token choice currently.'
|
assert single_token, 'Only support single token choice currently.'
|
||||||
@ -111,8 +98,8 @@ class CLPInferencer:
|
|||||||
# 3. Generate in-context examples for testing inputs
|
# 3. Generate in-context examples for testing inputs
|
||||||
for idx in range(len(ice_idx_list)):
|
for idx in range(len(ice_idx_list)):
|
||||||
ice.append(
|
ice.append(
|
||||||
retriever.generate_ice(ice_idx_list[idx],
|
retriever.generate_ice(
|
||||||
ice_template=ice_template))
|
ice_idx_list[idx], ice_template=ice_template))
|
||||||
output_handler.save_ice(ice)
|
output_handler.save_ice(ice)
|
||||||
|
|
||||||
# 4. Collect prompts and calculate conditional log probs
|
# 4. Collect prompts and calculate conditional log probs
|
||||||
@ -129,6 +116,9 @@ class CLPInferencer:
|
|||||||
]
|
]
|
||||||
except ValueError:
|
except ValueError:
|
||||||
choice_ids = [self.model.tokenizer.encode(c) for c in choices]
|
choice_ids = [self.model.tokenizer.encode(c) for c in choices]
|
||||||
|
if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer': # noqa
|
||||||
|
choice_ids = [c[2:] for c in choice_ids]
|
||||||
|
else:
|
||||||
if self.model.tokenizer.add_bos_token:
|
if self.model.tokenizer.add_bos_token:
|
||||||
choice_ids = [c[1:] for c in choice_ids]
|
choice_ids = [c[1:] for c in choice_ids]
|
||||||
if self.model.tokenizer.add_eos_token:
|
if self.model.tokenizer.add_eos_token:
|
||||||
@ -175,7 +165,8 @@ class CLPInferencer:
|
|||||||
choice_target_ids.append(prompt_token_num - 1)
|
choice_target_ids.append(prompt_token_num - 1)
|
||||||
|
|
||||||
logger.info('Calculating conditional log probability for prompts.')
|
logger.info('Calculating conditional log probability for prompts.')
|
||||||
for idx in trange(0,
|
for idx in trange(
|
||||||
|
0,
|
||||||
len(prompt_list),
|
len(prompt_list),
|
||||||
self.batch_size,
|
self.batch_size,
|
||||||
disable=not self.is_main_process):
|
disable=not self.is_main_process):
|
||||||
@ -209,10 +200,11 @@ class CLPInferencer:
|
|||||||
choice_ids,
|
choice_ids,
|
||||||
mask_length=None):
|
mask_length=None):
|
||||||
# TODO: support multiple tokens
|
# TODO: support multiple tokens
|
||||||
try:
|
if hasattr(self.model, 'generator'):
|
||||||
outputs, _ = self.model.generator.get_logits(input_texts)
|
outputs, _ = self.model.generator.get_logits(input_texts)
|
||||||
except AttributeError:
|
else:
|
||||||
outputs, _ = self.model.get_logits(input_texts)
|
outputs, _ = self.model.get_logits(input_texts)
|
||||||
|
|
||||||
shift_logits = outputs[..., :-1, :].contiguous()
|
shift_logits = outputs[..., :-1, :].contiguous()
|
||||||
|
|
||||||
shift_logits = F.log_softmax(shift_logits, dim=-1)
|
shift_logits = F.log_softmax(shift_logits, dim=-1)
|
||||||
|
Loading…
Reference in New Issue
Block a user