mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support intern lanuage model (#51)
* support internLM * support internLM * simplify intern model files * update storage_manager * support internLM * Modify the file organization structure * support internLM * support internLM * support internLM * support internLM * change some details
This commit is contained in:
parent
8a4d0867ab
commit
57fcfc975a
9
configs/eval_internLM.py
Normal file
9
configs/eval_internLM.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
# choose a list of datasets
|
||||||
|
from .datasets.collections.base_medium import datasets
|
||||||
|
# choose a model of interest
|
||||||
|
from .models.internlm_7b import models
|
||||||
|
# and output the results in a choosen format
|
||||||
|
from .summarizers.medium import summarizer
|
14
configs/models/internlm_7b.py
Normal file
14
configs/models/internlm_7b.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from opencompass.models import InternLM
|
||||||
|
|
||||||
|
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
type=InternLM,
|
||||||
|
path="./internData/",
|
||||||
|
tokenizer_path='./internData/V7.model',
|
||||||
|
model_config="./internData/model_config.py",
|
||||||
|
max_out_len=100,
|
||||||
|
max_seq_len=2048,
|
||||||
|
batch_size=16,
|
||||||
|
run_cfg=dict(num_gpus=1, num_procs=1))
|
||||||
|
]
|
@ -3,5 +3,6 @@ from .base_api import APITemplateParser, BaseAPIModel # noqa
|
|||||||
from .glm import GLM130B # noqa: F401, F403
|
from .glm import GLM130B # noqa: F401, F403
|
||||||
from .huggingface import HuggingFace # noqa: F401, F403
|
from .huggingface import HuggingFace # noqa: F401, F403
|
||||||
from .huggingface import HuggingFaceCausalLM # noqa: F401, F403
|
from .huggingface import HuggingFaceCausalLM # noqa: F401, F403
|
||||||
|
from .intern_model import InternLM # noqa: F401, F403
|
||||||
from .llama2 import Llama2Chat # noqa: F401, F403
|
from .llama2 import Llama2Chat # noqa: F401, F403
|
||||||
from .openai_api import OpenAI # noqa: F401
|
from .openai_api import OpenAI # noqa: F401
|
||||||
|
127
opencompass/models/intern_model.py
Normal file
127
opencompass/models/intern_model.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from opencompass.models.base import BaseModel, LMTemplateParser
|
||||||
|
|
||||||
|
|
||||||
|
class InternLM(BaseModel):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
path: str,
|
||||||
|
max_seq_len: int = 2048,
|
||||||
|
tokenizer_only: bool = False,
|
||||||
|
tokenizer_path: Optional[str] = None,
|
||||||
|
model_config: Optional[str] = None,
|
||||||
|
tokenizer_type: Optional[str] = 'v7',
|
||||||
|
meta_template: Optional[Dict] = None):
|
||||||
|
if tokenizer_only:
|
||||||
|
self._load_tokenizer(tokenizer_path=tokenizer_path,
|
||||||
|
tokenizer_type=tokenizer_type,
|
||||||
|
max_seq_len=max_seq_len)
|
||||||
|
else:
|
||||||
|
self._load_model(path=path,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
tokenizer_path=tokenizer_path,
|
||||||
|
tokenizer_type=tokenizer_type,
|
||||||
|
model_config=model_config)
|
||||||
|
self.template_parser = LMTemplateParser(meta_template)
|
||||||
|
self.eos_token_id = None
|
||||||
|
if meta_template and 'eos_token_id' in meta_template:
|
||||||
|
self.eos_token_id = meta_template['eos_token_id']
|
||||||
|
|
||||||
|
def _load_model(self,
|
||||||
|
path: str,
|
||||||
|
max_seq_len: int,
|
||||||
|
tokenizer_path: Optional[str] = None,
|
||||||
|
tokenizer_type: Optional[str] = None,
|
||||||
|
model_config: Optional[str] = None):
|
||||||
|
|
||||||
|
from internlm.load.load_model import load_llm
|
||||||
|
from internlm.model import build_model_with_cfg
|
||||||
|
|
||||||
|
self.model, self.tokenizer, self.generator, _ = load_llm(
|
||||||
|
path,
|
||||||
|
max_seq_len,
|
||||||
|
tokenizer_path=tokenizer_path,
|
||||||
|
tokenizer_type=tokenizer_type,
|
||||||
|
module=build_model_with_cfg,
|
||||||
|
model_config_path=model_config)
|
||||||
|
|
||||||
|
def _load_tokenizer(self, tokenizer_path: str, tokenizer_type: str,
|
||||||
|
max_seq_len: int):
|
||||||
|
from internlm.load.tokenizer import LLMTokenizer
|
||||||
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
tokenizer = SentencePieceProcessor()
|
||||||
|
tokenizer.load(tokenizer_path)
|
||||||
|
tokenizer = LLMTokenizer(tokenizer,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
tokenizer_type=tokenizer_type)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def get_token_len(self, prompt: str) -> int:
|
||||||
|
"""Get lengths of the tokenized strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): Input string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Length of the input tokens
|
||||||
|
"""
|
||||||
|
tokens = self.tokenizer([prompt], truncation=False)['tokens']
|
||||||
|
return len(tokens[0])
|
||||||
|
|
||||||
|
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
||||||
|
"""Generate results given a list of inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (List[str]): A list of strings.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of generated strings.
|
||||||
|
"""
|
||||||
|
return self.generator.generate(inputs,
|
||||||
|
generation_kwargs={
|
||||||
|
'max_gen_len': max_out_len,
|
||||||
|
'eos_token_id': self.eos_token_id
|
||||||
|
})
|
||||||
|
|
||||||
|
def get_ppl(self,
|
||||||
|
input_texts: List[str],
|
||||||
|
mask_length: Optional[List[int]] = None) -> List[float]:
|
||||||
|
"""Get perplexity scores given a list of inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_texts (List[str]): A list of strings.
|
||||||
|
mask_length (Optional[List[int]]): A list of mask lengths. If
|
||||||
|
provided, the perplexity scores will be calculated with the
|
||||||
|
first mask_length[i] tokens masked out.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: A list of perplexity scores.
|
||||||
|
"""
|
||||||
|
outputs, inputs = self.generator.get_logits(input_texts)
|
||||||
|
|
||||||
|
shift_logits = outputs[..., :-1, :].contiguous()
|
||||||
|
shift_labels = inputs['tokens'][..., 1:].contiguous()
|
||||||
|
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss(
|
||||||
|
reduction='none', ignore_index=self.tokenizer.pad_token_id)
|
||||||
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||||
|
shift_labels.view(-1)).view(shift_labels.size())
|
||||||
|
|
||||||
|
if mask_length is not None:
|
||||||
|
mask = torch.zeros_like(shift_labels) # [batch,seqlen]
|
||||||
|
for i in range(len(mask)):
|
||||||
|
for j in range(mask_length[i] - 1, len(mask[i])):
|
||||||
|
mask[i][j] = 1
|
||||||
|
loss = loss * mask
|
||||||
|
|
||||||
|
lens = (inputs['tokens'] !=
|
||||||
|
self.tokenizer.pad_token_id).sum(-1).cpu().numpy()
|
||||||
|
if mask_length is not None:
|
||||||
|
lens -= np.array(mask_length)
|
||||||
|
ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
|
||||||
|
return ce_loss
|
Loading…
Reference in New Issue
Block a user