diff --git a/README.md b/README.md
index 441e6d09..80901b7c 100644
--- a/README.md
+++ b/README.md
@@ -38,6 +38,7 @@ Just like a compass guides us on our journey, OpenCompass will guide you through
## 🚀 What's New
+- **\[2023.11.06\]** We have supported several API-based models, include ChatGLM Pro@Zhipu, ABAB-Chat@MiniMax and Xunfei. Welcome to [Models](https://opencompass.readthedocs.io/en/latest/user_guides/models.html) section for more details. 🔥🔥🔥.
- **\[2023.10.24\]** We release a new benchmark for evaluating LLMs’ capabilities of having multi-turn dialogues. Welcome to [BotChat](https://github.com/open-compass/BotChat) for more details. 🔥🔥🔥.
- **\[2023.09.26\]** We update the leaderboard with [Qwen](https://github.com/QwenLM/Qwen), one of the best-performing open-source models currently available, welcome to our [homepage](https://opencompass.org.cn) for more details. 🔥🔥🔥.
- **\[2023.09.20\]** We update the leaderboard with [InternLM-20B](https://github.com/InternLM/InternLM), welcome to our [homepage](https://opencompass.org.cn) for more details. 🔥🔥🔥.
@@ -46,7 +47,6 @@ Just like a compass guides us on our journey, OpenCompass will guide you through
- **\[2023.09.08\]** We update the leaderboard with Baichuan-2/Tigerbot-2/Vicuna-v1.5, welcome to our [homepage](https://opencompass.org.cn) for more details.
- **\[2023.09.06\]** [**Baichuan2**](https://github.com/baichuan-inc/Baichuan2) team adpots OpenCompass to evaluate their models systematically. We deeply appreciate the community's dedication to transparency and reproducibility in LLM evaluation.
- **\[2023.09.02\]** We have supported the evaluation of [Qwen-VL](https://github.com/QwenLM/Qwen-VL) in OpenCompass.
-- **\[2023.08.25\]** [**TigerBot**](https://github.com/TigerResearch/TigerBot) team adpots OpenCompass to evaluate their models systematically. We deeply appreciate the community's dedication to transparency and reproducibility in LLM evaluation.
> [More](docs/en/notes/news.md)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 2d9d05da..3f308fa1 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -38,6 +38,7 @@
## 🚀 最新进展
+- **\[2023.11.06\]** 我们已经支持了多个基于 API 的模型,包括ChatGLM Pro@智谱清言、ABAB-Chat@MiniMax 和讯飞。欢迎查看 [模型](https://opencompass.readthedocs.io/en/latest/user_guides/models.html) 部分以获取更多详细信息。🔥🔥🔥。
- **\[2023.10.24\]** 我们发布了一个全新的评测集,BotChat,用于评估大语言模型的多轮对话能力,欢迎查看 [BotChat](https://github.com/open-compass/BotChat) 获取更多信息. 🔥🔥🔥.
- **\[2023.09.26\]** 我们在评测榜单上更新了[Qwen](https://github.com/QwenLM/Qwen), 这是目前表现最好的开源模型之一, 欢迎访问[官方网站](https://opencompass.org.cn)获取详情.🔥🔥🔥.
- **\[2023.09.20\]** 我们在评测榜单上更新了[InternLM-20B](https://github.com/InternLM/InternLM), 欢迎访问[官方网站](https://opencompass.org.cn)获取详情.🔥🔥🔥.
@@ -46,7 +47,6 @@
- **\[2023.09.08\]** 我们在评测榜单上更新了Baichuan-2/Tigerbot-2/Vicuna-v1.5, 欢迎访问[官方网站](https://opencompass.org.cn)获取详情。
- **\[2023.09.06\]** 欢迎 [**Baichuan2**](https://github.com/baichuan-inc/Baichuan2) 团队采用OpenCompass对模型进行系统评估。我们非常感谢社区在提升LLM评估的透明度和可复现性上所做的努力。
- **\[2023.09.02\]** 我们加入了[Qwen-VL](https://github.com/QwenLM/Qwen-VL)的评测支持。
-- **\[2023.08.25\]** 欢迎 [**TigerBot**](https://github.com/TigerResearch/TigerBot) 团队采用OpenCompass对模型进行系统评估。我们非常感谢社区在提升LLM评估的透明度和可复现性上所做的努力。
> [更多](docs/zh_cn/notes/news.md)
diff --git a/configs/eval_minimax.py b/configs/eval_minimax.py
new file mode 100644
index 00000000..6f88e6e3
--- /dev/null
+++ b/configs/eval_minimax.py
@@ -0,0 +1,37 @@
+from mmengine.config import read_base
+from opencompass.models.minimax import MiniMax
+from opencompass.partitioners import NaivePartitioner
+from opencompass.runners import LocalRunner
+from opencompass.runners.local_api import LocalAPIRunner
+from opencompass.tasks import OpenICLInferTask
+
+with read_base():
+ # from .datasets.collections.chat_medium import datasets
+ from .summarizers.medium import summarizer
+ from .datasets.ceval.ceval_gen import ceval_datasets
+
+datasets = [
+ *ceval_datasets,
+]
+
+models = [
+ dict(
+ abbr='minimax_abab5.5-chat',
+ type=MiniMax,
+ path='abab5.5-chat',
+ key='xxxxxxx', # please give you key
+ group_id='xxxxxxxx', # please give your group_id
+ query_per_second=1,
+ max_out_len=2048,
+ max_seq_len=2048,
+ batch_size=8),
+]
+
+infer = dict(
+ partitioner=dict(type=NaivePartitioner),
+ runner=dict(
+ type=LocalAPIRunner,
+ max_num_workers=4,
+ concurrent_users=4,
+ task=dict(type=OpenICLInferTask)),
+)
\ No newline at end of file
diff --git a/configs/eval_xunfei.py b/configs/eval_xunfei.py
new file mode 100644
index 00000000..c852c684
--- /dev/null
+++ b/configs/eval_xunfei.py
@@ -0,0 +1,50 @@
+from mmengine.config import read_base
+from opencompass.models.xunfei_api import XunFei
+from opencompass.partitioners import NaivePartitioner
+from opencompass.runners import LocalRunner
+from opencompass.runners.local_api import LocalAPIRunner
+from opencompass.tasks import OpenICLInferTask
+
+with read_base():
+ # from .datasets.collections.chat_medium import datasets
+ from .summarizers.medium import summarizer
+ from .datasets.ceval.ceval_gen import ceval_datasets
+
+datasets = [
+ *ceval_datasets,
+]
+
+models = [
+ dict(
+ abbr='Spark-v1-1',
+ type=XunFei,
+ appid="xxxx",
+ path='ws://spark-api.xf-yun.com/v1.1/chat',
+ api_secret = "xxxxxxx",
+ api_key = "xxxxxxx",
+ query_per_second=1,
+ max_out_len=2048,
+ max_seq_len=2048,
+ batch_size=8),
+ dict(
+ abbr='Spark-v3-1',
+ type=XunFei,
+ appid="xxxx",
+ domain='generalv3',
+ path='ws://spark-api.xf-yun.com/v3.1/chat',
+ api_secret = "xxxxxxxx",
+ api_key = "xxxxxxxxx",
+ query_per_second=1,
+ max_out_len=2048,
+ max_seq_len=2048,
+ batch_size=8),
+]
+
+infer = dict(
+ partitioner=dict(type=NaivePartitioner),
+ runner=dict(
+ type=LocalAPIRunner,
+ max_num_workers=2,
+ concurrent_users=2,
+ task=dict(type=OpenICLInferTask)),
+)
\ No newline at end of file
diff --git a/configs/eval_zhihu.py b/configs/eval_zhihu.py
new file mode 100644
index 00000000..2dc2dbc6
--- /dev/null
+++ b/configs/eval_zhihu.py
@@ -0,0 +1,36 @@
+from mmengine.config import read_base
+from opencompass.models import ZhiPuAI
+from opencompass.partitioners import NaivePartitioner
+from opencompass.runners import LocalRunner
+from opencompass.runners.local_api import LocalAPIRunner
+from opencompass.tasks import OpenICLInferTask
+
+with read_base():
+ # from .datasets.collections.chat_medium import datasets
+ from .summarizers.medium import summarizer
+ from .datasets.ceval.ceval_gen import ceval_datasets
+
+datasets = [
+ *ceval_datasets,
+]
+
+models = [
+ dict(
+ abbr='chatglm_pro',
+ type=ZhiPuAI,
+ path='chatglm_pro',
+ key='xxxxxxxxxxxx',
+ query_per_second=1,
+ max_out_len=2048,
+ max_seq_len=2048,
+ batch_size=8),
+]
+
+infer = dict(
+ partitioner=dict(type=NaivePartitioner),
+ runner=dict(
+ type=LocalAPIRunner,
+ max_num_workers=2,
+ concurrent_users=2,
+ task=dict(type=OpenICLInferTask)),
+)
\ No newline at end of file
diff --git a/docs/en/index.rst b/docs/en/index.rst
index 97e09e63..d87b5036 100644
--- a/docs/en/index.rst
+++ b/docs/en/index.rst
@@ -69,7 +69,7 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass.
.. _Tools:
.. toctree::
:maxdepth: 1
- :caption: tools
+ :caption: Tools
tools.md
diff --git a/docs/en/notes/news.md b/docs/en/notes/news.md
index 2da16bfb..a871c71e 100644
--- a/docs/en/notes/news.md
+++ b/docs/en/notes/news.md
@@ -1,5 +1,6 @@
# News
+- **\[2023.08.25\]** [**TigerBot**](https://github.com/TigerResearch/TigerBot) team adpots OpenCompass to evaluate their models systematically. We deeply appreciate the community's dedication to transparency and reproducibility in LLM evaluation.
- **\[2023.08.21\]** [**Lagent**](https://github.com/InternLM/lagent) has been released, which is a lightweight framework for building LLM-based agents. We are working with Lagent team to support the evaluation of general tool-use capability, stay tuned!
- **\[2023.08.18\]** We have supported evaluation for **multi-modality learning**, include **MMBench, SEED-Bench, COCO-Caption, Flickr-30K, OCR-VQA, ScienceQA** and so on. Leaderboard is on the road. Feel free to try multi-modality evaluation with OpenCompass !
- **\[2023.08.18\]** [Dataset card](https://opencompass.org.cn/dataset-detail/MMLU) is now online. Welcome new evaluation benchmark OpenCompass !
diff --git a/docs/en/user_guides/models.md b/docs/en/user_guides/models.md
index a4a35e0f..eb96ae76 100644
--- a/docs/en/user_guides/models.md
+++ b/docs/en/user_guides/models.md
@@ -70,7 +70,9 @@ model = HuggingFaceCausalLM(
Currently, OpenCompass supports API-based model inference for the following:
- OpenAI (`opencompass.models.OpenAI`)
-- More coming soon
+- ChatGLM (`opencompass.models.ZhiPuAI`)
+- ABAB-Chat from MiniMax (`opencompass.models.MiniMax`)
+- XunFei from XunFei (`opencompass.models.XunFei`)
Let's take the OpenAI configuration file as an example to see how API-based models are used in the
configuration file.
@@ -94,6 +96,15 @@ models = [
]
```
+We have provided several examples for API-based models. Please refer to
+
+```bash
+configs
+├── eval_zhihu.py
+├── eval_xunfei.py
+└── eval_minimax.py
+```
+
## Custom Models
If the above methods do not support your model evaluation requirements, you can refer to
diff --git a/docs/zh_cn/notes/news.md b/docs/zh_cn/notes/news.md
index 23488b8f..39b5011b 100644
--- a/docs/zh_cn/notes/news.md
+++ b/docs/zh_cn/notes/news.md
@@ -1,5 +1,6 @@
# 新闻
+- **\[2023.08.25\]** 欢迎 [**TigerBot**](https://github.com/TigerResearch/TigerBot) 团队采用OpenCompass对模型进行系统评估。我们非常感谢社区在提升LLM评估的透明度和可复现性上所做的努力。
- **\[2023.08.21\]** [**Lagent**](https://github.com/InternLM/lagent) 正式发布,它是一个轻量级、开源的基于大语言模型的智能体(agent)框架。我们正与Lagent团队紧密合作,推进支持基于Lagent的大模型工具能力评测 !
- **\[2023.08.18\]** OpenCompass现已支持**多模态评测**,支持10+多模态评测数据集,包括 **MMBench, SEED-Bench, COCO-Caption, Flickr-30K, OCR-VQA, ScienceQA** 等. 多模态评测榜单即将上线,敬请期待!
- **\[2023.08.18\]** [数据集页面](https://opencompass.org.cn/dataset-detail/MMLU) 现已在OpenCompass官网上线,欢迎更多社区评测数据集加入OpenCompass !
diff --git a/docs/zh_cn/user_guides/models.md b/docs/zh_cn/user_guides/models.md
index 7f51b63b..a587bf5d 100644
--- a/docs/zh_cn/user_guides/models.md
+++ b/docs/zh_cn/user_guides/models.md
@@ -63,7 +63,9 @@ model = HuggingFaceCausalLM(
OpenCompass 目前支持以下基于 API 的模型推理:
- OpenAI(`opencompass.models.OpenAI`)
-- Coming soon
+- ChatGLM@智谱清言 (`opencompass.models.ZhiPuAI`)
+- ABAB-Chat@MiniMax (`opencompass.models.MiniMax`)
+- XunFei@科大讯飞 (`opencompass.models.XunFei`)
以下,我们以 OpenAI 的配置文件为例,模型如何在配置文件中使用基于 API 的模型。
@@ -86,6 +88,15 @@ models = [
]
```
+我们也提供了API模型的评测示例,请参考
+
+```bash
+configs
+├── eval_zhihu.py
+├── eval_xunfei.py
+└── eval_minimax.py
+```
+
## 自定义模型
如果以上方式无法支持你的模型评测需求,请参考 [支持新模型](../advanced_guides/new_model.md) 在 OpenCompass 中增添新的模型支持。
diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py
index 1fa6ed56..bf653c14 100644
--- a/opencompass/models/__init__.py
+++ b/opencompass/models/__init__.py
@@ -6,6 +6,7 @@ from .huggingface import HuggingFace # noqa: F401, F403
from .huggingface import HuggingFaceCausalLM # noqa: F401, F403
from .intern_model import InternLM # noqa: F401, F403
from .llama2 import Llama2, Llama2Chat # noqa: F401, F403
+from .minimax_api import MiniMax # noqa: F401
from .openai_api import OpenAI # noqa: F401
from .xunfei_api import XunFei # noqa: F401
-from .zhipuai import ZhiPuAI # noqa: F401
+from .zhipuai_api import ZhiPuAI # noqa: F401
diff --git a/opencompass/models/minimax_api.py b/opencompass/models/minimax_api.py
new file mode 100644
index 00000000..e6bc7bb3
--- /dev/null
+++ b/opencompass/models/minimax_api.py
@@ -0,0 +1,209 @@
+import sys
+from concurrent.futures import ThreadPoolExecutor
+from typing import Dict, List, Optional, Union
+
+import requests
+
+from opencompass.registry import MODELS
+from opencompass.utils.prompt import PromptList
+
+from .base_api import BaseAPIModel
+
+PromptType = Union[PromptList, str]
+
+
+@MODELS.register_module(name=['MiniMax'])
+class MiniMax(BaseAPIModel):
+ """Model wrapper around MiniMax.
+
+ Documentation: https://api.minimax.chat/document/guides/chat-pro
+
+ Args:
+ path (str): The name of MiniMax model.
+ e.g. `abab5.5-chat`
+ model_type (str): The type of the model
+ e.g. `chat`
+ group_id (str): The id of group(like the org ID of group)
+ key (str): Authorization key.
+ query_per_second (int): The maximum queries allowed per second
+ between two consecutive calls of the API. Defaults to 1.
+ max_seq_len (int): Unused here.
+ meta_template (Dict, optional): The model's meta prompt
+ template if needed, in case the requirement of injecting or
+ wrapping of any meta instructions.
+ retry (int): Number of retires if the API call fails. Defaults to 2.
+ """
+
+ def __init__(
+ self,
+ path: str,
+ key: str,
+ group_id: str,
+ model_type: str = 'chat',
+ url:
+ str = 'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId=',
+ query_per_second: int = 2,
+ max_seq_len: int = 2048,
+ meta_template: Optional[Dict] = None,
+ retry: int = 2,
+ ):
+ super().__init__(path=path,
+ max_seq_len=max_seq_len,
+ query_per_second=query_per_second,
+ meta_template=meta_template,
+ retry=retry)
+ self.headers = {
+ 'Authorization': f'Bearer {key}',
+ 'Content-Type': 'application/json',
+ }
+ self.type = model_type
+ self.url = url + group_id
+ self.model = path
+
+ def generate(
+ self,
+ inputs: List[str or PromptList],
+ max_out_len: int = 512,
+ ) -> List[str]:
+ """Generate results given a list of inputs.
+
+ Args:
+ inputs (List[str or PromptList]): A list of strings or PromptDicts.
+ The PromptDict should be organized in OpenCompass'
+ API format.
+ max_out_len (int): The maximum length of the output.
+
+ Returns:
+ List[str]: A list of generated strings.
+ """
+ with ThreadPoolExecutor() as executor:
+ results = list(
+ executor.map(self._generate, inputs,
+ [max_out_len] * len(inputs)))
+ self.flush()
+ return results
+
+ def flush(self):
+ """Flush stdout and stderr when concurrent resources exists.
+
+ When use multiproessing with standard io rediected to files, need to
+ flush internal information for examination or log loss when system
+ breaks.
+ """
+ if hasattr(self, 'tokens'):
+ sys.stdout.flush()
+ sys.stderr.flush()
+
+ def acquire(self):
+ """Acquire concurrent resources if exists.
+
+ This behavior will fall back to wait with query_per_second if there are
+ no concurrent resources.
+ """
+ if hasattr(self, 'tokens'):
+ self.tokens.acquire()
+ else:
+ self.wait()
+
+ def release(self):
+ """Release concurrent resources if acquired.
+
+ This behavior will fall back to do nothing if there are no concurrent
+ resources.
+ """
+ if hasattr(self, 'tokens'):
+ self.tokens.release()
+
+ def _generate(
+ self,
+ input: str or PromptList,
+ max_out_len: int = 512,
+ ) -> str:
+ """Generate results given an input.
+
+ Args:
+ inputs (str or PromptList): A string or PromptDict.
+ The PromptDict should be organized in OpenCompass'
+ API format.
+ max_out_len (int): The maximum length of the output.
+
+ Returns:
+ str: The generated string.
+ """
+ assert isinstance(input, (str, PromptList))
+
+ if isinstance(input, str):
+ messages = [{
+ 'sender_type': 'USER',
+ 'sender_name': 'OpenCompass',
+ 'text': input
+ }]
+ else:
+ messages = []
+ for item in input:
+ msg = {'text': item['prompt']}
+ if item['role'] == 'HUMAN':
+ msg['sender_type'] = 'USER'
+ msg['sender_name'] = 'OpenCompass'
+ elif item['role'] == 'BOT':
+ msg['sender_type'] = 'BOT'
+ msg['sender_name'] = 'MM智能助理'
+
+ messages.append(msg)
+
+ data = {
+ 'bot_setting': [{
+ 'bot_name':
+ 'MM智能助理',
+ 'content':
+ 'MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。' +
+ 'MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。'
+ }],
+ 'reply_constraints': {
+ 'sender_type': 'BOT',
+ 'sender_name': 'MM智能助理'
+ },
+ 'model':
+ self.model,
+ 'messages':
+ messages
+ }
+
+ max_num_retries = 0
+ while max_num_retries < self.retry:
+ self.acquire()
+ raw_response = requests.request('POST',
+ url=self.url,
+ headers=self.headers,
+ json=data)
+ response = raw_response.json()
+ self.release()
+
+ if response is None:
+ print('Connection error, reconnect.')
+ # if connect error, frequent requests will casuse
+ # continuous unstable network, therefore wait here
+ # to slow down the request
+ self.wait()
+ continue
+ if raw_response.status_code == 200:
+ # msg = json.load(response.text)
+ # response
+ msg = response['reply']
+ return msg
+ # sensitive content, prompt overlength, network error
+ # or illegal prompt
+ if (response.status_code == 1000 or response.status_code == 1001
+ or response.status_code == 1002
+ or response.status_code == 1004
+ or response.status_code == 1008
+ or response.status_code == 1013
+ or response.status_code == 1027
+ or response.status_code == 1039
+ or response.status_code == 2013):
+ print(response.text)
+ return ''
+ print(response)
+ max_num_retries += 1
+
+ raise RuntimeError(response.text)
diff --git a/opencompass/models/zhipuai.py b/opencompass/models/zhipuai_api.py
similarity index 100%
rename from opencompass/models/zhipuai.py
rename to opencompass/models/zhipuai_api.py
diff --git a/opencompass/openicl/icl_retriever/icl_topk_retriever.py b/opencompass/openicl/icl_retriever/icl_topk_retriever.py
index 15743d71..c9ac8f81 100644
--- a/opencompass/openicl/icl_retriever/icl_topk_retriever.py
+++ b/opencompass/openicl/icl_retriever/icl_topk_retriever.py
@@ -4,7 +4,6 @@ import copy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
-import faiss
import numpy as np
import torch
import tqdm
@@ -84,6 +83,8 @@ class TopkRetriever(BaseRetriever):
self.index = self.create_index()
def create_index(self):
+ import faiss
+
self.select_datalist = self.dataset_reader.generate_input_field_corpus(
self.index_ds)
encode_datalist = DatasetEncoder(self.select_datalist,
diff --git a/requirements/api.txt b/requirements/api.txt
new file mode 100644
index 00000000..33ef78a6
--- /dev/null
+++ b/requirements/api.txt
@@ -0,0 +1,2 @@
+websocket-client
+zhipu
diff --git a/requirements/extra.txt b/requirements/extra.txt
new file mode 100644
index 00000000..f5f709ce
--- /dev/null
+++ b/requirements/extra.txt
@@ -0,0 +1 @@
+faiss_gpu==1.7.2
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
index 1e0054d1..3ea57737 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime.txt
@@ -7,7 +7,6 @@ cpm_kernels
datasets>=2.12.0
evaluate>=0.3.0
fairscale
-faiss_gpu==1.7.2
fuzzywuzzy
jieba
ltp