from abc import abstractclassmethod from copy import deepcopy from typing import Dict, List, Optional, Tuple, Union from opencompass.utils.prompt import PromptList PromptType = Union[PromptList, str] class BaseModel: """Base class for model wrapper. Args: path (str): The path to the model. max_seq_len (int): The maximum sequence length of the model. Defaults to 2048. tokenizer_only (bool): If True, only the tokenizer will be initialized. Defaults to False. meta_template (Dict, optional): The model's meta prompt template if needed, in case the requirement of injecting or wrapping of any meta instructions. """ is_api: bool = False def __init__(self, path: str, max_seq_len: int = 2048, tokenizer_only: bool = False, meta_template: Optional[Dict] = None): self.path = path self.max_seq_len = max_seq_len self.tokenizer_only = tokenizer_only # meta template self.template_parser = LMTemplateParser(meta_template) self.eos_token_id = None if meta_template and 'eos_token_id' in meta_template: self.eos_token_id = meta_template['eos_token_id'] @abstractclassmethod def generate(self, inputs: List[str], max_out_len: int) -> List[str]: """Generate results given a list of inputs. Args: inputs (List[str]): A list of strings. max_out_len (int): The maximum length of the output. Returns: List[str]: A list of generated strings. """ @abstractclassmethod def get_ppl(self, inputs: List[str], mask_length: Optional[List[int]] = None) -> List[float]: """Get perplexity scores given a list of inputs. Args: inputs (List[str]): A list of strings. mask_length (Optional[List[int]]): A list of mask lengths. If provided, the perplexity scores will be calculated with the first mask_length[i] tokens masked out. It's okay to skip its implementation if advanced features in PPLInfernecer is not needed. Returns: List[float]: A list of perplexity scores. """ @abstractclassmethod def get_token_len(self, prompt: str) -> int: """Get lengths of the tokenized strings. Args: prompt (str): Input string. Returns: int: Length of the input tokens """ def parse_template(self, prompt_template: PromptType, mode: str) -> str: """Parse a prompt template, and wrap it with meta template if applicable. Args: prompt_template (List[str or PromptList]): A prompt template (potentially before being wrapped by meta template). mode (str): Parsing mode. Choices are 'ppl' and 'gen'. Returns: str: The final string. """ return self.template_parser.parse_template(prompt_template, mode) def get_ppl_from_template(self, templates: List[PromptType], mask_length=None): """Get perplexity given a list of templates. Args: templates (List[PromptType]): A list of templates. mask_length (List[int]): A list of mask lengths. If provided, the perplexity will be calculated only on the unmasked tokens. """ inputs = self.parse_template(templates, mode='ppl') return self.get_ppl(inputs, mask_length) def generate_from_template(self, templates: List[PromptType], max_out_len: int, **kwargs): """Generate completion from a list of templates. Args: templates (List[PromptType]): A list of templates. max_out_len (int): The maximum length of the output. """ inputs = self.parse_template(templates, mode='gen') return self.generate(inputs, max_out_len=max_out_len, **kwargs) def get_token_len_from_template( self, templates: Union[PromptType, List[PromptType]], mode: str = 'ppl') -> Union[List[int], int]: """Get lengths given a list of templates. Args: templates (Union[List[str], str]): Input template(s). mode (str): Parsing mode. Choices are 'ppl' and 'gen'. Returns: Union[List[int], int]: Length(s) of the input tokens. If the input is a list, a list of lengths will be returned. Otherwise, an int will be returned. """ prompts = self.parse_template(templates, mode=mode) assert isinstance(prompts, (list, str)), 'tokens must be list or str' is_batched = isinstance(prompts, list) and not isinstance(prompts, PromptList) if not is_batched: prompts = [prompts] prompts = [str(prompt) for prompt in prompts] token_lens = [self.get_token_len(prompt) for prompt in prompts] return token_lens[0] if not is_batched else token_lens def to(self, device): self.model.to(device) class LMTemplateParser: """Intermidate prompt template parser, specifically for language models. Args: meta_template (Dict): The meta template for the model. """ def __init__(self, meta_template: Optional[Dict] = None): self.meta_template = meta_template if meta_template: assert 'round' in meta_template, 'round is required in meta' \ ' template' assert isinstance(meta_template['round'], list) keys_to_check = ['round'] if 'reserved_roles' in meta_template: assert isinstance(meta_template['reserved_roles'], list) keys_to_check.append('reserved_roles') self.roles: Dict[str, dict] = dict() # maps role name to config for meta_key in keys_to_check: for item in meta_template[meta_key]: assert isinstance(item, (str, dict)) if isinstance(item, dict): assert item['role'] not in self.roles, \ 'role in meta prompt must be unique!' self.roles[item['role']] = item.copy() # convert list of string and int into a raw string # for the ease of future prompt processing for key in ['begin', 'end']: value = self.roles[item['role']].get(key, '') if isinstance(value, list): self.roles[item['role']][ key] = self._encode_speical_tokens(value) def parse_template(self, prompt_template: PromptType, mode: str) -> str: """Parse a prompt template, and wrap it with meta template if applicable. Args: prompt_template (List[str or PromptList]): A prompt template (potentially before being wrapped by meta template). mode (str): Parsing mode. Choices are 'ppl' and 'gen'. Returns: str: The final string. """ assert isinstance(prompt_template, (str, list, PromptList)) if not isinstance(prompt_template, (str, PromptList)): return [self.parse_template(p, mode=mode) for p in prompt_template] assert mode in ['ppl', 'gen'] if isinstance(prompt_template, str): return prompt_template if self.meta_template: prompt = '' # Whether to keep generating the prompt generate = True section_stack = [] # stores tuples: (section_name, start_idx) for i, item in enumerate(prompt_template): if not generate: break if isinstance(item, str): prompt += item elif isinstance(item, dict) and 'section' in item: if item['pos'] == 'end': section_name, start_idx = section_stack.pop(-1) assert section_name == item['section'] if section_name in ['round', 'ice']: dialogue = prompt_template[start_idx:i] round_ranges = self._split_rounds( dialogue, self.meta_template['round']) # Consider inserting multiple round examples into # template for i in range(len(round_ranges) - 1): start = round_ranges[i] end = round_ranges[i + 1] round_template = dialogue[start:end] role_dict = self._update_role_dict( round_template) new_str, generate = self._prompt2str( self.meta_template['round'], role_dict, # Start generating only when the mode is in # generation and the template reaches the # last round for_gen=mode == 'gen' and i == len(round_ranges) - 2 and section_name == 'round') prompt += new_str elif item['pos'] == 'begin': assert item['section'] in [ 'begin', 'round', 'end', 'ice' ] section_stack.append((item['section'], i + 1)) else: raise ValueError(f'Invalid pos {item["pos"]}') # if in "begin" or "end" section elif section_stack[-1][0] in ['begin', 'end']: role_dict = self._update_role_dict(item) new_str, generate = self._prompt2str( item, role_dict, # never stop generation for_gen=False) prompt += new_str prompt = self.meta_template.get('begin', '') + prompt if generate: prompt += self.meta_template.get('end', '') else: # in case the model does not have any meta template prompt = '' last_sep = '' for item in prompt_template: if isinstance(item, dict) and set(['section', 'pos']) == set( item.keys()): continue if isinstance(item, str): if item: prompt += last_sep + item elif item.get('prompt', ''): # it's a dict prompt += last_sep + item.get('prompt', '') last_sep = '\n' return prompt def _split_rounds( self, prompt_template: List[Union[str, Dict]], single_round_template: List[Union[str, Dict]]) -> List[int]: """Split the prompt template into rounds, based on single round template. Return the index ranges of each round. Specifically, prompt_template[res[i]:res[i+1]] represents the i-th round in the template. """ role_idxs = { role_cfg['role']: i for i, role_cfg in enumerate(single_round_template) if not isinstance(role_cfg, str) } last_role_idx = -1 cutoff_idxs = [0] for idx, template in enumerate(prompt_template): if isinstance(template, str): continue role_idx = role_idxs[template['role']] if role_idx <= last_role_idx: cutoff_idxs.append(idx) last_role_idx = role_idx cutoff_idxs.append(len(prompt_template)) return cutoff_idxs def _update_role_dict(self, prompt: Union[List, str, Dict]) -> Dict[str, Dict]: """Update the default role dict with the given prompt(s).""" assert isinstance(prompt, (str, list, dict)) role_dict = deepcopy(self.roles) if isinstance(prompt, str): return role_dict if isinstance(prompt, dict): prompt = [prompt] for p in prompt: if isinstance(p, dict): role = p['role'] if role not in self.roles: role = p.get('fallback_role', None) if not role: print(f'{p} neither has an appropriate role nor ' 'a fallback role.') role_dict[role].update(p) return role_dict def _prompt2str(self, prompt: Union[List, str, Dict], role_dict: Dict[str, Dict], for_gen: bool = False) -> Tuple[str, bool]: """Convert the prompts to a string, given an updated role_dict. Args: prompts (Union[List, str, dict]): The prompt(s) to be converted. role_dict (Dict[str, Dict]): The updated role dict. for_gen (bool): If True, the prompts will be converted for generation tasks. The conversion stops before the first role whose "generate" is set to True. Returns: Tuple[str, bool]: The converted string, and whether the follow-up conversion should be proceeded. """ assert isinstance(prompt, (list, str, dict)) if isinstance(prompt, str): return prompt, True if isinstance(prompt, dict): return self._role2str(prompt, role_dict, for_gen) res = '' for p in prompt: new_str, cont = self._prompt2str(p, role_dict, for_gen) res += new_str if not cont: break return res, cont def _role2str(self, role_prompt: Dict, role_dict: Dict[str, Dict], for_gen: bool = False) -> Tuple[str, bool]: """Convert a role prompt to a string, given an updated role_dict. Args: role_prompt (Dict): The role prompt to be converted. role_dict (Dict[str, Dict]): The updated role dict. for_gen (bool): If True, the prompts will be converted for generation tasks. The conversion stops before the first role whose "generate" is set to True. Returns: Tuple[str, bool]: The converted string, and whether the follow-up conversion should be proceeded. """ merged_prompt = role_dict.get( role_prompt['role'], role_dict.get(role_prompt.get('fallback_role'))) res = merged_prompt.get('begin', '') if for_gen and merged_prompt.get('generate', False): return res, False # res += merged_prompt.get('prompt', '') + merged_prompt.get('end', '') res += merged_prompt.get('prompt', '') + merged_prompt.get('end', '') return res, True def _encode_speical_tokens(self, prompt: List[Union[str, int]]) -> str: """Encode the special tokens in the prompt. Now this is left for the future work """ raise NotImplementedError('Using List[str|int] is as the begin or end' 'of a prompt is not supported yet.') res = '' for item in prompt: if isinstance(item, str): res += item else: res += f'' return res