mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Add support for DataCanvas Alaya LM (#612)
* Support for Alaya * Remove useless requirements
This commit is contained in:
parent
dbacd36379
commit
7199acc25d
11
configs/eval_alaya.py
Normal file
11
configs/eval_alaya.py
Normal file
@ -0,0 +1,11 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .datasets.ceval.ceval_gen import ceval_datasets
|
||||
from .datasets.cmmlu.cmmlu_gen import cmmlu_datasets
|
||||
from .datasets.agieval.agieval_gen import agieval_datasets
|
||||
from .datasets.bbh.bbh_gen import bbh_datasets
|
||||
from .datasets.mmlu.mmlu_gen import mmlu_datasets
|
||||
from .models.alaya.alaya import models
|
||||
|
||||
datasets = [*bbh_datasets, *ceval_datasets, *cmmlu_datasets, *agieval_datasets, *mmlu_datasets]
|
19
configs/models/alaya/alaya.py
Normal file
19
configs/models/alaya/alaya.py
Normal file
@ -0,0 +1,19 @@
|
||||
from opencompass.models import AlayaLM
|
||||
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=AlayaLM,
|
||||
abbr='alaya-7b-hf',
|
||||
path="DataCanvas/Alaya-7B-Base",
|
||||
tokenizer_path='DataCanvas/Alaya-7B-Base',
|
||||
tokenizer_kwargs=dict(padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
use_fast=False,),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
model_kwargs=dict(device_map='auto', trust_remote_code=True),
|
||||
run_cfg=dict(num_gpus=1, num_procs=1))
|
||||
]
|
@ -1,3 +1,4 @@
|
||||
from .alaya import AlayaLM # noqa: F401
|
||||
from .base import BaseModel, LMTemplateParser # noqa
|
||||
from .base_api import APITemplateParser, BaseAPIModel # noqa
|
||||
from .claude_api import Claude # noqa: F401
|
||||
|
159
opencompass/models/alaya.py
Normal file
159
opencompass/models/alaya.py
Normal file
@ -0,0 +1,159 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
||||
pipeline)
|
||||
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
from .base import BaseModel, LMTemplateParser
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
||||
|
||||
class AlayaLM(BaseModel):
|
||||
"""Model wrapper for Alaya model.
|
||||
|
||||
Args:
|
||||
path (str): The name or path to Alaya model, could be a local path
|
||||
or a Huggingface model tag of Alaya.
|
||||
max_seq_len (int): The maximum length of the input sequence. Defaults
|
||||
to 2048.
|
||||
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
||||
Defaults to False.
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
|
||||
Note:
|
||||
Alaya has some arguments which should be fixed such as
|
||||
eos_token_id and bad_words_ids.
|
||||
Model config should be loaded from a model config file.
|
||||
Triton is supported to accelerate the inference process.
|
||||
This class supports both Alaya Base model and Alaya Chat model.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
max_seq_len: int = 2048,
|
||||
tokenizer_only: bool = False,
|
||||
meta_template: Optional[Dict] = None,
|
||||
**kwargs):
|
||||
|
||||
self.template_parser = LMTemplateParser(meta_template)
|
||||
|
||||
self.max_seq_len = max_seq_len
|
||||
self.tokenizer_only = tokenizer_only
|
||||
self.meta_template = meta_template
|
||||
|
||||
self.name = path
|
||||
self.eos_token_id = 2
|
||||
self.bad_words_ids = 3
|
||||
|
||||
self.gpu_id = '0'
|
||||
|
||||
self.config = AutoConfig.from_pretrained(self.name,
|
||||
trust_remote_code=True,
|
||||
local_file_only=True)
|
||||
self.config.attn_config['attn_impl'] = 'triton'
|
||||
self.config.init_device = 'cuda:' + self.gpu_id
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.name,
|
||||
config=self.config,
|
||||
torch_dtype=torch.bfloat16, # Load model weights in bfloat16
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.name,
|
||||
local_file_only=True,
|
||||
padding_side='left')
|
||||
|
||||
self.pipe = pipeline('text-generation',
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
bad_words_ids=[[self.bad_words_ids]],
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.eos_token_id,
|
||||
device='cuda:' + self.gpu_id)
|
||||
|
||||
def do_inference(self, instruction, history=[]):
|
||||
PROMPT_FORMAT = '### Instruction:\t\n{instruction}\n\n'
|
||||
OUTPUT_FORMAT = '### Output:\t\n{output} </s>'
|
||||
|
||||
prompt = PROMPT_FORMAT.format(instruction=instruction)
|
||||
|
||||
history2llm = []
|
||||
|
||||
for i, msg in enumerate(history):
|
||||
if i % 2 == 0: # user
|
||||
msg2llm = PROMPT_FORMAT.format(instruction=msg)
|
||||
else: # alaya
|
||||
msg2llm = OUTPUT_FORMAT.format(output=msg)
|
||||
history2llm.append(msg2llm)
|
||||
|
||||
flag = '### Output:\t\n'
|
||||
prompt2LLM = ''.join(history2llm) + prompt
|
||||
|
||||
if len(prompt2LLM) >= 1500:
|
||||
prompt2LLM = prompt2LLM[-1500:]
|
||||
|
||||
result = self.pipe(prompt2LLM,
|
||||
max_new_tokens=100,
|
||||
max_length=1900,
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.eos_token_id)
|
||||
|
||||
try:
|
||||
output = result[0]['generated_text'][len(prompt2LLM):].lstrip(flag)
|
||||
except Exception:
|
||||
output = result[0]['generated_text']
|
||||
return output
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs,
|
||||
max_out_len: int = 1000,
|
||||
) -> List[str]:
|
||||
"""Generate results given a list of inputs."""
|
||||
outputs = []
|
||||
for instruction in inputs:
|
||||
output = self.do_inference(instruction=instruction)
|
||||
outputs.append(output)
|
||||
return outputs
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized string."""
|
||||
return len(self.tokenizer.encode(prompt))
|
||||
|
||||
def get_ppl(self,
|
||||
inputs: List[str],
|
||||
mask_length: Optional[List[int]] = None) -> List[float]:
|
||||
"""Copied from .huggingface.py."""
|
||||
|
||||
assert mask_length is None, 'mask_length is not supported'
|
||||
bsz = len(inputs)
|
||||
params = self.model.params
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
# tokenize
|
||||
prompt_tokens = [self.tokenizer.encode(x, True, False) for x in inputs]
|
||||
max_prompt_size = max([len(t) for t in prompt_tokens])
|
||||
total_len = min(params.max_seq_len, max_prompt_size)
|
||||
tokens = torch.zeros((bsz, total_len)).cuda().long()
|
||||
for k, t in enumerate(prompt_tokens):
|
||||
num_token = min(total_len, len(t))
|
||||
tokens[k, :num_token] = torch.tensor(t[-num_token:]).long()
|
||||
# forward
|
||||
outputs = self.model.forward(tokens, 0)
|
||||
# compute ppl
|
||||
shift_logits = outputs[..., :-1, :].contiguous().float()
|
||||
shift_labels = tokens[..., 1:].contiguous()
|
||||
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
||||
shift_labels = shift_labels.view(-1)
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=0)
|
||||
loss = loss_fct(shift_logits, shift_labels).view(bsz, -1)
|
||||
lens = (tokens != 0).sum(-1).cpu().numpy()
|
||||
ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
|
||||
return ce_loss
|
@ -5,6 +5,7 @@ cn2an
|
||||
colossalai
|
||||
cpm_kernels
|
||||
datasets>=2.12.0
|
||||
einops==0.5.0
|
||||
evaluate>=0.3.0
|
||||
fairscale
|
||||
fuzzywuzzy
|
||||
|
Loading…
Reference in New Issue
Block a user