diff --git a/configs/eval_internLM.py b/configs/eval_internLM.py new file mode 100644 index 00000000..c6fa425c --- /dev/null +++ b/configs/eval_internLM.py @@ -0,0 +1,9 @@ +from mmengine.config import read_base + +with read_base(): + # choose a list of datasets + from .datasets.collections.base_medium import datasets + # choose a model of interest + from .models.internlm_7b import models + # and output the results in a choosen format + from .summarizers.medium import summarizer \ No newline at end of file diff --git a/configs/models/internlm_7b.py b/configs/models/internlm_7b.py new file mode 100644 index 00000000..5d358be9 --- /dev/null +++ b/configs/models/internlm_7b.py @@ -0,0 +1,14 @@ +from opencompass.models import InternLM + + +models = [ + dict( + type=InternLM, + path="./internData/", + tokenizer_path='./internData/V7.model', + model_config="./internData/model_config.py", + max_out_len=100, + max_seq_len=2048, + batch_size=16, + run_cfg=dict(num_gpus=1, num_procs=1)) +] diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 6df976be..1f404cf7 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -3,5 +3,6 @@ from .base_api import APITemplateParser, BaseAPIModel # noqa from .glm import GLM130B # noqa: F401, F403 from .huggingface import HuggingFace # noqa: F401, F403 from .huggingface import HuggingFaceCausalLM # noqa: F401, F403 +from .intern_model import InternLM # noqa: F401, F403 from .llama2 import Llama2Chat # noqa: F401, F403 from .openai_api import OpenAI # noqa: F401 diff --git a/opencompass/models/intern_model.py b/opencompass/models/intern_model.py new file mode 100644 index 00000000..824db008 --- /dev/null +++ b/opencompass/models/intern_model.py @@ -0,0 +1,127 @@ +from typing import Dict, List, Optional + +import numpy as np +import torch + +from opencompass.models.base import BaseModel, LMTemplateParser + + +class InternLM(BaseModel): + + def __init__(self, + path: str, + max_seq_len: int = 2048, + tokenizer_only: bool = False, + tokenizer_path: Optional[str] = None, + model_config: Optional[str] = None, + tokenizer_type: Optional[str] = 'v7', + meta_template: Optional[Dict] = None): + if tokenizer_only: + self._load_tokenizer(tokenizer_path=tokenizer_path, + tokenizer_type=tokenizer_type, + max_seq_len=max_seq_len) + else: + self._load_model(path=path, + max_seq_len=max_seq_len, + tokenizer_path=tokenizer_path, + tokenizer_type=tokenizer_type, + model_config=model_config) + self.template_parser = LMTemplateParser(meta_template) + self.eos_token_id = None + if meta_template and 'eos_token_id' in meta_template: + self.eos_token_id = meta_template['eos_token_id'] + + def _load_model(self, + path: str, + max_seq_len: int, + tokenizer_path: Optional[str] = None, + tokenizer_type: Optional[str] = None, + model_config: Optional[str] = None): + + from internlm.load.load_model import load_llm + from internlm.model import build_model_with_cfg + + self.model, self.tokenizer, self.generator, _ = load_llm( + path, + max_seq_len, + tokenizer_path=tokenizer_path, + tokenizer_type=tokenizer_type, + module=build_model_with_cfg, + model_config_path=model_config) + + def _load_tokenizer(self, tokenizer_path: str, tokenizer_type: str, + max_seq_len: int): + from internlm.load.tokenizer import LLMTokenizer + from sentencepiece import SentencePieceProcessor + tokenizer = SentencePieceProcessor() + tokenizer.load(tokenizer_path) + tokenizer = LLMTokenizer(tokenizer, + max_seq_len=max_seq_len, + tokenizer_type=tokenizer_type) + self.tokenizer = tokenizer + + def get_token_len(self, prompt: str) -> int: + """Get lengths of the tokenized strings. + + Args: + prompt (str): Input string. + + Returns: + int: Length of the input tokens + """ + tokens = self.tokenizer([prompt], truncation=False)['tokens'] + return len(tokens[0]) + + def generate(self, inputs: List[str], max_out_len: int) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str]): A list of strings. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + return self.generator.generate(inputs, + generation_kwargs={ + 'max_gen_len': max_out_len, + 'eos_token_id': self.eos_token_id + }) + + def get_ppl(self, + input_texts: List[str], + mask_length: Optional[List[int]] = None) -> List[float]: + """Get perplexity scores given a list of inputs. + + Args: + input_texts (List[str]): A list of strings. + mask_length (Optional[List[int]]): A list of mask lengths. If + provided, the perplexity scores will be calculated with the + first mask_length[i] tokens masked out. + + Returns: + List[float]: A list of perplexity scores. + """ + outputs, inputs = self.generator.get_logits(input_texts) + + shift_logits = outputs[..., :-1, :].contiguous() + shift_labels = inputs['tokens'][..., 1:].contiguous() + + loss_fct = torch.nn.CrossEntropyLoss( + reduction='none', ignore_index=self.tokenizer.pad_token_id) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1)).view(shift_labels.size()) + + if mask_length is not None: + mask = torch.zeros_like(shift_labels) # [batch,seqlen] + for i in range(len(mask)): + for j in range(mask_length[i] - 1, len(mask[i])): + mask[i][j] = 1 + loss = loss * mask + + lens = (inputs['tokens'] != + self.tokenizer.pad_token_id).sum(-1).cpu().numpy() + if mask_length is not None: + lens -= np.array(mask_length) + ce_loss = loss.sum(-1).cpu().detach().numpy() / lens + return ce_loss