[Feature] Support the reasoning from BaiLing LLM (#1541)

* [Feature] Support the reasoning from BaiLing LLM

This commit includes the access to BaiLing LLM and gets the reasoning.

* Add the api example

The example of evalute bailing api

* Revise the generation arguments

Based on current experiment, we update some generation arguments for better reasoning

* [fix] set the batch size

* Retry under flowcontrol of serverside

* add dependent package into requirement.txt

add dependent package retrying to clean up the pre-comment check.

* correct the file names and make the file copy

correct the file names.
copy the files under configs to opencompass

* fix the lint issue

---------

Co-authored-by: christopher.dy <christopher.dy@antgroup.com>
This commit is contained in:
Yi Ding 2024-09-26 16:49:52 +08:00 committed by GitHub
parent 80cda1980e
commit 3f833186dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 380 additions and 2 deletions

View File

@ -0,0 +1,38 @@
from mmengine.config import read_base
from opencompass.models import BailingAPI
from opencompass.partitioners import NaivePartitioner
from opencompass.runners.local_api import LocalAPIRunner
from opencompass.tasks import OpenICLInferTask
with read_base():
from opencompass.configs.datasets.ceval.ceval_gen import ceval_datasets
from opencompass.configs.summarizers.medium import summarizer
datasets = [
*ceval_datasets,
]
models = [
dict(
path="Bailing-Lite-0830",
token="xxxxxx", # set your key here or in environment variable BAILING_API_KEY
url="https://bailingchat.alipay.com/chat/completions",
type=BailingAPI,
generation_kwargs={},
query_per_second=1,
max_seq_len=4096,
),
]
infer = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=LocalAPIRunner,
max_num_workers=2,
concurrent_users=2,
task=dict(type=OpenICLInferTask),
),
)
work_dir = "outputs/api_bailing/"

View File

@ -0,0 +1,31 @@
from opencompass.models import BailingAPI
api_meta_template = dict(
round=[
dict(role="HUMAN", api_role="HUMAN"),
dict(role="BOT", api_role="BOT", generate=False),
],
reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")],
)
models = [
dict(
path="Bailing-Lite-0830",
token="", # set your key here or in environment variable BAILING_API_KEY
url="https://bailingchat.alipay.com/chat/completions",
type=BailingAPI,
meta_template=api_meta_template,
query_per_second=1,
max_seq_len=4096,
batch_size=1,
generation_kwargs={
"temperature": 0.4,
"top_p": 1.0,
"top_k": -1,
"n": 1,
"logprobs": 1,
"use_beam_search": False,
},
),
]

View File

@ -0,0 +1,31 @@
from opencompass.models import BailingAPI
api_meta_template = dict(
round=[
dict(role="HUMAN", api_role="HUMAN"),
dict(role="BOT", api_role="BOT", generate=False),
],
reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")],
)
models = [
dict(
path="Bailing-Pro-0920",
token="", # set your key here or in environment variable BAILING_API_KEY
url="https://bailingchat.alipay.com/chat/completions",
type=BailingAPI,
meta_template=api_meta_template,
query_per_second=1,
max_seq_len=4096,
batch_size=1,
generation_kwargs={
"temperature": 0.4,
"top_p": 1.0,
"top_k": -1,
"n": 1,
"logprobs": 1,
"use_beam_search": False,
},
),
]

View File

@ -0,0 +1,31 @@
from opencompass.models import BailingAPI
api_meta_template = dict(
round=[
dict(role="HUMAN", api_role="HUMAN"),
dict(role="BOT", api_role="BOT", generate=False),
],
reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")],
)
models = [
dict(
path="Bailing-Lite-0830",
token="", # set your key here or in environment variable BAILING_API_KEY
url="https://bailingchat.alipay.com/chat/completions",
type=BailingAPI,
meta_template=api_meta_template,
query_per_second=1,
max_seq_len=4096,
batch_size=1,
generation_kwargs={
"temperature": 0.4,
"top_p": 1.0,
"top_k": -1,
"n": 1,
"logprobs": 1,
"use_beam_search": False,
},
),
]

View File

@ -0,0 +1,31 @@
from opencompass.models import BailingAPI
api_meta_template = dict(
round=[
dict(role="HUMAN", api_role="HUMAN"),
dict(role="BOT", api_role="BOT", generate=False),
],
reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")],
)
models = [
dict(
path="Bailing-Pro-0920",
token="", # set your key here or in environment variable BAILING_API_KEY
url="https://bailingchat.alipay.com/chat/completions",
type=BailingAPI,
meta_template=api_meta_template,
query_per_second=1,
max_seq_len=4096,
batch_size=1,
generation_kwargs={
"temperature": 0.4,
"top_p": 1.0,
"top_k": -1,
"n": 1,
"logprobs": 1,
"use_beam_search": False,
},
),
]

View File

@ -3,6 +3,7 @@ from .ai360_api import AI360GPT # noqa: F401
from .alaya import AlayaLM # noqa: F401 from .alaya import AlayaLM # noqa: F401
from .baichuan_api import BaiChuan # noqa: F401 from .baichuan_api import BaiChuan # noqa: F401
from .baidu_api import ERNIEBot # noqa: F401 from .baidu_api import ERNIEBot # noqa: F401
from .bailing_api_oc import BailingAPI # noqa: F401
from .base import BaseModel, LMTemplateParser # noqa: F401 from .base import BaseModel, LMTemplateParser # noqa: F401
from .base_api import APITemplateParser, BaseAPIModel # noqa: F401 from .base_api import APITemplateParser, BaseAPIModel # noqa: F401
from .bytedance_api import ByteDance # noqa: F401 from .bytedance_api import ByteDance # noqa: F401
@ -41,8 +42,7 @@ from .sensetime_api import SenseTime # noqa: F401
from .stepfun_api import StepFun # noqa: F401 from .stepfun_api import StepFun # noqa: F401
from .turbomind import TurboMindModel # noqa: F401 from .turbomind import TurboMindModel # noqa: F401
from .turbomind_tis import TurboMindTisModel # noqa: F401 from .turbomind_tis import TurboMindTisModel # noqa: F401
from .turbomind_with_tf_above_v4_33 import \ from .turbomind_with_tf_above_v4_33 import TurboMindModelwithChatTemplate # noqa: F401
TurboMindModelwithChatTemplate # noqa: F401
from .unigpt_api import UniGPT # noqa: F401 from .unigpt_api import UniGPT # noqa: F401
from .vllm import VLLM # noqa: F401 from .vllm import VLLM # noqa: F401
from .vllm_with_tf_above_v4_33 import VLLMwithChatTemplate # noqa: F401 from .vllm_with_tf_above_v4_33 import VLLMwithChatTemplate # noqa: F401

View File

@ -0,0 +1,215 @@
import concurrent
import concurrent.futures
import os
import socket
import traceback
from typing import Dict, List, Optional, Union
import requests
from requests.adapters import HTTPAdapter
from retrying import retry
from urllib3.connection import HTTPConnection
from opencompass.utils.prompt import PromptList
from .base_api import BaseAPIModel
PromptType = Union[PromptList, str]
class HTTPAdapterWithSocketOptions(HTTPAdapter):
def __init__(self, *args, **kwargs):
self._socket_options = HTTPConnection.default_socket_options + [
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
(socket.SOL_TCP, socket.TCP_KEEPIDLE, 75),
(socket.SOL_TCP, socket.TCP_KEEPINTVL, 30),
(socket.SOL_TCP, socket.TCP_KEEPCNT, 120),
]
super(HTTPAdapterWithSocketOptions, self).__init__(*args, **kwargs)
def init_poolmanager(self, *args, **kwargs):
if self._socket_options is not None:
kwargs["socket_options"] = self._socket_options
super(HTTPAdapterWithSocketOptions, self).init_poolmanager(*args, **kwargs)
class BailingAPI(BaseAPIModel):
"""Model wrapper around Bailing Service.
Args:
ouput_key (str): key for prediction
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
generation_kwargs: other params
retry (int): Number of retires if the API call fails. Defaults to 2.
"""
def __init__(
self,
path: str,
token: str,
url: str,
meta_template: Optional[Dict] = None,
query_per_second: int = 1,
retry: int = 3,
generation_kwargs: Dict = {},
max_seq_len=4096,
):
super().__init__(
path=path,
max_seq_len=max_seq_len,
query_per_second=query_per_second,
meta_template=meta_template,
retry=retry,
generation_kwargs=generation_kwargs,
)
self.logger.info(f"Bailing API Model Init path: {path} url={url}")
if not token:
token = os.environ.get("BAILING_API_KEY")
if token:
self._headers = {"Authorization": f"Bearer {token}"}
else:
raise RuntimeError(f"There is not valid token.")
self._headers["Content-Type"] = "application/json"
self._url = url if url else "https://bailingchat.alipay.com/chat/completions"
self._model = path
self._sessions = []
self._num = (
int(os.environ.get("BAILING_API_PARALLEL_NUM"))
if os.environ.get("BAILING_API_PARALLEL_NUM")
else 1
)
try:
for _ in range(self._num):
adapter = HTTPAdapterWithSocketOptions()
sess = requests.Session()
sess.mount("http://", adapter)
sess.mount("https://", adapter)
self._sessions.append(sess)
except Exception as e:
self.logger.error(f"Fail to setup the session. {e}")
raise e
def generate(
self,
inputs: Union[List[str], PromptList],
max_out_len: int = 4096,
) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (Union[List[str], PromptList]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with concurrent.futures.ThreadPoolExecutor(
max_workers=self._num,
) as executor:
future_to_m = {
executor.submit(
self._generate,
self._sessions[i % self._num],
input,
max_out_len,
): i
for i, input in enumerate(inputs)
}
results = []
for future in concurrent.futures.as_completed(future_to_m):
m = future_to_m[future]
resp = future.result()
if resp and resp.status_code == 200:
try:
result = resp.json()
except:
results.append("")
else:
if (
result.get("choices")
and result["choices"][0].get("message")
and result["choices"][0]["message"].get("content")
):
results.append(result["choices"][0]["message"]["content"])
else:
results.append("")
self.flush()
return results
def _generate(
self,
sess,
input: Union[str, PromptList],
max_out_len: int,
) -> str:
"""Generate results given an input.
Args:
inputs (str or PromptList): A string or PromptDict.
The PromptDict should be organized in OpenCompass' API format.
max_out_len (int): The maximum length of the output.
Returns:
str: The generated string.
"""
if isinstance(input, str):
messages = [{"role": "user", "content": input}]
else:
messages = []
for item in input:
content = item["prompt"]
if not content:
continue
message = {"content": content}
if item["role"] == "HUMAN":
message["role"] = "user"
elif item["role"] == "BOT":
message["role"] = "assistant"
elif item["role"] == "SYSTEM":
message["role"] = "system"
else:
message["role"] = item["role"]
messages.append(message)
request = {
"model": self._model,
"messages": messages,
"max_seq_len": max(
max_out_len if max_out_len else 4096,
self.max_seq_len if self.max_seq_len else 4096,
),
}
request.update(self.generation_kwargs)
try:
retry_num = 0
while retry_num < self.retry:
response = self._infer_result(request, sess)
if response.status_code == 200:
break # success
elif response.status_code == 426:
retry_num += 1 # retry
else:
raise ValueError(f"Status code = {response.status_code}")
else:
raise ValueError(
f"Exceed the maximal retry times. Last status code = {response.status_code}"
)
except Exception as e:
self.logger.error(
f"Fail to inference request={request}; model_name={self.path}; error={e}, stack:{traceback.format_exc()}"
)
raise e
return response
@retry(stop_max_attempt_number=3, wait_fixed=16000) # ms
def _infer_result(self, request, sess):
response = sess.request(
"POST",
self._url,
json=request,
headers=self._headers,
timeout=500,
)
return response

View File

@ -23,6 +23,7 @@ python-Levenshtein
rank_bm25==0.2.2 rank_bm25==0.2.2
rapidfuzz rapidfuzz
requests>=2.31.0 requests>=2.31.0
retrying
rich rich
rouge rouge
-e git+https://github.com/Isaac-JL-Chen/rouge_chinese.git@master#egg=rouge_chinese -e git+https://github.com/Isaac-JL-Chen/rouge_chinese.git@master#egg=rouge_chinese