mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
89 lines
3.4 KiB
Python
89 lines
3.4 KiB
Python
![]() |
from typing import Dict, Iterable, List, Optional, Union
|
||
|
|
||
|
import numpy as np
|
||
|
import torch.distributed as dist
|
||
|
|
||
|
from opencompass.models.base import BaseModel
|
||
|
from opencompass.models.base_api import APITemplateParser
|
||
|
from opencompass.utils.logging import get_logger
|
||
|
from opencompass.utils.prompt import PromptList
|
||
|
|
||
|
PromptType = Union[PromptList, str]
|
||
|
|
||
|
|
||
|
class LLaMA2AccessoryModel(BaseModel):
|
||
|
"""LLaMA2-Accessory model wrapper.
|
||
|
|
||
|
Project: https://github.com/Alpha-VLLM/LLaMA2-Accessory
|
||
|
|
||
|
Args:
|
||
|
tokenizer_only (bool): whether to load tokenizer only
|
||
|
meta_template (dict): meta template for the model
|
||
|
additional_stop_symbols: (Iterable[str]): additional symbols that mark
|
||
|
the end of generation, e.g. the "###" symbol for separating turns
|
||
|
in the chat template.
|
||
|
from_pretrained_kwargs: kwargs that will be passed to
|
||
|
`accessory.MetaModel.from_pretrained` for model instantiation.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
tokenizer_only: bool = False,
|
||
|
meta_template: Optional[Dict] = None,
|
||
|
additional_stop_symbols: Iterable[str] = (),
|
||
|
**from_pretrained_kwargs):
|
||
|
if tokenizer_only:
|
||
|
self._load_tokenizer(from_pretrained_kwargs)
|
||
|
else:
|
||
|
self._load_model(from_pretrained_kwargs)
|
||
|
|
||
|
self.additional_stop_symbols = additional_stop_symbols
|
||
|
self.max_seq_len = from_pretrained_kwargs.get('max_seq_len', 4096)
|
||
|
self.template_parser = APITemplateParser(meta_template)
|
||
|
self.logger = get_logger()
|
||
|
|
||
|
def _load_model(self, from_pretrained_kwargs):
|
||
|
from accessory.model.meta import MetaModel
|
||
|
from accessory.util.misc import init_distributed_mode
|
||
|
if not dist.is_initialized():
|
||
|
init_distributed_mode()
|
||
|
|
||
|
model_parallel_group = dist.GroupMember.WORLD
|
||
|
from_pretrained_kwargs['mp_group'] = model_parallel_group
|
||
|
|
||
|
self.model = MetaModel.from_pretrained(**from_pretrained_kwargs)
|
||
|
self.tokenizer = self.model.tokenizer
|
||
|
self.logger = get_logger()
|
||
|
|
||
|
def _load_tokenizer(self, from_pretrained_kwargs):
|
||
|
from accessory.model.tokenizer import (
|
||
|
Tokenizer, probe_tokenizer_path_from_pretrained)
|
||
|
if 'tokenizer_path' in from_pretrained_kwargs:
|
||
|
tokenizer_path = from_pretrained_kwargs['tokenizer_path']
|
||
|
else:
|
||
|
pretrained_path = from_pretrained_kwargs['pretrained_path']
|
||
|
if isinstance(pretrained_path, str):
|
||
|
pretrained_path = [pretrained_path]
|
||
|
tokenizer_path = probe_tokenizer_path_from_pretrained(
|
||
|
pretrained_path[-1])
|
||
|
|
||
|
self.tokenizer = Tokenizer(tokenizer_path)
|
||
|
|
||
|
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
||
|
results = self.model.generate(
|
||
|
prompts=inputs,
|
||
|
max_gen_len=max_out_len,
|
||
|
temperature=0.,
|
||
|
additional_stop_symbols=self.additional_stop_symbols)
|
||
|
return results
|
||
|
|
||
|
def get_ppl(self,
|
||
|
inputs: List[str],
|
||
|
mask_length: Optional[List[int]] = None):
|
||
|
assert mask_length is None, 'mask_length is not supported'
|
||
|
evaluation_results = self.model.evaluate_examples(examples=inputs)
|
||
|
ppl = evaluation_results['ppl']
|
||
|
return np.array(ppl, dtype=np.float32)
|
||
|
|
||
|
def get_token_len(self, prompt: str) -> int:
|
||
|
return len(self.tokenizer.encode(prompt, True, True))
|