diff --git a/README.md b/README.md
index e1e662e4..70953b87 100644
--- a/README.md
+++ b/README.md
@@ -27,6 +27,7 @@ Just like a compass guides us on our journey, OpenCompass will guide you through
## News
+- **\[2023.07.19\]** We have supported [Llama 2](https://ai.meta.com/llama/)! Its performance report will be available soon. \[[doc](./docs/en/get_started.md#Installation)\]
- **\[2023.07.13\]** We release [MMBench](https://opencompass.org.cn/MMBench), a meticulously curated dataset to comprehensively evaluate different abilities of multimodality models 🔥🔥🔥.
## Introduction
diff --git a/README_zh-CN.md b/README_zh-CN.md
index f7656433..fae6e28f 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -25,6 +25,11 @@
就像指南针在我们的旅程中为我们导航一样,我们希望OpenCompass能够帮助你穿越评估大型语言模型的重重迷雾。OpenCompass提供丰富的算法和功能支持,期待OpenCompass能够帮助社区更便捷地对NLP模型的性能进行公平全面的评估。
+## 更新
+
+- **\[2023.07.19\]** 新增了 [Llama 2](https://ai.meta.com/llama/)!我们近期将会公布其评测结果。\[[文档](./docs/zh_cn/get_started.md#安装)\]
+- **\[2023.07.13\]** 发布了 [MMBench](https://opencompass.org.cn/MMBench),该数据集经过细致整理,用于评测多模态模型全方位能力 🔥🔥🔥。
+
## 介绍
OpenCompass 是面向大模型评测的一站式平台。其主要特点如下:
diff --git a/configs/models/hf_llama2_13b.py b/configs/models/hf_llama2_13b.py
new file mode 100644
index 00000000..4103c874
--- /dev/null
+++ b/configs/models/hf_llama2_13b.py
@@ -0,0 +1,21 @@
+from opencompass.models import HuggingFaceCausalLM
+
+
+models = [
+ dict(
+ type=HuggingFaceCausalLM,
+ abbr='llama-2-13b-hf',
+ path="meta-llama/Llama-2-13b-hf",
+ tokenizer_path='meta-llama/Llama-2-13b-hf',
+ tokenizer_kwargs=dict(padding_side='left',
+ truncation_side='left',
+ use_fast=False,
+ ),
+ max_out_len=100,
+ max_seq_len=2048,
+ batch_size=8,
+ model_kwargs=dict(device_map='auto'),
+ batch_padding=False, # if false, inference with for-loop without batch padding
+ run_cfg=dict(num_gpus=2, num_procs=1),
+ )
+]
diff --git a/configs/models/hf_llama2_70b.py b/configs/models/hf_llama2_70b.py
new file mode 100644
index 00000000..44078cf0
--- /dev/null
+++ b/configs/models/hf_llama2_70b.py
@@ -0,0 +1,21 @@
+from opencompass.models import HuggingFaceCausalLM
+
+
+models = [
+ dict(
+ type=HuggingFaceCausalLM,
+ abbr='llama-2-70b-hf',
+ path="meta-llama/Llama-2-70b-hf",
+ tokenizer_path='meta-llama/Llama-2-70b-hf',
+ tokenizer_kwargs=dict(padding_side='left',
+ truncation_side='left',
+ use_fast=False,
+ ),
+ max_out_len=100,
+ max_seq_len=2048,
+ batch_size=8,
+ model_kwargs=dict(device_map='auto'),
+ batch_padding=False, # if false, inference with for-loop without batch padding
+ run_cfg=dict(num_gpus=8, num_procs=1),
+ )
+]
diff --git a/configs/models/hf_llama2_7b.py b/configs/models/hf_llama2_7b.py
new file mode 100644
index 00000000..3d00990e
--- /dev/null
+++ b/configs/models/hf_llama2_7b.py
@@ -0,0 +1,21 @@
+from opencompass.models import HuggingFaceCausalLM
+
+
+models = [
+ dict(
+ type=HuggingFaceCausalLM,
+ abbr='llama-2-7b-hf',
+ path="meta-llama/Llama-2-7b-hf",
+ tokenizer_path='meta-llama/Llama-2-7b-hf',
+ tokenizer_kwargs=dict(padding_side='left',
+ truncation_side='left',
+ use_fast=False,
+ ),
+ max_out_len=100,
+ max_seq_len=2048,
+ batch_size=8,
+ model_kwargs=dict(device_map='auto'),
+ batch_padding=False, # if false, inference with for-loop without batch padding
+ run_cfg=dict(num_gpus=1, num_procs=1),
+ )
+]
diff --git a/configs/models/llama2_13b_chat.py b/configs/models/llama2_13b_chat.py
new file mode 100644
index 00000000..ef77148c
--- /dev/null
+++ b/configs/models/llama2_13b_chat.py
@@ -0,0 +1,31 @@
+from opencompass.models import Llama2Chat
+
+# Please follow the instruction in the Meta AI website https://github.com/facebookresearch/llama
+# and download the LLaMA-2-Chat model and tokenizer to the path './models/llama2/llama/'.
+#
+# The LLaMA requirement is also needed to be installed.
+#
+# git clone https://github.com/facebookresearch/llama.git
+# cd llama
+# pip install -e .
+
+api_meta_template = dict(
+ round=[
+ dict(role="HUMAN", api_role="HUMAN"),
+ dict(role="BOT", api_role="BOT", generate=True),
+ ],
+)
+
+models = [
+ dict(
+ abbr="llama-2-13b-chat",
+ type=Llama2Chat,
+ path="./models/llama2/llama/llama-2-13b-chat/",
+ tokenizer_path="./models/llama2/llama/tokenizer.model",
+ meta_template=api_meta_template,
+ max_out_len=100,
+ max_seq_len=2048,
+ batch_size=16,
+ run_cfg=dict(num_gpus=2, num_procs=2),
+ ),
+]
diff --git a/configs/models/llama2_70b_chat.py b/configs/models/llama2_70b_chat.py
new file mode 100644
index 00000000..94a94b5c
--- /dev/null
+++ b/configs/models/llama2_70b_chat.py
@@ -0,0 +1,31 @@
+from opencompass.models import Llama2Chat
+
+# Please follow the instruction in the Meta AI website https://github.com/facebookresearch/llama
+# and download the LLaMA-2-Chat model and tokenizer to the path './models/llama2/llama/'.
+#
+# The LLaMA requirement is also needed to be installed.
+#
+# git clone https://github.com/facebookresearch/llama.git
+# cd llama
+# pip install -e .
+
+api_meta_template = dict(
+ round=[
+ dict(role="HUMAN", api_role="HUMAN"),
+ dict(role="BOT", api_role="BOT", generate=True),
+ ],
+)
+
+models = [
+ dict(
+ abbr="llama-2-70b-chat",
+ type=Llama2Chat,
+ path="./models/llama2/llama/llama-2-70b-chat/",
+ tokenizer_path="./models/llama2/llama/tokenizer.model",
+ meta_template=api_meta_template,
+ max_out_len=100,
+ max_seq_len=2048,
+ batch_size=16,
+ run_cfg=dict(num_gpus=8, num_procs=8),
+ ),
+]
diff --git a/configs/models/llama2_7b_chat.py b/configs/models/llama2_7b_chat.py
new file mode 100644
index 00000000..f3cb571e
--- /dev/null
+++ b/configs/models/llama2_7b_chat.py
@@ -0,0 +1,31 @@
+from opencompass.models import Llama2Chat
+
+# Please follow the instruction in the Meta AI website https://github.com/facebookresearch/llama
+# and download the LLaMA-2-Chat model and tokenizer to the path './models/llama2/llama/'.
+#
+# The LLaMA requirement is also needed to be installed.
+#
+# git clone https://github.com/facebookresearch/llama.git
+# cd llama
+# pip install -e .
+
+api_meta_template = dict(
+ round=[
+ dict(role="HUMAN", api_role="HUMAN"),
+ dict(role="BOT", api_role="BOT", generate=True),
+ ],
+)
+
+models = [
+ dict(
+ abbr="llama-2-7b-chat",
+ type=Llama2Chat,
+ path="./models/llama2/llama/llama-2-7b-chat/",
+ tokenizer_path="./models/llama2/llama/tokenizer.model",
+ meta_template=api_meta_template,
+ max_out_len=100,
+ max_seq_len=2048,
+ batch_size=16,
+ run_cfg=dict(num_gpus=1, num_procs=1),
+ ),
+]
diff --git a/docs/en/get_started.md b/docs/en/get_started.md
index 9c8660a5..b77afb1f 100644
--- a/docs/en/get_started.md
+++ b/docs/en/get_started.md
@@ -2,39 +2,58 @@
1. Set up the OpenCompass environment:
-```bash
-conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y
-conda activate opencompass
-```
+ ```bash
+ conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y
+ conda activate opencompass
+ ```
-If you want to customize the PyTorch version or related CUDA version, please refer to the [official documentation](https://pytorch.org/get-started/locally/) to set up the PyTorch environment. Note that OpenCompass requires `pytorch>=1.13`.
+ If you want to customize the PyTorch version or related CUDA version, please refer to the [official documentation](https://pytorch.org/get-started/locally/) to set up the PyTorch environment. Note that OpenCompass requires `pytorch>=1.13`.
2. Install OpenCompass:
-```bash
-git clone https://github.com/InternLM/opencompass.git
-cd opencompass
-pip install -e .
-```
+ ```bash
+ git clone https://github.com/InternLM/opencompass.git
+ cd opencompass
+ pip install -e .
+ ```
3. Install humaneval (Optional)
-If you want to **evaluate your models coding ability on the humaneval dataset**, execute this step otherwise skip it.
+ If you want to **evaluate your models coding ability on the humaneval dataset**, follow this step.
-
-click to show the details
+
+ click to show the details
-```bash
-git clone https://github.com/openai/human-eval.git
-cd human-eval
-pip install -r requirements.txt
-pip install -e .
-cd ..
-```
+ ```bash
+ git clone https://github.com/openai/human-eval.git
+ cd human-eval
+ pip install -r requirements.txt
+ pip install -e .
+ cd ..
+ ```
-Please read the comments in `human_eval/execution.py` **lines 48-57** to understand the potential risks of executing the model generation code. If you accept these risks, uncomment **line 58** to enable code execution evaluation.
+ Please read the comments in `human_eval/execution.py` **lines 48-57** to understand the potential risks of executing the model generation code. If you accept these risks, uncomment **line 58** to enable code execution evaluation.
-
+
+
+4. Install Llama (Optional)
+
+ If you want to **evaluate Llama / Llama-2 / Llama-2-chat with its official implementation**, follow this step.
+
+
+ click to show the details
+
+ ```bash
+ git clone https://github.com/facebookresearch/llama.git
+ cd llama
+ pip install -r requirements.txt
+ pip install -e .
+ cd ..
+ ```
+
+ You can find example configs in `configs/models`. ([example](https://github.com/InternLM/opencompass/blob/eb4822a94d624a4e16db03adeb7a59bbd10c2012/configs/models/llama2_7b_chat.py))
+
+
# Dataset Preparation
diff --git a/docs/zh_cn/get_started.md b/docs/zh_cn/get_started.md
index 14e79b92..77634500 100644
--- a/docs/zh_cn/get_started.md
+++ b/docs/zh_cn/get_started.md
@@ -2,39 +2,58 @@
1. 准备 OpenCompass 运行环境:
-```bash
-conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y
-conda activate opencompass
-```
+ ```bash
+ conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y
+ conda activate opencompass
+ ```
-如果你希望自定义 PyTorch 版本或相关的 CUDA 版本,请参考 [官方文档](https://pytorch.org/get-started/locally/) 准备 PyTorch 环境。需要注意的是,OpenCompass 要求 `pytorch>=1.13`。
+ 如果你希望自定义 PyTorch 版本或相关的 CUDA 版本,请参考 [官方文档](https://pytorch.org/get-started/locally/) 准备 PyTorch 环境。需要注意的是,OpenCompass 要求 `pytorch>=1.13`。
2. 安装 OpenCompass:
-```bash
-git clone https://github.com/InternLM/opencompass.git
-cd opencompass
-pip install -e .
-```
+ ```bash
+ git clone https://github.com/InternLM/opencompass.git
+ cd opencompass
+ pip install -e .
+ ```
3. 安装 humaneval(可选):
-如果你需要**在 humaneval 数据集上评估模型代码能力**,请执行此步骤,否则忽略这一步。
+ 如果你需要**在 humaneval 数据集上评估模型代码能力**,请执行此步骤,否则忽略这一步。
-
-点击查看详细
+
+ 点击查看详细
-```bash
-git clone https://github.com/openai/human-eval.git
-cd human-eval
-pip install -r requirements.txt
-pip install -e .
-cd ..
-```
+ ```bash
+ git clone https://github.com/openai/human-eval.git
+ cd human-eval
+ pip install -r requirements.txt
+ pip install -e .
+ cd ..
+ ```
-请仔细阅读 `human_eval/execution.py` **第48-57行**的注释,了解执行模型生成的代码可能存在的风险,如果接受这些风险,请取消**第58行**的注释,启用代码执行评测。
+ 请仔细阅读 `human_eval/execution.py` **第48-57行**的注释,了解执行模型生成的代码可能存在的风险,如果接受这些风险,请取消**第58行**的注释,启用代码执行评测。
-
+
+
+4. 安装 Llama(可选):
+
+ 如果你需要**使用官方实现评测 Llama / Llama-2 / Llama-2-chat 模型**,请执行此步骤,否则忽略这一步。
+
+
+ 点击查看详细
+
+ ```bash
+ git clone https://github.com/facebookresearch/llama.git
+ cd llama
+ pip install -r requirements.txt
+ pip install -e .
+ cd ..
+ ```
+
+ 你可以在 `configs/models` 下找到所有 Llama / Llama-2 / Llama-2-chat 模型的配置文件示例。([示例](https://github.com/InternLM/opencompass/blob/eb4822a94d624a4e16db03adeb7a59bbd10c2012/configs/models/llama2_7b_chat.py))
+
+
# 数据集准备
diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py
index fa46042e..6df976be 100644
--- a/opencompass/models/__init__.py
+++ b/opencompass/models/__init__.py
@@ -3,4 +3,5 @@ from .base_api import APITemplateParser, BaseAPIModel # noqa
from .glm import GLM130B # noqa: F401, F403
from .huggingface import HuggingFace # noqa: F401, F403
from .huggingface import HuggingFaceCausalLM # noqa: F401, F403
+from .llama2 import Llama2Chat # noqa: F401, F403
from .openai_api import OpenAI # noqa: F401
diff --git a/opencompass/models/llama2.py b/opencompass/models/llama2.py
new file mode 100644
index 00000000..4ed076a0
--- /dev/null
+++ b/opencompass/models/llama2.py
@@ -0,0 +1,113 @@
+from typing import Dict, List, Optional, Union
+
+from opencompass.models.base import BaseModel
+from opencompass.models.base_api import APITemplateParser
+from opencompass.utils.logging import get_logger
+from opencompass.utils.prompt import PromptList
+
+PromptType = Union[PromptList, str]
+
+
+class Llama2Chat(BaseModel):
+ """LLaMA-2 chat model wrapper
+ https://github.com/facebookresearch/llama/tree/main.
+
+ Args:
+ path (str): path to the model directory
+ max_seq_len (int): max sequence length
+ max_batch_size (int): max batch size
+ tokenizer_only (bool): whether to load tokenizer only
+ tokenizer_path (str): path to the tokenizer directory
+ meta_template (dict): meta template for the model
+ """
+
+ def __init__(
+ self,
+ path: str,
+ max_seq_len: int = 2048,
+ max_batch_size: int = 16,
+ tokenizer_only: bool = False,
+ tokenizer_path: Optional[str] = None,
+ meta_template: Optional[Dict] = None,
+ ): # noqa
+ if tokenizer_only:
+ self._load_tokenizer(tokenizer_path=tokenizer_path)
+ else:
+ self._load_model(path=path,
+ max_seq_len=max_seq_len,
+ max_batch_size=max_batch_size,
+ tokenizer_path=tokenizer_path)
+ self.max_seq_len = max_seq_len
+ self.template_parser = APITemplateParser(meta_template)
+ self.logger = get_logger()
+
+ def _load_model(self,
+ path: str,
+ max_seq_len: int,
+ max_batch_size: int,
+ tokenizer_path: Optional[str] = None):
+ from llama import Llama
+ self.generator = Llama.build(path, tokenizer_path, max_seq_len,
+ max_batch_size)
+ self.tokenizer = self.generator.tokenizer
+ self.model = self.generator.model
+
+ def _load_tokenizer(self, tokenizer_path: str):
+ from llama import Tokenizer
+ self.tokenizer = Tokenizer(tokenizer_path)
+
+ def generate(self,
+ inputs: List[str or PromptList],
+ max_out_len: int = 512,
+ temperature: float = 0.6) -> str:
+ """Generate response from input prompt.
+
+ Args:
+ inputs (list): input prompt
+ max_out_len (int): max output length
+ temperature (float): temperature for sampling
+ """
+ dialogs = []
+ for input in inputs:
+ assert isinstance(input, (str, PromptList))
+ if isinstance(input, str):
+ dialog = [{'role': 'user', 'content': input}]
+ else:
+ dialog = []
+ for item in input:
+ msg = {'content': item['prompt']}
+ if item['role'] == 'HUMAN':
+ msg['role'] = 'user'
+ elif item['role'] == 'BOT':
+ msg['role'] = 'assistant'
+ elif item['role'] == 'SYSTEM':
+ msg['role'] = 'system'
+ dialog.append(msg)
+ dialogs.append(dialog)
+
+ try:
+ results = self.generator.chat_completion(
+ dialogs, # type: ignore
+ max_gen_len=max_out_len,
+ temperature=temperature,
+ )
+ return [r['generation']['content'] for r in results]
+ except AssertionError:
+ self.warning('Batched data max token limit exceeded, '
+ 'try to run one by one...')
+
+ results = []
+ for dialog in dialogs:
+ try:
+ result = self.generator.chat_completion(
+ [dialog], # type: ignore
+ max_gen_len=max_out_len,
+ temperature=temperature,
+ )[0]
+ results.append(result['generation']['content'])
+ except AssertionError:
+ results.append('')
+ return results
+
+ def get_token_len(self, prompt: str) -> int:
+ return len(self.tokenizer.encode(prompt, bos=True, eos=True)) + 100