Add support for DataCanvas Alaya LM (#612)

* Support for Alaya

* Remove useless requirements
This commit is contained in:
Yuan Feng 2023-11-21 17:51:30 +08:00 committed by GitHub
parent dbacd36379
commit 7199acc25d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 191 additions and 0 deletions

11
configs/eval_alaya.py Normal file
View File

@ -0,0 +1,11 @@
from mmengine.config import read_base
with read_base():
from .datasets.ceval.ceval_gen import ceval_datasets
from .datasets.cmmlu.cmmlu_gen import cmmlu_datasets
from .datasets.agieval.agieval_gen import agieval_datasets
from .datasets.bbh.bbh_gen import bbh_datasets
from .datasets.mmlu.mmlu_gen import mmlu_datasets
from .models.alaya.alaya import models
datasets = [*bbh_datasets, *ceval_datasets, *cmmlu_datasets, *agieval_datasets, *mmlu_datasets]

View File

@ -0,0 +1,19 @@
from opencompass.models import AlayaLM
models = [
dict(
type=AlayaLM,
abbr='alaya-7b-hf',
path="DataCanvas/Alaya-7B-Base",
tokenizer_path='DataCanvas/Alaya-7B-Base',
tokenizer_kwargs=dict(padding_side='left',
truncation_side='left',
trust_remote_code=True,
use_fast=False,),
max_out_len=100,
max_seq_len=2048,
batch_size=8,
model_kwargs=dict(device_map='auto', trust_remote_code=True),
run_cfg=dict(num_gpus=1, num_procs=1))
]

View File

@ -1,3 +1,4 @@
from .alaya import AlayaLM # noqa: F401
from .base import BaseModel, LMTemplateParser # noqa
from .base_api import APITemplateParser, BaseAPIModel # noqa
from .claude_api import Claude # noqa: F401

159
opencompass/models/alaya.py Normal file
View File

@ -0,0 +1,159 @@
from typing import Dict, List, Optional, Union
import torch
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
pipeline)
from opencompass.utils.prompt import PromptList
from .base import BaseModel, LMTemplateParser
PromptType = Union[PromptList, str]
class AlayaLM(BaseModel):
"""Model wrapper for Alaya model.
Args:
path (str): The name or path to Alaya model, could be a local path
or a Huggingface model tag of Alaya.
max_seq_len (int): The maximum length of the input sequence. Defaults
to 2048.
tokenizer_only (bool): If True, only the tokenizer will be initialized.
Defaults to False.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
Note:
Alaya has some arguments which should be fixed such as
eos_token_id and bad_words_ids.
Model config should be loaded from a model config file.
Triton is supported to accelerate the inference process.
This class supports both Alaya Base model and Alaya Chat model.
"""
def __init__(self,
path: str,
max_seq_len: int = 2048,
tokenizer_only: bool = False,
meta_template: Optional[Dict] = None,
**kwargs):
self.template_parser = LMTemplateParser(meta_template)
self.max_seq_len = max_seq_len
self.tokenizer_only = tokenizer_only
self.meta_template = meta_template
self.name = path
self.eos_token_id = 2
self.bad_words_ids = 3
self.gpu_id = '0'
self.config = AutoConfig.from_pretrained(self.name,
trust_remote_code=True,
local_file_only=True)
self.config.attn_config['attn_impl'] = 'triton'
self.config.init_device = 'cuda:' + self.gpu_id
self.model = AutoModelForCausalLM.from_pretrained(
self.name,
config=self.config,
torch_dtype=torch.bfloat16, # Load model weights in bfloat16
trust_remote_code=True,
)
self.tokenizer = AutoTokenizer.from_pretrained(self.name,
local_file_only=True,
padding_side='left')
self.pipe = pipeline('text-generation',
model=self.model,
tokenizer=self.tokenizer,
bad_words_ids=[[self.bad_words_ids]],
eos_token_id=self.eos_token_id,
pad_token_id=self.eos_token_id,
device='cuda:' + self.gpu_id)
def do_inference(self, instruction, history=[]):
PROMPT_FORMAT = '### Instruction:\t\n{instruction}\n\n'
OUTPUT_FORMAT = '### Output:\t\n{output} </s>'
prompt = PROMPT_FORMAT.format(instruction=instruction)
history2llm = []
for i, msg in enumerate(history):
if i % 2 == 0: # user
msg2llm = PROMPT_FORMAT.format(instruction=msg)
else: # alaya
msg2llm = OUTPUT_FORMAT.format(output=msg)
history2llm.append(msg2llm)
flag = '### Output:\t\n'
prompt2LLM = ''.join(history2llm) + prompt
if len(prompt2LLM) >= 1500:
prompt2LLM = prompt2LLM[-1500:]
result = self.pipe(prompt2LLM,
max_new_tokens=100,
max_length=1900,
do_sample=True,
use_cache=True,
eos_token_id=self.eos_token_id,
pad_token_id=self.eos_token_id)
try:
output = result[0]['generated_text'][len(prompt2LLM):].lstrip(flag)
except Exception:
output = result[0]['generated_text']
return output
def generate(
self,
inputs,
max_out_len: int = 1000,
) -> List[str]:
"""Generate results given a list of inputs."""
outputs = []
for instruction in inputs:
output = self.do_inference(instruction=instruction)
outputs.append(output)
return outputs
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized string."""
return len(self.tokenizer.encode(prompt))
def get_ppl(self,
inputs: List[str],
mask_length: Optional[List[int]] = None) -> List[float]:
"""Copied from .huggingface.py."""
assert mask_length is None, 'mask_length is not supported'
bsz = len(inputs)
params = self.model.params
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
# tokenize
prompt_tokens = [self.tokenizer.encode(x, True, False) for x in inputs]
max_prompt_size = max([len(t) for t in prompt_tokens])
total_len = min(params.max_seq_len, max_prompt_size)
tokens = torch.zeros((bsz, total_len)).cuda().long()
for k, t in enumerate(prompt_tokens):
num_token = min(total_len, len(t))
tokens[k, :num_token] = torch.tensor(t[-num_token:]).long()
# forward
outputs = self.model.forward(tokens, 0)
# compute ppl
shift_logits = outputs[..., :-1, :].contiguous().float()
shift_labels = tokens[..., 1:].contiguous()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=0)
loss = loss_fct(shift_logits, shift_labels).view(bsz, -1)
lens = (tokens != 0).sum(-1).cpu().numpy()
ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
return ce_loss

View File

@ -5,6 +5,7 @@ cn2an
colossalai
cpm_kernels
datasets>=2.12.0
einops==0.5.0
evaluate>=0.3.0
fairscale
fuzzywuzzy