mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Enhancement] Update API Interface and Mixtral (#681)
* [Enhancement] Update API interface * [Enhancement] Update API interface * Update mixtral * Update readme
This commit is contained in:
parent
1bf85949ef
commit
e25c5f9525
@ -50,6 +50,7 @@ Just like a compass guides us on our journey, OpenCompass will guide you through
|
||||
|
||||
## 🚀 What's New <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
|
||||
|
||||
- **\[2023.12.10\]** We have supported Mistral AI's MoE LLM: **Mixtral-8x7B-32K**. Welcome to [MixtralKit](https://github.com/open-compass/MixtralKit) for more details about inference and evaluation. 🔥🔥🔥.
|
||||
- **\[2023.11.22\]** We have supported many API-based models, include **Baidu, ByteDance, Huawei, 360**. Welcome to [Models](https://opencompass.readthedocs.io/en/latest/user_guides/models.html) section for more details. 🔥🔥🔥.
|
||||
- **\[2023.11.20\]** Thanks [helloyongyang](https://github.com/helloyongyang) for supporting the evaluation with [LightLLM](https://github.com/ModelTC/lightllm) as backent. Welcome to [Evaluation With LightLLM](https://opencompass.readthedocs.io/en/latest/advanced_guides/evaluation_lightllm.html) for more details. 🔥🔥🔥.
|
||||
- **\[2023.11.13\]** We are delighted to announce the release of OpenCompass v0.1.8. This version enables local loading of evaluation benchmarks, thereby eliminating the need for an internet connection. Please note that with this update, **you must re-download all evaluation datasets** to ensure accurate and up-to-date results.🔥🔥🔥.
|
||||
|
@ -50,6 +50,7 @@
|
||||
|
||||
## 🚀 最新进展 <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
|
||||
|
||||
- **\[2023.12.10\]** 我们已经支持了Mistral AI的MoE模型 **Mixtral-8x7B-32K**。欢迎查阅[MixtralKit](https://github.com/open-compass/MixtralKit)以获取更多关于推理和评测的详细信息。🔥🔥🔥。
|
||||
- **\[2023.11.22\]** 我们已经支持了多个于API的模型,包括**百度、字节跳动、华为、360**。欢迎查阅[模型](https://opencompass.readthedocs.io/en/latest/user_guides/models.html)部分以获取更多详细信息。🔥🔥🔥。
|
||||
- **\[2023.11.20\]** 感谢[helloyongyang](https://github.com/helloyongyang)支持使用[LightLLM](https://github.com/ModelTC/lightllm)作为后端进行评估。欢迎查阅[使用LightLLM进行评估](https://opencompass.readthedocs.io/en/latest/advanced_guides/evaluation_lightllm.html)以获取更多详细信息。🔥🔥🔥。
|
||||
- **\[2023.11.13\]** 我们很高兴地宣布发布 OpenCompass v0.1.8 版本。此版本支持本地加载评估基准,从而无需连接互联网。请注意,随着此更新的发布,**您需要重新下载所有评估数据集**,以确保结果准确且最新。🔥🔥🔥。
|
||||
|
@ -18,6 +18,13 @@ models = [
|
||||
type=AI360GPT,
|
||||
path='360GPT_S2_V9',
|
||||
key="xxxxxxxxxxxx",
|
||||
generation_kwargs={
|
||||
'temperature': 0.9,
|
||||
'max_tokens': 2048,
|
||||
'top_p': 0.5,
|
||||
'tok_k': 0,
|
||||
'repetition_penalty': 1.05,
|
||||
},
|
||||
query_per_second=1,
|
||||
max_out_len=2048,
|
||||
max_seq_len=2048,
|
||||
|
@ -20,6 +20,12 @@ models = [
|
||||
api_key='xxxxxx',
|
||||
secret_key="xxxxx",
|
||||
url="xxxxx",
|
||||
generation_kwargs={
|
||||
'temperature': 0.3,
|
||||
'top_p': 0.85,
|
||||
'top_k': 5,
|
||||
'with_search_enhance': False,
|
||||
},
|
||||
query_per_second=1,
|
||||
max_out_len=2048,
|
||||
max_seq_len=2048,
|
||||
|
@ -20,10 +20,14 @@ models = [
|
||||
key='xxxxxx', # please give you key
|
||||
secretkey='xxxxxxxxx', # please give your group_id
|
||||
url='xxxxxxxxx',
|
||||
generation_kwargs = {
|
||||
'temperature': 0.8,
|
||||
},
|
||||
query_per_second=1,
|
||||
max_out_len=2048,
|
||||
max_seq_len=2048,
|
||||
batch_size=8),
|
||||
batch_size=8
|
||||
),
|
||||
]
|
||||
|
||||
infer = dict(
|
||||
|
@ -21,6 +21,11 @@ models = [
|
||||
accesskey="xxxxxxx",
|
||||
secretkey="xxxxxxx",
|
||||
url='xxxxxx',
|
||||
generation_kwargs={
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.9,
|
||||
'top_k': 0,
|
||||
},
|
||||
query_per_second=1,
|
||||
max_out_len=2048,
|
||||
max_seq_len=2048,
|
||||
|
@ -19,6 +19,9 @@ models = [
|
||||
path='moonshot-v1-32k',
|
||||
key='xxxxxxx',
|
||||
url= 'xxxxxxxx',
|
||||
system_prompt= '你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。'
|
||||
'你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一些涉及恐怖主义,种族歧视,'
|
||||
'黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。',
|
||||
query_per_second=1,
|
||||
max_out_len=2048,
|
||||
max_seq_len=2048,
|
||||
|
8
configs/eval_mixtral_8x7b.py
Normal file
8
configs/eval_mixtral_8x7b.py
Normal file
@ -0,0 +1,8 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .datasets.collections.base_medium_llama import piqa_datasets, siqa_datasets
|
||||
from .models.mistral.mixtral_8x7b_32k import models
|
||||
|
||||
|
||||
datasets = [*piqa_datasets, *siqa_datasets]
|
19
configs/models/mixtral/mixtral_8x7b_32k.py
Normal file
19
configs/models/mixtral/mixtral_8x7b_32k.py
Normal file
@ -0,0 +1,19 @@
|
||||
from opencompass.models import Mixtral
|
||||
|
||||
# Please follow the instruction in https://github.com/open-compass/MixtralKit
|
||||
# to download the model weights and install the requirements
|
||||
|
||||
|
||||
models = [
|
||||
dict(
|
||||
abbr="mixtral-8x7b-32k",
|
||||
type=Mixtral,
|
||||
path="./models/mixtral/mixtral-8x7b-32kseqlen",
|
||||
tokenizer_path="./models/mixtral/mixtral-8x7b-32kseqlen/tokenizer.model",
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
num_gpus=2,
|
||||
run_cfg=dict(num_gpus=2, num_procs=1),
|
||||
),
|
||||
]
|
@ -14,6 +14,7 @@ from .intern_model import InternLM # noqa: F401, F403
|
||||
from .lightllm_api import LightllmAPI # noqa: F401
|
||||
from .llama2 import Llama2, Llama2Chat # noqa: F401, F403
|
||||
from .minimax_api import MiniMax # noqa: F401
|
||||
from .mixtral import Mixtral # noqa: F401
|
||||
from .modelscope import ModelScope, ModelScopeCausalLM # noqa: F401, F403
|
||||
from .moonshot_api import MoonShot # noqa: F401
|
||||
from .openai_api import OpenAI # noqa: F401
|
||||
|
@ -38,12 +38,19 @@ class AI360GPT(BaseAPIModel):
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2,
|
||||
):
|
||||
generation_kwargs: Dict = {
|
||||
'temperature': 0.9,
|
||||
'max_tokens': 2048,
|
||||
'top_p': 0.5,
|
||||
'tok_k': 0,
|
||||
'repetition_penalty': 1.05,
|
||||
}): # noqa E125
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
query_per_second=query_per_second,
|
||||
meta_template=meta_template,
|
||||
retry=retry)
|
||||
retry=retry,
|
||||
generation_kwargs=generation_kwargs)
|
||||
self.headers = {
|
||||
'Authorization': f'Bearer {key}',
|
||||
'Content-Type': 'application/json',
|
||||
@ -110,15 +117,11 @@ class AI360GPT(BaseAPIModel):
|
||||
'model': self.model,
|
||||
'messages': messages,
|
||||
'stream': False,
|
||||
'temperature': 0.9,
|
||||
'max_tokens': 2048,
|
||||
'top_p': 0.5,
|
||||
'tok_k': 0,
|
||||
'repetition_penalty': 1.05,
|
||||
# "num_beams": 1,
|
||||
# "user": "OpenCompass"
|
||||
}
|
||||
|
||||
data.update(self.generation_kwargs)
|
||||
|
||||
max_num_retries = 0
|
||||
while max_num_retries < self.retry:
|
||||
self.acquire()
|
||||
|
@ -43,12 +43,18 @@ class BaiChuan(BaseAPIModel):
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2,
|
||||
):
|
||||
generation_kwargs: Dict = {
|
||||
'temperature': 0.3,
|
||||
'top_p': 0.85,
|
||||
'top_k': 5,
|
||||
'with_search_enhance': False,
|
||||
}): # noqa E125
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
query_per_second=query_per_second,
|
||||
meta_template=meta_template,
|
||||
retry=retry)
|
||||
retry=retry,
|
||||
generation_kwargs=generation_kwargs)
|
||||
|
||||
self.api_key = api_key
|
||||
self.secret_key = secret_key
|
||||
@ -111,6 +117,7 @@ class BaiChuan(BaseAPIModel):
|
||||
messages.append(msg)
|
||||
|
||||
data = {'model': self.model, 'messages': messages}
|
||||
data.update(self.generation_kwargs)
|
||||
|
||||
def calculate_md5(input_string):
|
||||
md5 = hashlib.md5()
|
||||
|
@ -32,27 +32,28 @@ class ERNIEBot(BaseAPIModel):
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
key: str,
|
||||
secretkey: str,
|
||||
url: str,
|
||||
query_per_second: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2,
|
||||
):
|
||||
def __init__(self,
|
||||
path: str,
|
||||
key: str,
|
||||
secretkey: str,
|
||||
url: str,
|
||||
query_per_second: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2,
|
||||
generation_kwargs: Dict = {
|
||||
'temperature': 0.8,
|
||||
}):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
query_per_second=query_per_second,
|
||||
meta_template=meta_template,
|
||||
retry=retry)
|
||||
retry=retry,
|
||||
generation_kwargs=generation_kwargs)
|
||||
self.headers = {'Content_Type': 'application/json'}
|
||||
self.secretkey = secretkey
|
||||
self.key = key
|
||||
self.url = url
|
||||
self.model = path
|
||||
|
||||
def _generate_access_token(self):
|
||||
try:
|
||||
@ -148,6 +149,7 @@ class ERNIEBot(BaseAPIModel):
|
||||
|
||||
messages.append(msg)
|
||||
data = {'messages': messages}
|
||||
data.update(self.generation_kwargs)
|
||||
|
||||
max_num_retries = 0
|
||||
while max_num_retries < self.retry:
|
||||
|
@ -33,22 +33,26 @@ class ByteDance(BaseAPIModel):
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
accesskey: str,
|
||||
secretkey: str,
|
||||
url: str,
|
||||
query_per_second: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2,
|
||||
):
|
||||
def __init__(self,
|
||||
path: str,
|
||||
accesskey: str,
|
||||
secretkey: str,
|
||||
url: str,
|
||||
query_per_second: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2,
|
||||
generation_kwargs: Dict = {
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.9,
|
||||
'top_k': 0,
|
||||
}):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
query_per_second=query_per_second,
|
||||
meta_template=meta_template,
|
||||
retry=retry)
|
||||
retry=retry,
|
||||
generation_kwargs=generation_kwargs)
|
||||
if not ChatRole:
|
||||
print('Please install related packages via'
|
||||
' `pip install volcengine`')
|
||||
@ -134,7 +138,8 @@ class ByteDance(BaseAPIModel):
|
||||
'model': {
|
||||
'name': 'skylark-pro-public',
|
||||
},
|
||||
'messages': messages
|
||||
'messages': messages,
|
||||
'parameters': self.generation_kwargs
|
||||
}
|
||||
|
||||
def _chat(maas, req):
|
||||
|
109
opencompass/models/mixtral.py
Normal file
109
opencompass/models/mixtral.py
Normal file
@ -0,0 +1,109 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from opencompass.models.base import BaseModel
|
||||
from opencompass.models.base_api import APITemplateParser
|
||||
from opencompass.utils.logging import get_logger
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
||||
|
||||
class Mixtral(BaseModel):
|
||||
"""Mixtral model wrapper https://github.com/open-compass/MixtralKit.
|
||||
|
||||
Args:
|
||||
path (str): path to the model directory
|
||||
max_seq_len (int): max sequence length
|
||||
max_batch_size (int): max batch size
|
||||
tokenizer_only (bool): whether to load tokenizer only
|
||||
tokenizer_path (str): path to the tokenizer directory
|
||||
meta_template (dict): meta template for the model
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
max_seq_len: int = 2048,
|
||||
max_batch_size: int = 8,
|
||||
tokenizer_only: bool = False,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
meta_template: Optional[Dict] = None,
|
||||
num_gpus: int = 2,
|
||||
): # noqa
|
||||
if tokenizer_only:
|
||||
self._load_tokenizer(tokenizer_path=tokenizer_path)
|
||||
else:
|
||||
self._load_model(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
tokenizer_path=tokenizer_path,
|
||||
num_gpus=num_gpus)
|
||||
self.max_seq_len = max_seq_len
|
||||
self.template_parser = APITemplateParser(meta_template)
|
||||
self.logger = get_logger()
|
||||
|
||||
def _load_model(self,
|
||||
path: str,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
num_gpus: int = 2):
|
||||
from mixtralkit.mixtral import Mixtral
|
||||
self.generator = Mixtral.build(ckpt_dir=path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
num_gpus=num_gpus)
|
||||
self.tokenizer = self.generator.tokenizer
|
||||
self.model = self.generator.model
|
||||
|
||||
def _load_tokenizer(self, tokenizer_path: str):
|
||||
from mixtralkit.layers import Tokenizer
|
||||
self.tokenizer = Tokenizer(tokenizer_path)
|
||||
|
||||
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
||||
prompt_tokens = []
|
||||
for input in inputs:
|
||||
tokens = self.tokenizer.encode(input, True, False)
|
||||
num_token = min(self.model.params.max_seq_len, len(tokens))
|
||||
prompt_tokens.append(tokens[-num_token:])
|
||||
generation_tokens, _ = self.generator.generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_gen_len=max_out_len,
|
||||
temperature=0,
|
||||
)
|
||||
results = [self.tokenizer.decode(t) for t in generation_tokens]
|
||||
return results
|
||||
|
||||
def get_ppl(self,
|
||||
inputs: List[str],
|
||||
mask_length: Optional[List[int]] = None) -> List[float]:
|
||||
assert mask_length is None, 'mask_length is not supported'
|
||||
bsz = len(inputs)
|
||||
params = self.model.params
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
# tokenize
|
||||
prompt_tokens = [self.tokenizer.encode(x, True, False) for x in inputs]
|
||||
max_prompt_size = max([len(t) for t in prompt_tokens])
|
||||
total_len = min(params.max_seq_len, max_prompt_size)
|
||||
tokens = torch.zeros((bsz, total_len)).cuda().long()
|
||||
for k, t in enumerate(prompt_tokens):
|
||||
num_token = min(total_len, len(t))
|
||||
tokens[k, :num_token] = torch.tensor(t[-num_token:]).long()
|
||||
# forward
|
||||
outputs = self.model.forward(tokens, 0)
|
||||
# compute ppl
|
||||
shift_logits = outputs[..., :-1, :].contiguous().float()
|
||||
shift_labels = tokens[..., 1:].contiguous()
|
||||
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
||||
shift_labels = shift_labels.view(-1)
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=0)
|
||||
loss = loss_fct(shift_logits, shift_labels).view(bsz, -1)
|
||||
lens = (tokens != 0).sum(-1).cpu().numpy()
|
||||
ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
|
||||
return ce_loss
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
return len(self.tokenizer.encode(prompt, True, True))
|
@ -38,6 +38,7 @@ class MoonShot(BaseAPIModel):
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2,
|
||||
system_prompt: str = '',
|
||||
):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
@ -50,6 +51,7 @@ class MoonShot(BaseAPIModel):
|
||||
}
|
||||
self.url = url
|
||||
self.model = path
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -106,12 +108,11 @@ class MoonShot(BaseAPIModel):
|
||||
messages.append(msg)
|
||||
|
||||
system = {
|
||||
'role':
|
||||
'system',
|
||||
'content':
|
||||
'你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。'
|
||||
'你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一些涉及恐怖主义,种族歧视,'
|
||||
'黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。'
|
||||
'role': 'system',
|
||||
'content': self.system_prompt
|
||||
# '你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。'
|
||||
# '你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一些涉及恐怖主义,种族歧视,'
|
||||
# '黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。'
|
||||
}
|
||||
|
||||
messages.insert(0, system)
|
||||
@ -150,10 +151,12 @@ class MoonShot(BaseAPIModel):
|
||||
print('请求被拒绝 api_key错误')
|
||||
continue
|
||||
elif raw_response.status_code == 400:
|
||||
print(messages, response)
|
||||
print('请求失败,状态码:', raw_response)
|
||||
time.sleep(1)
|
||||
continue
|
||||
elif raw_response.status_code == 429:
|
||||
print(messages, response)
|
||||
print('请求失败,状态码:', raw_response)
|
||||
time.sleep(3)
|
||||
continue
|
||||
|
Loading…
Reference in New Issue
Block a user