mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support the reasoning from BaiLing LLM (#1541)
* [Feature] Support the reasoning from BaiLing LLM This commit includes the access to BaiLing LLM and gets the reasoning. * Add the api example The example of evalute bailing api * Revise the generation arguments Based on current experiment, we update some generation arguments for better reasoning * [fix] set the batch size * Retry under flowcontrol of serverside * add dependent package into requirement.txt add dependent package retrying to clean up the pre-comment check. * correct the file names and make the file copy correct the file names. copy the files under configs to opencompass * fix the lint issue --------- Co-authored-by: christopher.dy <christopher.dy@antgroup.com>
This commit is contained in:
parent
80cda1980e
commit
3f833186dc
38
configs/api_examples/eval_api_bailing.py
Normal file
38
configs/api_examples/eval_api_bailing.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
from opencompass.models import BailingAPI
|
||||||
|
from opencompass.partitioners import NaivePartitioner
|
||||||
|
from opencompass.runners.local_api import LocalAPIRunner
|
||||||
|
from opencompass.tasks import OpenICLInferTask
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
from opencompass.configs.datasets.ceval.ceval_gen import ceval_datasets
|
||||||
|
from opencompass.configs.summarizers.medium import summarizer
|
||||||
|
|
||||||
|
datasets = [
|
||||||
|
*ceval_datasets,
|
||||||
|
]
|
||||||
|
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
path="Bailing-Lite-0830",
|
||||||
|
token="xxxxxx", # set your key here or in environment variable BAILING_API_KEY
|
||||||
|
url="https://bailingchat.alipay.com/chat/completions",
|
||||||
|
type=BailingAPI,
|
||||||
|
generation_kwargs={},
|
||||||
|
query_per_second=1,
|
||||||
|
max_seq_len=4096,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
infer = dict(
|
||||||
|
partitioner=dict(type=NaivePartitioner),
|
||||||
|
runner=dict(
|
||||||
|
type=LocalAPIRunner,
|
||||||
|
max_num_workers=2,
|
||||||
|
concurrent_users=2,
|
||||||
|
task=dict(type=OpenICLInferTask),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
work_dir = "outputs/api_bailing/"
|
31
configs/models/bailing_api/bailing-lite-0830.py
Normal file
31
configs/models/bailing_api/bailing-lite-0830.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from opencompass.models import BailingAPI
|
||||||
|
|
||||||
|
api_meta_template = dict(
|
||||||
|
round=[
|
||||||
|
dict(role="HUMAN", api_role="HUMAN"),
|
||||||
|
dict(role="BOT", api_role="BOT", generate=False),
|
||||||
|
],
|
||||||
|
reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")],
|
||||||
|
)
|
||||||
|
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
path="Bailing-Lite-0830",
|
||||||
|
token="", # set your key here or in environment variable BAILING_API_KEY
|
||||||
|
url="https://bailingchat.alipay.com/chat/completions",
|
||||||
|
type=BailingAPI,
|
||||||
|
meta_template=api_meta_template,
|
||||||
|
query_per_second=1,
|
||||||
|
max_seq_len=4096,
|
||||||
|
batch_size=1,
|
||||||
|
generation_kwargs={
|
||||||
|
"temperature": 0.4,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"top_k": -1,
|
||||||
|
"n": 1,
|
||||||
|
"logprobs": 1,
|
||||||
|
"use_beam_search": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
31
configs/models/bailing_api/bailing-pro-0920.py
Normal file
31
configs/models/bailing_api/bailing-pro-0920.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from opencompass.models import BailingAPI
|
||||||
|
|
||||||
|
api_meta_template = dict(
|
||||||
|
round=[
|
||||||
|
dict(role="HUMAN", api_role="HUMAN"),
|
||||||
|
dict(role="BOT", api_role="BOT", generate=False),
|
||||||
|
],
|
||||||
|
reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")],
|
||||||
|
)
|
||||||
|
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
path="Bailing-Pro-0920",
|
||||||
|
token="", # set your key here or in environment variable BAILING_API_KEY
|
||||||
|
url="https://bailingchat.alipay.com/chat/completions",
|
||||||
|
type=BailingAPI,
|
||||||
|
meta_template=api_meta_template,
|
||||||
|
query_per_second=1,
|
||||||
|
max_seq_len=4096,
|
||||||
|
batch_size=1,
|
||||||
|
generation_kwargs={
|
||||||
|
"temperature": 0.4,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"top_k": -1,
|
||||||
|
"n": 1,
|
||||||
|
"logprobs": 1,
|
||||||
|
"use_beam_search": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
31
opencompass/configs/models/bailing_api/bailing-lite-0830.py
Normal file
31
opencompass/configs/models/bailing_api/bailing-lite-0830.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from opencompass.models import BailingAPI
|
||||||
|
|
||||||
|
api_meta_template = dict(
|
||||||
|
round=[
|
||||||
|
dict(role="HUMAN", api_role="HUMAN"),
|
||||||
|
dict(role="BOT", api_role="BOT", generate=False),
|
||||||
|
],
|
||||||
|
reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")],
|
||||||
|
)
|
||||||
|
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
path="Bailing-Lite-0830",
|
||||||
|
token="", # set your key here or in environment variable BAILING_API_KEY
|
||||||
|
url="https://bailingchat.alipay.com/chat/completions",
|
||||||
|
type=BailingAPI,
|
||||||
|
meta_template=api_meta_template,
|
||||||
|
query_per_second=1,
|
||||||
|
max_seq_len=4096,
|
||||||
|
batch_size=1,
|
||||||
|
generation_kwargs={
|
||||||
|
"temperature": 0.4,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"top_k": -1,
|
||||||
|
"n": 1,
|
||||||
|
"logprobs": 1,
|
||||||
|
"use_beam_search": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
31
opencompass/configs/models/bailing_api/bailing-pro-0920.py
Normal file
31
opencompass/configs/models/bailing_api/bailing-pro-0920.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from opencompass.models import BailingAPI
|
||||||
|
|
||||||
|
api_meta_template = dict(
|
||||||
|
round=[
|
||||||
|
dict(role="HUMAN", api_role="HUMAN"),
|
||||||
|
dict(role="BOT", api_role="BOT", generate=False),
|
||||||
|
],
|
||||||
|
reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")],
|
||||||
|
)
|
||||||
|
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
path="Bailing-Pro-0920",
|
||||||
|
token="", # set your key here or in environment variable BAILING_API_KEY
|
||||||
|
url="https://bailingchat.alipay.com/chat/completions",
|
||||||
|
type=BailingAPI,
|
||||||
|
meta_template=api_meta_template,
|
||||||
|
query_per_second=1,
|
||||||
|
max_seq_len=4096,
|
||||||
|
batch_size=1,
|
||||||
|
generation_kwargs={
|
||||||
|
"temperature": 0.4,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"top_k": -1,
|
||||||
|
"n": 1,
|
||||||
|
"logprobs": 1,
|
||||||
|
"use_beam_search": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
@ -3,6 +3,7 @@ from .ai360_api import AI360GPT # noqa: F401
|
|||||||
from .alaya import AlayaLM # noqa: F401
|
from .alaya import AlayaLM # noqa: F401
|
||||||
from .baichuan_api import BaiChuan # noqa: F401
|
from .baichuan_api import BaiChuan # noqa: F401
|
||||||
from .baidu_api import ERNIEBot # noqa: F401
|
from .baidu_api import ERNIEBot # noqa: F401
|
||||||
|
from .bailing_api_oc import BailingAPI # noqa: F401
|
||||||
from .base import BaseModel, LMTemplateParser # noqa: F401
|
from .base import BaseModel, LMTemplateParser # noqa: F401
|
||||||
from .base_api import APITemplateParser, BaseAPIModel # noqa: F401
|
from .base_api import APITemplateParser, BaseAPIModel # noqa: F401
|
||||||
from .bytedance_api import ByteDance # noqa: F401
|
from .bytedance_api import ByteDance # noqa: F401
|
||||||
@ -41,8 +42,7 @@ from .sensetime_api import SenseTime # noqa: F401
|
|||||||
from .stepfun_api import StepFun # noqa: F401
|
from .stepfun_api import StepFun # noqa: F401
|
||||||
from .turbomind import TurboMindModel # noqa: F401
|
from .turbomind import TurboMindModel # noqa: F401
|
||||||
from .turbomind_tis import TurboMindTisModel # noqa: F401
|
from .turbomind_tis import TurboMindTisModel # noqa: F401
|
||||||
from .turbomind_with_tf_above_v4_33 import \
|
from .turbomind_with_tf_above_v4_33 import TurboMindModelwithChatTemplate # noqa: F401
|
||||||
TurboMindModelwithChatTemplate # noqa: F401
|
|
||||||
from .unigpt_api import UniGPT # noqa: F401
|
from .unigpt_api import UniGPT # noqa: F401
|
||||||
from .vllm import VLLM # noqa: F401
|
from .vllm import VLLM # noqa: F401
|
||||||
from .vllm_with_tf_above_v4_33 import VLLMwithChatTemplate # noqa: F401
|
from .vllm_with_tf_above_v4_33 import VLLMwithChatTemplate # noqa: F401
|
||||||
|
215
opencompass/models/bailing_api_oc.py
Normal file
215
opencompass/models/bailing_api_oc.py
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
import concurrent
|
||||||
|
import concurrent.futures
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import traceback
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
from retrying import retry
|
||||||
|
from urllib3.connection import HTTPConnection
|
||||||
|
|
||||||
|
from opencompass.utils.prompt import PromptList
|
||||||
|
|
||||||
|
from .base_api import BaseAPIModel
|
||||||
|
|
||||||
|
PromptType = Union[PromptList, str]
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPAdapterWithSocketOptions(HTTPAdapter):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._socket_options = HTTPConnection.default_socket_options + [
|
||||||
|
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
|
||||||
|
(socket.SOL_TCP, socket.TCP_KEEPIDLE, 75),
|
||||||
|
(socket.SOL_TCP, socket.TCP_KEEPINTVL, 30),
|
||||||
|
(socket.SOL_TCP, socket.TCP_KEEPCNT, 120),
|
||||||
|
]
|
||||||
|
super(HTTPAdapterWithSocketOptions, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def init_poolmanager(self, *args, **kwargs):
|
||||||
|
if self._socket_options is not None:
|
||||||
|
kwargs["socket_options"] = self._socket_options
|
||||||
|
super(HTTPAdapterWithSocketOptions, self).init_poolmanager(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class BailingAPI(BaseAPIModel):
|
||||||
|
"""Model wrapper around Bailing Service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ouput_key (str): key for prediction
|
||||||
|
query_per_second (int): The maximum queries allowed per second
|
||||||
|
between two consecutive calls of the API. Defaults to 1.
|
||||||
|
generation_kwargs: other params
|
||||||
|
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
token: str,
|
||||||
|
url: str,
|
||||||
|
meta_template: Optional[Dict] = None,
|
||||||
|
query_per_second: int = 1,
|
||||||
|
retry: int = 3,
|
||||||
|
generation_kwargs: Dict = {},
|
||||||
|
max_seq_len=4096,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
path=path,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
query_per_second=query_per_second,
|
||||||
|
meta_template=meta_template,
|
||||||
|
retry=retry,
|
||||||
|
generation_kwargs=generation_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(f"Bailing API Model Init path: {path} url={url}")
|
||||||
|
if not token:
|
||||||
|
token = os.environ.get("BAILING_API_KEY")
|
||||||
|
if token:
|
||||||
|
self._headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"There is not valid token.")
|
||||||
|
self._headers["Content-Type"] = "application/json"
|
||||||
|
self._url = url if url else "https://bailingchat.alipay.com/chat/completions"
|
||||||
|
self._model = path
|
||||||
|
self._sessions = []
|
||||||
|
self._num = (
|
||||||
|
int(os.environ.get("BAILING_API_PARALLEL_NUM"))
|
||||||
|
if os.environ.get("BAILING_API_PARALLEL_NUM")
|
||||||
|
else 1
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
for _ in range(self._num):
|
||||||
|
adapter = HTTPAdapterWithSocketOptions()
|
||||||
|
sess = requests.Session()
|
||||||
|
sess.mount("http://", adapter)
|
||||||
|
sess.mount("https://", adapter)
|
||||||
|
self._sessions.append(sess)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Fail to setup the session. {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
inputs: Union[List[str], PromptList],
|
||||||
|
max_out_len: int = 4096,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Generate results given a list of inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (Union[List[str], PromptList]): A list of strings or PromptDicts.
|
||||||
|
The PromptDict should be organized in OpenCompass' API format.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of generated strings.
|
||||||
|
"""
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(
|
||||||
|
max_workers=self._num,
|
||||||
|
) as executor:
|
||||||
|
future_to_m = {
|
||||||
|
executor.submit(
|
||||||
|
self._generate,
|
||||||
|
self._sessions[i % self._num],
|
||||||
|
input,
|
||||||
|
max_out_len,
|
||||||
|
): i
|
||||||
|
for i, input in enumerate(inputs)
|
||||||
|
}
|
||||||
|
results = []
|
||||||
|
for future in concurrent.futures.as_completed(future_to_m):
|
||||||
|
m = future_to_m[future]
|
||||||
|
resp = future.result()
|
||||||
|
if resp and resp.status_code == 200:
|
||||||
|
try:
|
||||||
|
result = resp.json()
|
||||||
|
except:
|
||||||
|
results.append("")
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
result.get("choices")
|
||||||
|
and result["choices"][0].get("message")
|
||||||
|
and result["choices"][0]["message"].get("content")
|
||||||
|
):
|
||||||
|
results.append(result["choices"][0]["message"]["content"])
|
||||||
|
else:
|
||||||
|
results.append("")
|
||||||
|
self.flush()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
sess,
|
||||||
|
input: Union[str, PromptList],
|
||||||
|
max_out_len: int,
|
||||||
|
) -> str:
|
||||||
|
"""Generate results given an input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (str or PromptList): A string or PromptDict.
|
||||||
|
The PromptDict should be organized in OpenCompass' API format.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated string.
|
||||||
|
"""
|
||||||
|
if isinstance(input, str):
|
||||||
|
messages = [{"role": "user", "content": input}]
|
||||||
|
else:
|
||||||
|
messages = []
|
||||||
|
for item in input:
|
||||||
|
content = item["prompt"]
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
message = {"content": content}
|
||||||
|
if item["role"] == "HUMAN":
|
||||||
|
message["role"] = "user"
|
||||||
|
elif item["role"] == "BOT":
|
||||||
|
message["role"] = "assistant"
|
||||||
|
elif item["role"] == "SYSTEM":
|
||||||
|
message["role"] = "system"
|
||||||
|
else:
|
||||||
|
message["role"] = item["role"]
|
||||||
|
messages.append(message)
|
||||||
|
request = {
|
||||||
|
"model": self._model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_seq_len": max(
|
||||||
|
max_out_len if max_out_len else 4096,
|
||||||
|
self.max_seq_len if self.max_seq_len else 4096,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
request.update(self.generation_kwargs)
|
||||||
|
try:
|
||||||
|
retry_num = 0
|
||||||
|
while retry_num < self.retry:
|
||||||
|
response = self._infer_result(request, sess)
|
||||||
|
if response.status_code == 200:
|
||||||
|
break # success
|
||||||
|
elif response.status_code == 426:
|
||||||
|
retry_num += 1 # retry
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Status code = {response.status_code}")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Exceed the maximal retry times. Last status code = {response.status_code}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(
|
||||||
|
f"Fail to inference request={request}; model_name={self.path}; error={e}, stack:{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
return response
|
||||||
|
|
||||||
|
@retry(stop_max_attempt_number=3, wait_fixed=16000) # ms
|
||||||
|
def _infer_result(self, request, sess):
|
||||||
|
response = sess.request(
|
||||||
|
"POST",
|
||||||
|
self._url,
|
||||||
|
json=request,
|
||||||
|
headers=self._headers,
|
||||||
|
timeout=500,
|
||||||
|
)
|
||||||
|
return response
|
@ -23,6 +23,7 @@ python-Levenshtein
|
|||||||
rank_bm25==0.2.2
|
rank_bm25==0.2.2
|
||||||
rapidfuzz
|
rapidfuzz
|
||||||
requests>=2.31.0
|
requests>=2.31.0
|
||||||
|
retrying
|
||||||
rich
|
rich
|
||||||
rouge
|
rouge
|
||||||
-e git+https://github.com/Isaac-JL-Chen/rouge_chinese.git@master#egg=rouge_chinese
|
-e git+https://github.com/Isaac-JL-Chen/rouge_chinese.git@master#egg=rouge_chinese
|
||||||
|
Loading…
Reference in New Issue
Block a user