From 7199acc25dace594455acd163f79807a5e2ea929 Mon Sep 17 00:00:00 2001 From: Yuan Feng Date: Tue, 21 Nov 2023 17:51:30 +0800 Subject: [PATCH] Add support for DataCanvas Alaya LM (#612) * Support for Alaya * Remove useless requirements --- configs/eval_alaya.py | 11 +++ configs/models/alaya/alaya.py | 19 ++++ opencompass/models/__init__.py | 1 + opencompass/models/alaya.py | 159 +++++++++++++++++++++++++++++++++ requirements/runtime.txt | 1 + 5 files changed, 191 insertions(+) create mode 100644 configs/eval_alaya.py create mode 100644 configs/models/alaya/alaya.py create mode 100644 opencompass/models/alaya.py diff --git a/configs/eval_alaya.py b/configs/eval_alaya.py new file mode 100644 index 00000000..f9e6cc44 --- /dev/null +++ b/configs/eval_alaya.py @@ -0,0 +1,11 @@ +from mmengine.config import read_base + +with read_base(): + from .datasets.ceval.ceval_gen import ceval_datasets + from .datasets.cmmlu.cmmlu_gen import cmmlu_datasets + from .datasets.agieval.agieval_gen import agieval_datasets + from .datasets.bbh.bbh_gen import bbh_datasets + from .datasets.mmlu.mmlu_gen import mmlu_datasets + from .models.alaya.alaya import models + +datasets = [*bbh_datasets, *ceval_datasets, *cmmlu_datasets, *agieval_datasets, *mmlu_datasets] diff --git a/configs/models/alaya/alaya.py b/configs/models/alaya/alaya.py new file mode 100644 index 00000000..70f43bb3 --- /dev/null +++ b/configs/models/alaya/alaya.py @@ -0,0 +1,19 @@ +from opencompass.models import AlayaLM + + +models = [ + dict( + type=AlayaLM, + abbr='alaya-7b-hf', + path="DataCanvas/Alaya-7B-Base", + tokenizer_path='DataCanvas/Alaya-7B-Base', + tokenizer_kwargs=dict(padding_side='left', + truncation_side='left', + trust_remote_code=True, + use_fast=False,), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + model_kwargs=dict(device_map='auto', trust_remote_code=True), + run_cfg=dict(num_gpus=1, num_procs=1)) +] diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index b727e72c..313162ed 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -1,3 +1,4 @@ +from .alaya import AlayaLM # noqa: F401 from .base import BaseModel, LMTemplateParser # noqa from .base_api import APITemplateParser, BaseAPIModel # noqa from .claude_api import Claude # noqa: F401 diff --git a/opencompass/models/alaya.py b/opencompass/models/alaya.py new file mode 100644 index 00000000..d25108f6 --- /dev/null +++ b/opencompass/models/alaya.py @@ -0,0 +1,159 @@ +from typing import Dict, List, Optional, Union + +import torch +from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, + pipeline) + +from opencompass.utils.prompt import PromptList + +from .base import BaseModel, LMTemplateParser + +PromptType = Union[PromptList, str] + + +class AlayaLM(BaseModel): + """Model wrapper for Alaya model. + + Args: + path (str): The name or path to Alaya model, could be a local path + or a Huggingface model tag of Alaya. + max_seq_len (int): The maximum length of the input sequence. Defaults + to 2048. + tokenizer_only (bool): If True, only the tokenizer will be initialized. + Defaults to False. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + + Note: + Alaya has some arguments which should be fixed such as + eos_token_id and bad_words_ids. + Model config should be loaded from a model config file. + Triton is supported to accelerate the inference process. + This class supports both Alaya Base model and Alaya Chat model. + """ + + def __init__(self, + path: str, + max_seq_len: int = 2048, + tokenizer_only: bool = False, + meta_template: Optional[Dict] = None, + **kwargs): + + self.template_parser = LMTemplateParser(meta_template) + + self.max_seq_len = max_seq_len + self.tokenizer_only = tokenizer_only + self.meta_template = meta_template + + self.name = path + self.eos_token_id = 2 + self.bad_words_ids = 3 + + self.gpu_id = '0' + + self.config = AutoConfig.from_pretrained(self.name, + trust_remote_code=True, + local_file_only=True) + self.config.attn_config['attn_impl'] = 'triton' + self.config.init_device = 'cuda:' + self.gpu_id + + self.model = AutoModelForCausalLM.from_pretrained( + self.name, + config=self.config, + torch_dtype=torch.bfloat16, # Load model weights in bfloat16 + trust_remote_code=True, + ) + + self.tokenizer = AutoTokenizer.from_pretrained(self.name, + local_file_only=True, + padding_side='left') + + self.pipe = pipeline('text-generation', + model=self.model, + tokenizer=self.tokenizer, + bad_words_ids=[[self.bad_words_ids]], + eos_token_id=self.eos_token_id, + pad_token_id=self.eos_token_id, + device='cuda:' + self.gpu_id) + + def do_inference(self, instruction, history=[]): + PROMPT_FORMAT = '### Instruction:\t\n{instruction}\n\n' + OUTPUT_FORMAT = '### Output:\t\n{output} ' + + prompt = PROMPT_FORMAT.format(instruction=instruction) + + history2llm = [] + + for i, msg in enumerate(history): + if i % 2 == 0: # user + msg2llm = PROMPT_FORMAT.format(instruction=msg) + else: # alaya + msg2llm = OUTPUT_FORMAT.format(output=msg) + history2llm.append(msg2llm) + + flag = '### Output:\t\n' + prompt2LLM = ''.join(history2llm) + prompt + + if len(prompt2LLM) >= 1500: + prompt2LLM = prompt2LLM[-1500:] + + result = self.pipe(prompt2LLM, + max_new_tokens=100, + max_length=1900, + do_sample=True, + use_cache=True, + eos_token_id=self.eos_token_id, + pad_token_id=self.eos_token_id) + + try: + output = result[0]['generated_text'][len(prompt2LLM):].lstrip(flag) + except Exception: + output = result[0]['generated_text'] + return output + + def generate( + self, + inputs, + max_out_len: int = 1000, + ) -> List[str]: + """Generate results given a list of inputs.""" + outputs = [] + for instruction in inputs: + output = self.do_inference(instruction=instruction) + outputs.append(output) + return outputs + + def get_token_len(self, prompt: str) -> int: + """Get lengths of the tokenized string.""" + return len(self.tokenizer.encode(prompt)) + + def get_ppl(self, + inputs: List[str], + mask_length: Optional[List[int]] = None) -> List[float]: + """Copied from .huggingface.py.""" + + assert mask_length is None, 'mask_length is not supported' + bsz = len(inputs) + params = self.model.params + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + # tokenize + prompt_tokens = [self.tokenizer.encode(x, True, False) for x in inputs] + max_prompt_size = max([len(t) for t in prompt_tokens]) + total_len = min(params.max_seq_len, max_prompt_size) + tokens = torch.zeros((bsz, total_len)).cuda().long() + for k, t in enumerate(prompt_tokens): + num_token = min(total_len, len(t)) + tokens[k, :num_token] = torch.tensor(t[-num_token:]).long() + # forward + outputs = self.model.forward(tokens, 0) + # compute ppl + shift_logits = outputs[..., :-1, :].contiguous().float() + shift_labels = tokens[..., 1:].contiguous() + shift_logits = shift_logits.view(-1, shift_logits.size(-1)) + shift_labels = shift_labels.view(-1) + loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=0) + loss = loss_fct(shift_logits, shift_labels).view(bsz, -1) + lens = (tokens != 0).sum(-1).cpu().numpy() + ce_loss = loss.sum(-1).cpu().detach().numpy() / lens + return ce_loss diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 3ea57737..3e905ab6 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -5,6 +5,7 @@ cn2an colossalai cpm_kernels datasets>=2.12.0 +einops==0.5.0 evaluate>=0.3.0 fairscale fuzzywuzzy