From eea8b04417e5ec00f13898fbc5ab008837dc5d02 Mon Sep 17 00:00:00 2001 From: Leymore Date: Wed, 19 Jul 2023 19:51:29 +0800 Subject: [PATCH] [Feature] Add llama-2 models (#81) * add llama-2 models * update docs --------- Co-authored-by: gaotongxiao --- README.md | 1 + README_zh-CN.md | 5 ++ configs/models/hf_llama2_13b.py | 21 ++++++ configs/models/hf_llama2_70b.py | 21 ++++++ configs/models/hf_llama2_7b.py | 21 ++++++ configs/models/llama2_13b_chat.py | 31 ++++++++ configs/models/llama2_70b_chat.py | 31 ++++++++ configs/models/llama2_7b_chat.py | 31 ++++++++ docs/en/get_started.md | 63 +++++++++++------ docs/zh_cn/get_started.md | 63 +++++++++++------ opencompass/models/__init__.py | 1 + opencompass/models/llama2.py | 113 ++++++++++++++++++++++++++++++ 12 files changed, 358 insertions(+), 44 deletions(-) create mode 100644 configs/models/hf_llama2_13b.py create mode 100644 configs/models/hf_llama2_70b.py create mode 100644 configs/models/hf_llama2_7b.py create mode 100644 configs/models/llama2_13b_chat.py create mode 100644 configs/models/llama2_70b_chat.py create mode 100644 configs/models/llama2_7b_chat.py create mode 100644 opencompass/models/llama2.py diff --git a/README.md b/README.md index e1e662e4..70953b87 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ Just like a compass guides us on our journey, OpenCompass will guide you through ## News +- **\[2023.07.19\]** We have supported [Llama 2](https://ai.meta.com/llama/)! Its performance report will be available soon. \[[doc](./docs/en/get_started.md#Installation)\] - **\[2023.07.13\]** We release [MMBench](https://opencompass.org.cn/MMBench), a meticulously curated dataset to comprehensively evaluate different abilities of multimodality models 🔥🔥🔥. ## Introduction diff --git a/README_zh-CN.md b/README_zh-CN.md index f7656433..fae6e28f 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -25,6 +25,11 @@ 就像指南针在我们的旅程中为我们导航一样,我们希望OpenCompass能够帮助你穿越评估大型语言模型的重重迷雾。OpenCompass提供丰富的算法和功能支持,期待OpenCompass能够帮助社区更便捷地对NLP模型的性能进行公平全面的评估。 +## 更新 + +- **\[2023.07.19\]** 新增了 [Llama 2](https://ai.meta.com/llama/)!我们近期将会公布其评测结果。\[[文档](./docs/zh_cn/get_started.md#安装)\] +- **\[2023.07.13\]** 发布了 [MMBench](https://opencompass.org.cn/MMBench),该数据集经过细致整理,用于评测多模态模型全方位能力 🔥🔥🔥。 + ## 介绍 OpenCompass 是面向大模型评测的一站式平台。其主要特点如下: diff --git a/configs/models/hf_llama2_13b.py b/configs/models/hf_llama2_13b.py new file mode 100644 index 00000000..4103c874 --- /dev/null +++ b/configs/models/hf_llama2_13b.py @@ -0,0 +1,21 @@ +from opencompass.models import HuggingFaceCausalLM + + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='llama-2-13b-hf', + path="meta-llama/Llama-2-13b-hf", + tokenizer_path='meta-llama/Llama-2-13b-hf', + tokenizer_kwargs=dict(padding_side='left', + truncation_side='left', + use_fast=False, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + model_kwargs=dict(device_map='auto'), + batch_padding=False, # if false, inference with for-loop without batch padding + run_cfg=dict(num_gpus=2, num_procs=1), + ) +] diff --git a/configs/models/hf_llama2_70b.py b/configs/models/hf_llama2_70b.py new file mode 100644 index 00000000..44078cf0 --- /dev/null +++ b/configs/models/hf_llama2_70b.py @@ -0,0 +1,21 @@ +from opencompass.models import HuggingFaceCausalLM + + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='llama-2-70b-hf', + path="meta-llama/Llama-2-70b-hf", + tokenizer_path='meta-llama/Llama-2-70b-hf', + tokenizer_kwargs=dict(padding_side='left', + truncation_side='left', + use_fast=False, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + model_kwargs=dict(device_map='auto'), + batch_padding=False, # if false, inference with for-loop without batch padding + run_cfg=dict(num_gpus=8, num_procs=1), + ) +] diff --git a/configs/models/hf_llama2_7b.py b/configs/models/hf_llama2_7b.py new file mode 100644 index 00000000..3d00990e --- /dev/null +++ b/configs/models/hf_llama2_7b.py @@ -0,0 +1,21 @@ +from opencompass.models import HuggingFaceCausalLM + + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='llama-2-7b-hf', + path="meta-llama/Llama-2-7b-hf", + tokenizer_path='meta-llama/Llama-2-7b-hf', + tokenizer_kwargs=dict(padding_side='left', + truncation_side='left', + use_fast=False, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + model_kwargs=dict(device_map='auto'), + batch_padding=False, # if false, inference with for-loop without batch padding + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/configs/models/llama2_13b_chat.py b/configs/models/llama2_13b_chat.py new file mode 100644 index 00000000..ef77148c --- /dev/null +++ b/configs/models/llama2_13b_chat.py @@ -0,0 +1,31 @@ +from opencompass.models import Llama2Chat + +# Please follow the instruction in the Meta AI website https://github.com/facebookresearch/llama +# and download the LLaMA-2-Chat model and tokenizer to the path './models/llama2/llama/'. +# +# The LLaMA requirement is also needed to be installed. +# +# git clone https://github.com/facebookresearch/llama.git +# cd llama +# pip install -e . + +api_meta_template = dict( + round=[ + dict(role="HUMAN", api_role="HUMAN"), + dict(role="BOT", api_role="BOT", generate=True), + ], +) + +models = [ + dict( + abbr="llama-2-13b-chat", + type=Llama2Chat, + path="./models/llama2/llama/llama-2-13b-chat/", + tokenizer_path="./models/llama2/llama/tokenizer.model", + meta_template=api_meta_template, + max_out_len=100, + max_seq_len=2048, + batch_size=16, + run_cfg=dict(num_gpus=2, num_procs=2), + ), +] diff --git a/configs/models/llama2_70b_chat.py b/configs/models/llama2_70b_chat.py new file mode 100644 index 00000000..94a94b5c --- /dev/null +++ b/configs/models/llama2_70b_chat.py @@ -0,0 +1,31 @@ +from opencompass.models import Llama2Chat + +# Please follow the instruction in the Meta AI website https://github.com/facebookresearch/llama +# and download the LLaMA-2-Chat model and tokenizer to the path './models/llama2/llama/'. +# +# The LLaMA requirement is also needed to be installed. +# +# git clone https://github.com/facebookresearch/llama.git +# cd llama +# pip install -e . + +api_meta_template = dict( + round=[ + dict(role="HUMAN", api_role="HUMAN"), + dict(role="BOT", api_role="BOT", generate=True), + ], +) + +models = [ + dict( + abbr="llama-2-70b-chat", + type=Llama2Chat, + path="./models/llama2/llama/llama-2-70b-chat/", + tokenizer_path="./models/llama2/llama/tokenizer.model", + meta_template=api_meta_template, + max_out_len=100, + max_seq_len=2048, + batch_size=16, + run_cfg=dict(num_gpus=8, num_procs=8), + ), +] diff --git a/configs/models/llama2_7b_chat.py b/configs/models/llama2_7b_chat.py new file mode 100644 index 00000000..f3cb571e --- /dev/null +++ b/configs/models/llama2_7b_chat.py @@ -0,0 +1,31 @@ +from opencompass.models import Llama2Chat + +# Please follow the instruction in the Meta AI website https://github.com/facebookresearch/llama +# and download the LLaMA-2-Chat model and tokenizer to the path './models/llama2/llama/'. +# +# The LLaMA requirement is also needed to be installed. +# +# git clone https://github.com/facebookresearch/llama.git +# cd llama +# pip install -e . + +api_meta_template = dict( + round=[ + dict(role="HUMAN", api_role="HUMAN"), + dict(role="BOT", api_role="BOT", generate=True), + ], +) + +models = [ + dict( + abbr="llama-2-7b-chat", + type=Llama2Chat, + path="./models/llama2/llama/llama-2-7b-chat/", + tokenizer_path="./models/llama2/llama/tokenizer.model", + meta_template=api_meta_template, + max_out_len=100, + max_seq_len=2048, + batch_size=16, + run_cfg=dict(num_gpus=1, num_procs=1), + ), +] diff --git a/docs/en/get_started.md b/docs/en/get_started.md index 9c8660a5..b77afb1f 100644 --- a/docs/en/get_started.md +++ b/docs/en/get_started.md @@ -2,39 +2,58 @@ 1. Set up the OpenCompass environment: -```bash -conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y -conda activate opencompass -``` + ```bash + conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y + conda activate opencompass + ``` -If you want to customize the PyTorch version or related CUDA version, please refer to the [official documentation](https://pytorch.org/get-started/locally/) to set up the PyTorch environment. Note that OpenCompass requires `pytorch>=1.13`. + If you want to customize the PyTorch version or related CUDA version, please refer to the [official documentation](https://pytorch.org/get-started/locally/) to set up the PyTorch environment. Note that OpenCompass requires `pytorch>=1.13`. 2. Install OpenCompass: -```bash -git clone https://github.com/InternLM/opencompass.git -cd opencompass -pip install -e . -``` + ```bash + git clone https://github.com/InternLM/opencompass.git + cd opencompass + pip install -e . + ``` 3. Install humaneval (Optional) -If you want to **evaluate your models coding ability on the humaneval dataset**, execute this step otherwise skip it. + If you want to **evaluate your models coding ability on the humaneval dataset**, follow this step. -
-click to show the details +
+ click to show the details -```bash -git clone https://github.com/openai/human-eval.git -cd human-eval -pip install -r requirements.txt -pip install -e . -cd .. -``` + ```bash + git clone https://github.com/openai/human-eval.git + cd human-eval + pip install -r requirements.txt + pip install -e . + cd .. + ``` -Please read the comments in `human_eval/execution.py` **lines 48-57** to understand the potential risks of executing the model generation code. If you accept these risks, uncomment **line 58** to enable code execution evaluation. + Please read the comments in `human_eval/execution.py` **lines 48-57** to understand the potential risks of executing the model generation code. If you accept these risks, uncomment **line 58** to enable code execution evaluation. -
+
+ +4. Install Llama (Optional) + + If you want to **evaluate Llama / Llama-2 / Llama-2-chat with its official implementation**, follow this step. + +
+ click to show the details + + ```bash + git clone https://github.com/facebookresearch/llama.git + cd llama + pip install -r requirements.txt + pip install -e . + cd .. + ``` + + You can find example configs in `configs/models`. ([example](https://github.com/InternLM/opencompass/blob/eb4822a94d624a4e16db03adeb7a59bbd10c2012/configs/models/llama2_7b_chat.py)) + +
# Dataset Preparation diff --git a/docs/zh_cn/get_started.md b/docs/zh_cn/get_started.md index 14e79b92..77634500 100644 --- a/docs/zh_cn/get_started.md +++ b/docs/zh_cn/get_started.md @@ -2,39 +2,58 @@ 1. 准备 OpenCompass 运行环境: -```bash -conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y -conda activate opencompass -``` + ```bash + conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y + conda activate opencompass + ``` -如果你希望自定义 PyTorch 版本或相关的 CUDA 版本,请参考 [官方文档](https://pytorch.org/get-started/locally/) 准备 PyTorch 环境。需要注意的是,OpenCompass 要求 `pytorch>=1.13`。 + 如果你希望自定义 PyTorch 版本或相关的 CUDA 版本,请参考 [官方文档](https://pytorch.org/get-started/locally/) 准备 PyTorch 环境。需要注意的是,OpenCompass 要求 `pytorch>=1.13`。 2. 安装 OpenCompass: -```bash -git clone https://github.com/InternLM/opencompass.git -cd opencompass -pip install -e . -``` + ```bash + git clone https://github.com/InternLM/opencompass.git + cd opencompass + pip install -e . + ``` 3. 安装 humaneval(可选): -如果你需要**在 humaneval 数据集上评估模型代码能力**,请执行此步骤,否则忽略这一步。 + 如果你需要**在 humaneval 数据集上评估模型代码能力**,请执行此步骤,否则忽略这一步。 -
-点击查看详细 +
+ 点击查看详细 -```bash -git clone https://github.com/openai/human-eval.git -cd human-eval -pip install -r requirements.txt -pip install -e . -cd .. -``` + ```bash + git clone https://github.com/openai/human-eval.git + cd human-eval + pip install -r requirements.txt + pip install -e . + cd .. + ``` -请仔细阅读 `human_eval/execution.py` **第48-57行**的注释,了解执行模型生成的代码可能存在的风险,如果接受这些风险,请取消**第58行**的注释,启用代码执行评测。 + 请仔细阅读 `human_eval/execution.py` **第48-57行**的注释,了解执行模型生成的代码可能存在的风险,如果接受这些风险,请取消**第58行**的注释,启用代码执行评测。 -
+
+ +4. 安装 Llama(可选): + + 如果你需要**使用官方实现评测 Llama / Llama-2 / Llama-2-chat 模型**,请执行此步骤,否则忽略这一步。 + +
+ 点击查看详细 + + ```bash + git clone https://github.com/facebookresearch/llama.git + cd llama + pip install -r requirements.txt + pip install -e . + cd .. + ``` + + 你可以在 `configs/models` 下找到所有 Llama / Llama-2 / Llama-2-chat 模型的配置文件示例。([示例](https://github.com/InternLM/opencompass/blob/eb4822a94d624a4e16db03adeb7a59bbd10c2012/configs/models/llama2_7b_chat.py)) + +
# 数据集准备 diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index fa46042e..6df976be 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -3,4 +3,5 @@ from .base_api import APITemplateParser, BaseAPIModel # noqa from .glm import GLM130B # noqa: F401, F403 from .huggingface import HuggingFace # noqa: F401, F403 from .huggingface import HuggingFaceCausalLM # noqa: F401, F403 +from .llama2 import Llama2Chat # noqa: F401, F403 from .openai_api import OpenAI # noqa: F401 diff --git a/opencompass/models/llama2.py b/opencompass/models/llama2.py new file mode 100644 index 00000000..4ed076a0 --- /dev/null +++ b/opencompass/models/llama2.py @@ -0,0 +1,113 @@ +from typing import Dict, List, Optional, Union + +from opencompass.models.base import BaseModel +from opencompass.models.base_api import APITemplateParser +from opencompass.utils.logging import get_logger +from opencompass.utils.prompt import PromptList + +PromptType = Union[PromptList, str] + + +class Llama2Chat(BaseModel): + """LLaMA-2 chat model wrapper + https://github.com/facebookresearch/llama/tree/main. + + Args: + path (str): path to the model directory + max_seq_len (int): max sequence length + max_batch_size (int): max batch size + tokenizer_only (bool): whether to load tokenizer only + tokenizer_path (str): path to the tokenizer directory + meta_template (dict): meta template for the model + """ + + def __init__( + self, + path: str, + max_seq_len: int = 2048, + max_batch_size: int = 16, + tokenizer_only: bool = False, + tokenizer_path: Optional[str] = None, + meta_template: Optional[Dict] = None, + ): # noqa + if tokenizer_only: + self._load_tokenizer(tokenizer_path=tokenizer_path) + else: + self._load_model(path=path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + tokenizer_path=tokenizer_path) + self.max_seq_len = max_seq_len + self.template_parser = APITemplateParser(meta_template) + self.logger = get_logger() + + def _load_model(self, + path: str, + max_seq_len: int, + max_batch_size: int, + tokenizer_path: Optional[str] = None): + from llama import Llama + self.generator = Llama.build(path, tokenizer_path, max_seq_len, + max_batch_size) + self.tokenizer = self.generator.tokenizer + self.model = self.generator.model + + def _load_tokenizer(self, tokenizer_path: str): + from llama import Tokenizer + self.tokenizer = Tokenizer(tokenizer_path) + + def generate(self, + inputs: List[str or PromptList], + max_out_len: int = 512, + temperature: float = 0.6) -> str: + """Generate response from input prompt. + + Args: + inputs (list): input prompt + max_out_len (int): max output length + temperature (float): temperature for sampling + """ + dialogs = [] + for input in inputs: + assert isinstance(input, (str, PromptList)) + if isinstance(input, str): + dialog = [{'role': 'user', 'content': input}] + else: + dialog = [] + for item in input: + msg = {'content': item['prompt']} + if item['role'] == 'HUMAN': + msg['role'] = 'user' + elif item['role'] == 'BOT': + msg['role'] = 'assistant' + elif item['role'] == 'SYSTEM': + msg['role'] = 'system' + dialog.append(msg) + dialogs.append(dialog) + + try: + results = self.generator.chat_completion( + dialogs, # type: ignore + max_gen_len=max_out_len, + temperature=temperature, + ) + return [r['generation']['content'] for r in results] + except AssertionError: + self.warning('Batched data max token limit exceeded, ' + 'try to run one by one...') + + results = [] + for dialog in dialogs: + try: + result = self.generator.chat_completion( + [dialog], # type: ignore + max_gen_len=max_out_len, + temperature=temperature, + )[0] + results.append(result['generation']['content']) + except AssertionError: + results.append('') + return results + + def get_token_len(self, prompt: str) -> int: + return len(self.tokenizer.encode(prompt, bos=True, eos=True)) + 100