mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support 360API and FixKRetriever for CSQA dataset (#601)
* [Feature] Support 360API and FixKRetriever for CSQA dataset * Update API * Update API * [Feature] Support 360API and FixKRetriever for CSQA dataset * Update API * Update API * rm mathbench * fix_lint * Update opencompass/models/bytedance_api.py Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com> * update * update * update --------- Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com>
This commit is contained in:
parent
d3b0d5c4ce
commit
d925748266
55
configs/datasets/commonsenseqa/commonsenseqa_gen_1da2d0.py
Normal file
55
configs/datasets/commonsenseqa/commonsenseqa_gen_1da2d0.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Use FixKRetriever to avoid hang caused by the Huggingface
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import commonsenseqaDataset
|
||||
from opencompass.utils.text_postprocessors import first_capital_postprocess
|
||||
|
||||
commonsenseqa_reader_cfg = dict(
|
||||
input_columns=["question", "A", "B", "C", "D", "E"],
|
||||
output_column="answerKey",
|
||||
test_split="validation")
|
||||
|
||||
_ice_template = dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
begin="</E>",
|
||||
round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer:",
|
||||
),
|
||||
dict(
|
||||
role="BOT",
|
||||
prompt="{answerKey}",
|
||||
),
|
||||
],
|
||||
),
|
||||
ice_token="</E>",
|
||||
)
|
||||
|
||||
commonsenseqa_infer_cfg = dict(
|
||||
ice_template=_ice_template,
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
commonsenseqa_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
pred_postprocessor=dict(type=first_capital_postprocess),
|
||||
)
|
||||
|
||||
commonsenseqa_datasets = [
|
||||
dict(
|
||||
abbr='commonsense_qa',
|
||||
type=commonsenseqaDataset,
|
||||
path='./data/commonsenseqa',
|
||||
reader_cfg=commonsenseqa_reader_cfg,
|
||||
infer_cfg=commonsenseqa_infer_cfg,
|
||||
eval_cfg=commonsenseqa_eval_cfg,
|
||||
)
|
||||
]
|
||||
|
||||
del _ice_template
|
42
configs/datasets/commonsenseqa/commonsenseqa_ppl_e51e32.py
Normal file
42
configs/datasets/commonsenseqa/commonsenseqa_ppl_e51e32.py
Normal file
@ -0,0 +1,42 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import commonsenseqaDataset
|
||||
|
||||
commonsenseqa_reader_cfg = dict(
|
||||
input_columns=['question', 'A', 'B', 'C', 'D', 'E'],
|
||||
output_column='answerKey',
|
||||
test_split='validation')
|
||||
|
||||
_ice_template = dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
ans: dict(
|
||||
begin='</E>',
|
||||
round=[
|
||||
dict(role="HUMAN", prompt="Question: {question}\nAnswer: "),
|
||||
dict(role="BOT", prompt=ans_token),
|
||||
])
|
||||
for ans, ans_token in [["A", "{A}"], ["B", "{B}"],
|
||||
["C", "{C}"], ["D", "{D}"],
|
||||
["E", "{E}"]]
|
||||
},
|
||||
ice_token='</E>')
|
||||
|
||||
commonsenseqa_infer_cfg = dict(
|
||||
ice_template=_ice_template,
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
commonsenseqa_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
commonsenseqa_datasets = [
|
||||
dict(
|
||||
abbr='commonsense_qa',
|
||||
type=commonsenseqaDataset,
|
||||
path='./data/commonsenseqa',
|
||||
reader_cfg=commonsenseqa_reader_cfg,
|
||||
infer_cfg=commonsenseqa_infer_cfg,
|
||||
eval_cfg=commonsenseqa_eval_cfg)
|
||||
]
|
36
configs/eval_api_360.py
Normal file
36
configs/eval_api_360.py
Normal file
@ -0,0 +1,36 @@
|
||||
from mmengine.config import read_base
|
||||
from opencompass.models import AI360GPT
|
||||
from opencompass.partitioners import NaivePartitioner
|
||||
from opencompass.runners.local_api import LocalAPIRunner
|
||||
from opencompass.tasks import OpenICLInferTask
|
||||
|
||||
with read_base():
|
||||
from .summarizers.medium import summarizer
|
||||
from .datasets.ceval.ceval_gen import ceval_datasets
|
||||
|
||||
datasets = [
|
||||
*ceval_datasets,
|
||||
]
|
||||
|
||||
models = [
|
||||
dict(
|
||||
abbr='360GPT_S2_V9',
|
||||
type=AI360GPT,
|
||||
path='360GPT_S2_V9',
|
||||
key="xxxxxxxxxxxx",
|
||||
query_per_second=1,
|
||||
max_out_len=2048,
|
||||
max_seq_len=2048,
|
||||
batch_size=8),
|
||||
]
|
||||
|
||||
infer = dict(
|
||||
partitioner=dict(type=NaivePartitioner),
|
||||
runner=dict(
|
||||
type=LocalAPIRunner,
|
||||
max_num_workers=2,
|
||||
concurrent_users=2,
|
||||
task=dict(type=OpenICLInferTask)),
|
||||
)
|
||||
|
||||
work_dir ="./output/360GPT_S2_V9"
|
39
configs/eval_api_baichuan.py
Normal file
39
configs/eval_api_baichuan.py
Normal file
@ -0,0 +1,39 @@
|
||||
from mmengine.config import read_base
|
||||
from opencompass.models import BaiChuan
|
||||
|
||||
from opencompass.partitioners import NaivePartitioner
|
||||
from opencompass.runners.local_api import LocalAPIRunner
|
||||
from opencompass.tasks import OpenICLInferTask
|
||||
|
||||
with read_base():
|
||||
from .summarizers.medium import summarizer
|
||||
from .datasets.ceval.ceval_gen import ceval_datasets
|
||||
|
||||
datasets = [
|
||||
*ceval_datasets,
|
||||
]
|
||||
|
||||
models = [
|
||||
dict(
|
||||
abbr='Baichuan2-53B',
|
||||
type=BaiChuan,
|
||||
path='Baichuan2-53B',
|
||||
api_key='xxxxxx',
|
||||
secret_key="xxxxx",
|
||||
url="xxxxx",
|
||||
query_per_second=1,
|
||||
max_out_len=2048,
|
||||
max_seq_len=2048,
|
||||
batch_size=8),
|
||||
]
|
||||
|
||||
infer = dict(
|
||||
partitioner=dict(type=NaivePartitioner),
|
||||
runner=dict(
|
||||
type=LocalAPIRunner,
|
||||
max_num_workers=2,
|
||||
concurrent_users=2,
|
||||
task=dict(type=OpenICLInferTask)),
|
||||
)
|
||||
|
||||
work_dir = "outputs/api_baichuan53b/"
|
43
configs/eval_api_pangu.py
Normal file
43
configs/eval_api_pangu.py
Normal file
@ -0,0 +1,43 @@
|
||||
from mmengine.config import read_base
|
||||
from opencompass.models import PanGu
|
||||
|
||||
from opencompass.partitioners import NaivePartitioner
|
||||
from opencompass.runners.local_api import LocalAPIRunner
|
||||
from opencompass.tasks import OpenICLInferTask
|
||||
|
||||
with read_base():
|
||||
from .summarizers.medium import summarizer
|
||||
from .datasets.ceval.ceval_gen import ceval_datasets
|
||||
|
||||
datasets = [
|
||||
*ceval_datasets,
|
||||
]
|
||||
|
||||
models = [
|
||||
dict(
|
||||
abbr='pangu',
|
||||
type=PanGu,
|
||||
path='pangu',
|
||||
access_key="xxxxxx",
|
||||
secret_key="xxxxxx",
|
||||
url = "xxxxxx",
|
||||
# url of token sever, used for generate token, like "https://xxxxxx.myhuaweicloud.com/v3/auth/tokens",
|
||||
token_url = "xxxxxx",
|
||||
# scope-project-name, used for generate token
|
||||
project_name = "xxxxxx",
|
||||
query_per_second=1,
|
||||
max_out_len=2048,
|
||||
max_seq_len=2048,
|
||||
batch_size=8),
|
||||
]
|
||||
|
||||
infer = dict(
|
||||
partitioner=dict(type=NaivePartitioner),
|
||||
runner=dict(
|
||||
type=LocalAPIRunner,
|
||||
max_num_workers=2,
|
||||
concurrent_users=2,
|
||||
task=dict(type=OpenICLInferTask)),
|
||||
)
|
||||
|
||||
work_dir = "outputs/api_pangu/"
|
35
configs/eval_api_sensetime.py
Normal file
35
configs/eval_api_sensetime.py
Normal file
@ -0,0 +1,35 @@
|
||||
from mmengine.config import read_base
|
||||
from opencompass.models import SenseTime
|
||||
from opencompass.partitioners import NaivePartitioner
|
||||
from opencompass.runners.local_api import LocalAPIRunner
|
||||
from opencompass.tasks import OpenICLInferTask
|
||||
|
||||
with read_base():
|
||||
from .summarizers.medium import summarizer
|
||||
from .datasets.ceval.ceval_gen import ceval_datasets
|
||||
|
||||
datasets = [
|
||||
*ceval_datasets,
|
||||
]
|
||||
|
||||
models = [
|
||||
dict(
|
||||
abbr='nova-ptc-xl-v1',
|
||||
type=SenseTime,
|
||||
path='nova-ptc-xl-v1',
|
||||
key='xxxxxxxxxxxxxx',
|
||||
url='xxxxxxxxxxx',
|
||||
query_per_second=1,
|
||||
max_out_len=2048,
|
||||
max_seq_len=2048,
|
||||
batch_size=8),
|
||||
]
|
||||
|
||||
infer = dict(
|
||||
partitioner=dict(type=NaivePartitioner),
|
||||
runner=dict(
|
||||
type=LocalAPIRunner,
|
||||
max_num_workers=2,
|
||||
concurrent_users=2,
|
||||
task=dict(type=OpenICLInferTask)),
|
||||
)
|
@ -1,6 +1,10 @@
|
||||
from .ai360_api import AI360GPT # noqa: F401
|
||||
from .alaya import AlayaLM # noqa: F401
|
||||
from .baichuan_api import BaiChuan # noqa: F401
|
||||
from .baidu_api import ERNIEBot # noqa: F401
|
||||
from .base import BaseModel, LMTemplateParser # noqa
|
||||
from .base_api import APITemplateParser, BaseAPIModel # noqa
|
||||
from .bytedance_api import ByteDance # noqa: F401
|
||||
from .claude_api import Claude # noqa: F401
|
||||
from .glm import GLM130B # noqa: F401, F403
|
||||
from .huggingface import HuggingFace # noqa: F401, F403
|
||||
@ -11,5 +15,7 @@ from .lightllm_api import LightllmAPI # noqa: F401
|
||||
from .llama2 import Llama2, Llama2Chat # noqa: F401, F403
|
||||
from .minimax_api import MiniMax # noqa: F401
|
||||
from .openai_api import OpenAI # noqa: F401
|
||||
from .pangu_api import PanGu # noqa: F401
|
||||
from .sensetime_api import SenseTime # noqa: F401
|
||||
from .xunfei_api import XunFei # noqa: F401
|
||||
from .zhipuai_api import ZhiPuAI # noqa: F401
|
||||
|
168
opencompass/models/ai360_api.py
Normal file
168
opencompass/models/ai360_api.py
Normal file
@ -0,0 +1,168 @@
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
from .base_api import BaseAPIModel
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
||||
|
||||
class AI360GPT(BaseAPIModel):
|
||||
"""Model wrapper around 360 GPT.
|
||||
|
||||
Documentations: https://ai.360.com/platform/docs/overview
|
||||
|
||||
Args:
|
||||
path (str): Model name
|
||||
key (str): Provide API Key
|
||||
url (str): Provided URL
|
||||
query_per_second (int): The maximum queries allowed per second
|
||||
between two consecutive calls of the API. Defaults to 2.
|
||||
max_seq_len (int): Unused here.
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str, # model name, e.g.: 360GPT_S2_V9
|
||||
key: str,
|
||||
url: str = 'https://api.360.cn/v1/chat/completions',
|
||||
query_per_second: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2,
|
||||
):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
query_per_second=query_per_second,
|
||||
meta_template=meta_template,
|
||||
retry=retry)
|
||||
self.headers = {
|
||||
'Authorization': f'Bearer {key}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
self.model = path
|
||||
self.url = url
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: List[str or PromptList],
|
||||
max_out_len: int = 512,
|
||||
) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = list(
|
||||
executor.map(self._generate, inputs,
|
||||
[max_out_len] * len(inputs)))
|
||||
self.flush()
|
||||
return results
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
input: str or PromptList,
|
||||
max_out_len: int = 512,
|
||||
) -> str:
|
||||
"""Generate results given an input.
|
||||
|
||||
Args:
|
||||
inputs (str or PromptList): A string or PromptDict.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
assert isinstance(input, (str, PromptList))
|
||||
|
||||
if isinstance(input, str):
|
||||
messages = [{'role': 'user', 'content': input}]
|
||||
else:
|
||||
messages = []
|
||||
for item in input:
|
||||
msg = {'content': item['prompt']}
|
||||
if item['role'] == 'HUMAN':
|
||||
msg['role'] = 'user'
|
||||
elif item['role'] == 'BOT':
|
||||
msg['role'] = 'assistant'
|
||||
elif item['role'] == 'SYSTEM':
|
||||
msg['role'] = 'system'
|
||||
messages.append(msg)
|
||||
|
||||
data = {
|
||||
'model': self.model,
|
||||
'messages': messages,
|
||||
'stream': False,
|
||||
'temperature': 0.9,
|
||||
'max_tokens': 2048,
|
||||
'top_p': 0.5,
|
||||
'tok_k': 0,
|
||||
'repetition_penalty': 1.05,
|
||||
# "num_beams": 1,
|
||||
# "user": "OpenCompass"
|
||||
}
|
||||
|
||||
max_num_retries = 0
|
||||
while max_num_retries < self.retry:
|
||||
self.acquire()
|
||||
# payload = json.dumps(data)
|
||||
raw_response = requests.request('POST',
|
||||
url=self.url,
|
||||
headers=self.headers,
|
||||
json=data)
|
||||
response = raw_response.json()
|
||||
self.release()
|
||||
|
||||
if response is None:
|
||||
print('Connection error, reconnect.')
|
||||
# if connect error, frequent requests will casuse
|
||||
# continuous unstable network, therefore wait here
|
||||
# to slow down the request
|
||||
self.wait()
|
||||
continue
|
||||
if raw_response.status_code == 200:
|
||||
try:
|
||||
msg = response['choices'][0]['message']['content'].strip()
|
||||
return msg
|
||||
|
||||
except KeyError:
|
||||
if 'error' in response:
|
||||
# tpm(token per minitue) limit
|
||||
if response['erro']['code'] == '1005':
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
self.logger.error('Find error message in response: ',
|
||||
str(response['error']))
|
||||
|
||||
# sensitive content, prompt overlength, network error
|
||||
# or illegal prompt
|
||||
if (raw_response.status_code == 400
|
||||
or raw_response.status_code == 401
|
||||
or raw_response.status_code == 402
|
||||
or raw_response.status_code == 429
|
||||
or raw_response.status_code == 500):
|
||||
print(raw_response.text)
|
||||
# return ''
|
||||
continue
|
||||
print(raw_response)
|
||||
max_num_retries += 1
|
||||
|
||||
raise RuntimeError(raw_response.text)
|
164
opencompass/models/baichuan_api.py
Normal file
164
opencompass/models/baichuan_api.py
Normal file
@ -0,0 +1,164 @@
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
from .base_api import BaseAPIModel
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
||||
|
||||
class BaiChuan(BaseAPIModel):
|
||||
"""Model wrapper around Baichuan.
|
||||
|
||||
Documentation: https://platform.baichuan-ai.com/docs/api
|
||||
|
||||
Args:
|
||||
path (str): The name of Baichuan model.
|
||||
e.g. `Baichuan2-53B`
|
||||
api_key (str): Provided api key
|
||||
secretkey (str): secretkey in order to obtain access_token
|
||||
url (str): Provide url
|
||||
query_per_second (int): The maximum queries allowed per second
|
||||
between two consecutive calls of the API. Defaults to 1.
|
||||
max_seq_len (int): Unused here.
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
api_key: str,
|
||||
secret_key: str,
|
||||
url: str,
|
||||
query_per_second: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2,
|
||||
):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
query_per_second=query_per_second,
|
||||
meta_template=meta_template,
|
||||
retry=retry)
|
||||
|
||||
self.api_key = api_key
|
||||
self.secret_key = secret_key
|
||||
self.url = url
|
||||
self.model = path
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: List[str or PromptList],
|
||||
max_out_len: int = 512,
|
||||
) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = list(
|
||||
executor.map(self._generate, inputs,
|
||||
[max_out_len] * len(inputs)))
|
||||
self.flush()
|
||||
return results
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
input: str or PromptList,
|
||||
max_out_len: int = 512,
|
||||
) -> str:
|
||||
"""Generate results given an input.
|
||||
|
||||
Args:
|
||||
inputs (str or PromptList): A string or PromptDict.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
|
||||
assert isinstance(input, (str, PromptList))
|
||||
|
||||
if isinstance(input, str):
|
||||
messages = [{'role': 'user', 'content': input}]
|
||||
else:
|
||||
messages = []
|
||||
for item in input:
|
||||
msg = {'content': item['prompt']}
|
||||
if item['role'] == 'HUMAN':
|
||||
msg['role'] = 'user'
|
||||
elif item['role'] == 'BOT':
|
||||
msg['role'] = 'assistant'
|
||||
|
||||
messages.append(msg)
|
||||
|
||||
data = {'model': self.model, 'messages': messages}
|
||||
|
||||
def calculate_md5(input_string):
|
||||
md5 = hashlib.md5()
|
||||
md5.update(input_string.encode('utf-8'))
|
||||
encrypted = md5.hexdigest()
|
||||
return encrypted
|
||||
|
||||
json_data = json.dumps(data)
|
||||
time_stamp = int(time.time())
|
||||
signature = calculate_md5(self.secret_key + json_data +
|
||||
str(time_stamp))
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer ' + self.api_key,
|
||||
'X-BC-Request-Id': 'your requestId',
|
||||
'X-BC-Timestamp': str(time_stamp),
|
||||
'X-BC-Signature': signature,
|
||||
'X-BC-Sign-Algo': 'MD5',
|
||||
}
|
||||
|
||||
max_num_retries = 0
|
||||
while max_num_retries < self.retry:
|
||||
self.acquire()
|
||||
raw_response = requests.request('POST',
|
||||
url=self.url,
|
||||
headers=headers,
|
||||
json=data)
|
||||
response = raw_response.json()
|
||||
self.release()
|
||||
|
||||
if response is None:
|
||||
print('Connection error, reconnect.')
|
||||
# if connect error, frequent requests will casuse
|
||||
# continuous unstable network, therefore wait here
|
||||
# to slow down the request
|
||||
self.wait()
|
||||
continue
|
||||
if raw_response.status_code == 200 and response['code'] == 0:
|
||||
# msg = json.load(response.text)
|
||||
# response
|
||||
msg = response['data']['messages'][0]['content']
|
||||
return msg
|
||||
|
||||
if response['code'] != 0:
|
||||
print(response)
|
||||
return ''
|
||||
print(response)
|
||||
max_num_retries += 1
|
||||
|
||||
raise RuntimeError(response)
|
196
opencompass/models/baidu_api.py
Normal file
196
opencompass/models/baidu_api.py
Normal file
@ -0,0 +1,196 @@
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
from .base_api import BaseAPIModel
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
||||
|
||||
class ERNIEBot(BaseAPIModel):
|
||||
"""Model wrapper around ERNIE-Bot.
|
||||
|
||||
Documentation: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
|
||||
|
||||
Args:
|
||||
path (str): The name of ENRIE-bot model.
|
||||
e.g. `erniebot`
|
||||
model_type (str): The type of the model
|
||||
e.g. `chat`
|
||||
secretkey (str): secretkey in order to obtain access_token
|
||||
key (str): Authorization key.
|
||||
query_per_second (int): The maximum queries allowed per second
|
||||
between two consecutive calls of the API. Defaults to 1.
|
||||
max_seq_len (int): Unused here.
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
key: str,
|
||||
secretkey: str,
|
||||
url: str,
|
||||
query_per_second: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2,
|
||||
):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
query_per_second=query_per_second,
|
||||
meta_template=meta_template,
|
||||
retry=retry)
|
||||
self.headers = {'Content_Type': 'application/json'}
|
||||
self.secretkey = secretkey
|
||||
self.key = key
|
||||
self.url = url
|
||||
self.model = path
|
||||
|
||||
def _generate_access_token(self):
|
||||
try:
|
||||
BAIDU_APIKEY = self.key
|
||||
BAIDU_SECRETKEY = self.secretkey
|
||||
url = f'https://aip.baidubce.com/oauth/2.0/token?' \
|
||||
f'client_id={BAIDU_APIKEY}&client_secret={BAIDU_SECRETKEY}' \
|
||||
f'&grant_type=client_credentials'
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
response = requests.request('POST', url, headers=headers)
|
||||
resp_dict = response.json()
|
||||
if response.status_code == 200:
|
||||
access_token = resp_dict.get('access_token')
|
||||
refresh_token = resp_dict.get('refresh_token')
|
||||
if 'error' in resp_dict:
|
||||
raise ValueError(f'Failed to obtain certificate.'
|
||||
f'{resp_dict.get("error")}')
|
||||
else:
|
||||
return access_token, refresh_token
|
||||
else:
|
||||
error = resp_dict.get('error')
|
||||
raise ValueError(
|
||||
f'Failed to requests obtain certificate {error}.')
|
||||
except Exception as ex:
|
||||
raise ex
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: List[str or PromptList],
|
||||
max_out_len: int = 512,
|
||||
) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = list(
|
||||
executor.map(self._generate, inputs,
|
||||
[max_out_len] * len(inputs)))
|
||||
self.flush()
|
||||
return results
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
input: str or PromptList,
|
||||
max_out_len: int = 512,
|
||||
) -> str:
|
||||
"""Generate results given an input.
|
||||
|
||||
Args:
|
||||
inputs (str or PromptList): A string or PromptDict.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
assert isinstance(input, (str, PromptList))
|
||||
"""
|
||||
{
|
||||
"messages": [
|
||||
{"role":"user","content":"请介绍一下你自己"},
|
||||
{"role":"assistant","content":"我是百度公司开发的人工智能语言模型"},
|
||||
{"role":"user","content": "我在上海,周末可以去哪里玩?"},
|
||||
{"role":"assistant","content": "上海是一个充满活力和文化氛围的城市"},
|
||||
{"role":"user","content": "周末这里的天气怎么样?"}
|
||||
]
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(input, str):
|
||||
messages = [{'role': 'user', 'content': input}]
|
||||
else:
|
||||
messages = []
|
||||
for item in input:
|
||||
msg = {'content': item['prompt']}
|
||||
if item['role'] == 'HUMAN':
|
||||
msg['role'] = 'user'
|
||||
elif item['role'] == 'BOT':
|
||||
msg['role'] = 'assistant'
|
||||
|
||||
messages.append(msg)
|
||||
data = {'messages': messages}
|
||||
|
||||
max_num_retries = 0
|
||||
while max_num_retries < self.retry:
|
||||
self.acquire()
|
||||
access_token, _ = self._generate_access_token()
|
||||
raw_response = requests.request('POST',
|
||||
url=self.url + access_token,
|
||||
headers=self.headers,
|
||||
json=data)
|
||||
response = raw_response.json()
|
||||
self.release()
|
||||
|
||||
if response is None:
|
||||
print('Connection error, reconnect.')
|
||||
# if connect error, frequent requests will casuse
|
||||
# continuous unstable network, therefore wait here
|
||||
# to slow down the request
|
||||
self.wait()
|
||||
continue
|
||||
if raw_response.status_code == 200:
|
||||
try:
|
||||
msg = response['result']
|
||||
return msg
|
||||
except KeyError:
|
||||
print(response)
|
||||
self.logger.error(str(response['error_code']))
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
if (response['error_code'] == 110 or response['error_code'] == 100
|
||||
or response['error_code'] == 111
|
||||
or response['error_code'] == 200
|
||||
or response['error_code'] == 1000
|
||||
or response['error_code'] == 1001
|
||||
or response['error_code'] == 1002
|
||||
or response['error_code'] == 21002
|
||||
or response['error_code'] == 216100
|
||||
or response['error_code'] == 336001
|
||||
or response['error_code'] == 336003
|
||||
or response['error_code'] == 336000):
|
||||
print(response['error_msg'])
|
||||
return ''
|
||||
print(response)
|
||||
max_num_retries += 1
|
||||
|
||||
raise RuntimeError(response['error_msg'])
|
@ -1,4 +1,5 @@
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
@ -64,6 +65,38 @@ class BaseAPIModel(BaseModel):
|
||||
' gen-based evaluation yet, try ppl-based '
|
||||
'instead.')
|
||||
|
||||
def flush(self):
|
||||
"""Ensure simultaneous emptying of stdout and stderr when concurrent
|
||||
resources are available.
|
||||
|
||||
When employing multiprocessing with standard I/O redirected to files,
|
||||
it is crucial to clear internal data for examination or prevent log
|
||||
loss in case of system failures."
|
||||
"""
|
||||
if hasattr(self, 'tokens'):
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
def acquire(self):
|
||||
"""Acquire concurrent resources if exists.
|
||||
|
||||
This behavior will fall back to wait with query_per_second if there are
|
||||
no concurrent resources.
|
||||
"""
|
||||
if hasattr(self, 'tokens'):
|
||||
self.tokens.acquire()
|
||||
else:
|
||||
self.wait()
|
||||
|
||||
def release(self):
|
||||
"""Release concurrent resources if acquired.
|
||||
|
||||
This behavior will fall back to do nothing if there are no concurrent
|
||||
resources.
|
||||
"""
|
||||
if hasattr(self, 'tokens'):
|
||||
self.tokens.release()
|
||||
|
||||
@abstractmethod
|
||||
def get_ppl(self,
|
||||
inputs: List[PromptType],
|
||||
|
172
opencompass/models/bytedance_api.py
Normal file
172
opencompass/models/bytedance_api.py
Normal file
@ -0,0 +1,172 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
from .base_api import BaseAPIModel
|
||||
|
||||
try:
|
||||
from volcengine.maas import ChatRole, MaasException, MaasService
|
||||
except ImportError:
|
||||
ChatRole, MaasException, MaasService = None, None, None
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
||||
|
||||
class ByteDance(BaseAPIModel):
|
||||
"""Model wrapper around ByteDance.
|
||||
|
||||
Args:
|
||||
path (str): The name of ByteDance model.
|
||||
e.g. `skylark`
|
||||
model_type (str): The type of the model
|
||||
e.g. `chat`
|
||||
secretkey (str): secretkey in order to obtain access_token
|
||||
key (str): Authorization key.
|
||||
query_per_second (int): The maximum queries allowed per second
|
||||
between two consecutive calls of the API. Defaults to 1.
|
||||
max_seq_len (int): Unused here.
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
accesskey: str,
|
||||
secretkey: str,
|
||||
url: str,
|
||||
query_per_second: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2,
|
||||
):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
query_per_second=query_per_second,
|
||||
meta_template=meta_template,
|
||||
retry=retry)
|
||||
if not ChatRole:
|
||||
print('Please install related packages via'
|
||||
' `pip install volcengine`')
|
||||
|
||||
self.accesskey = accesskey
|
||||
self.secretkey = secretkey
|
||||
self.url = url
|
||||
self.model = path
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: List[str or PromptList],
|
||||
max_out_len: int = 512,
|
||||
) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = list(
|
||||
executor.map(self._generate, inputs,
|
||||
[max_out_len] * len(inputs)))
|
||||
self.flush()
|
||||
return results
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
input: str or PromptList,
|
||||
max_out_len: int = 512,
|
||||
) -> str:
|
||||
"""Generate results given an input.
|
||||
|
||||
Args:
|
||||
inputs (str or PromptList): A string or PromptDict.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
str: The generated string.
|
||||
|
||||
messages
|
||||
[
|
||||
{
|
||||
"role": ChatRole.USER,
|
||||
"content": "天为什么这么蓝?"
|
||||
}, {
|
||||
"role": ChatRole.ASSISTANT,
|
||||
"content": "因为有你"
|
||||
}, {
|
||||
"role": ChatRole.USER,
|
||||
"content": "花儿为什么这么香?"
|
||||
},
|
||||
]
|
||||
"""
|
||||
assert isinstance(input, (str, PromptList))
|
||||
|
||||
if isinstance(input, str):
|
||||
messages = [{'role': ChatRole.USER, 'content': input}]
|
||||
else:
|
||||
messages = []
|
||||
for item in input:
|
||||
msg = {'content': item['prompt']}
|
||||
if item['role'] == 'HUMAN':
|
||||
msg['role'] = ChatRole.USER
|
||||
elif item['role'] == 'BOT':
|
||||
msg['role'] = ChatRole.ASSISTANT
|
||||
|
||||
messages.append(msg)
|
||||
|
||||
maas = MaasService(self.url, 'cn-beijing')
|
||||
maas.set_ak(self.accesskey)
|
||||
maas.set_sk(self.secretkey)
|
||||
|
||||
req = {
|
||||
'model': {
|
||||
'name': 'skylark-pro-public',
|
||||
},
|
||||
'messages': messages
|
||||
}
|
||||
|
||||
def _chat(maas, req):
|
||||
try:
|
||||
resp = maas.chat(req)
|
||||
return resp
|
||||
except MaasException as e:
|
||||
print(e)
|
||||
return e
|
||||
|
||||
max_num_retries = 0
|
||||
while max_num_retries < self.retry:
|
||||
self.acquire()
|
||||
response = _chat(maas, req)
|
||||
|
||||
self.release()
|
||||
|
||||
if response is None:
|
||||
print('Connection error, reconnect.')
|
||||
# if connect error, frequent requests will casuse
|
||||
# continuous unstable network, therefore wait here
|
||||
# to slow down the request
|
||||
self.wait()
|
||||
continue
|
||||
if not isinstance(response, MaasException):
|
||||
# response
|
||||
msg = response.choice.message.content
|
||||
return msg
|
||||
|
||||
if isinstance(response, MaasException):
|
||||
print(response)
|
||||
return ''
|
||||
print(response)
|
||||
max_num_retries += 1
|
||||
|
||||
raise RuntimeError(response)
|
@ -1,4 +1,3 @@
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
@ -81,37 +80,6 @@ class MiniMax(BaseAPIModel):
|
||||
self.flush()
|
||||
return results
|
||||
|
||||
def flush(self):
|
||||
"""Flush stdout and stderr when concurrent resources exists.
|
||||
|
||||
When use multiproessing with standard io rediected to files, need to
|
||||
flush internal information for examination or log loss when system
|
||||
breaks.
|
||||
"""
|
||||
if hasattr(self, 'tokens'):
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
def acquire(self):
|
||||
"""Acquire concurrent resources if exists.
|
||||
|
||||
This behavior will fall back to wait with query_per_second if there are
|
||||
no concurrent resources.
|
||||
"""
|
||||
if hasattr(self, 'tokens'):
|
||||
self.tokens.acquire()
|
||||
else:
|
||||
self.wait()
|
||||
|
||||
def release(self):
|
||||
"""Release concurrent resources if acquired.
|
||||
|
||||
This behavior will fall back to do nothing if there are no concurrent
|
||||
resources.
|
||||
"""
|
||||
if hasattr(self, 'tokens'):
|
||||
self.tokens.release()
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
input: str or PromptList,
|
||||
|
182
opencompass/models/pangu_api.py
Normal file
182
opencompass/models/pangu_api.py
Normal file
@ -0,0 +1,182 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
from .base_api import BaseAPIModel
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
||||
|
||||
class PanGu(BaseAPIModel):
|
||||
"""Model wrapper around PanGu.
|
||||
|
||||
Args:
|
||||
path (str): The name of Pangu model.
|
||||
e.g. `pangu`
|
||||
access_key (str): provided access_key
|
||||
secret_key (str): secretkey in order to obtain access_token
|
||||
url (str): provide url for requests
|
||||
token_url (str): url of token server
|
||||
project_name (str): project name for generate the token
|
||||
query_per_second (int): The maximum queries allowed per second
|
||||
between two consecutive calls of the API. Defaults to 1.
|
||||
max_seq_len (int): Unused here.
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
access_key: str,
|
||||
secret_key: str,
|
||||
url: str,
|
||||
token_url: str,
|
||||
project_name: str,
|
||||
query_per_second: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2,
|
||||
):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
query_per_second=query_per_second,
|
||||
meta_template=meta_template,
|
||||
retry=retry)
|
||||
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
self.url = url
|
||||
self.token_url = token_url
|
||||
self.project_name = project_name
|
||||
self.model = path
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: List[str or PromptList],
|
||||
max_out_len: int = 512,
|
||||
) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = list(
|
||||
executor.map(self._generate, inputs,
|
||||
[max_out_len] * len(inputs)))
|
||||
self.flush()
|
||||
return results
|
||||
|
||||
def _get_token(self):
|
||||
url = self.token_url
|
||||
payload = {
|
||||
'auth': {
|
||||
'identity': {
|
||||
'methods': ['hw_ak_sk'],
|
||||
'hw_ak_sk': {
|
||||
'access': {
|
||||
'key': self.access_key
|
||||
},
|
||||
'secret': {
|
||||
'key': self.secret_key
|
||||
}
|
||||
}
|
||||
},
|
||||
'scope': {
|
||||
'project': {
|
||||
'name': self.project_name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
|
||||
response = requests.request('POST', url, headers=headers, json=payload)
|
||||
return response
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
input: str or PromptList,
|
||||
max_out_len: int = 512,
|
||||
) -> str:
|
||||
"""Generate results given an input.
|
||||
|
||||
Args:
|
||||
inputs (str or PromptList): A string or PromptDict.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
assert isinstance(input, (str, PromptList))
|
||||
|
||||
if isinstance(input, str):
|
||||
messages = [{'role': 'user', 'content': input}]
|
||||
else:
|
||||
messages = []
|
||||
for item in input:
|
||||
msg = {'content': item['prompt']}
|
||||
if item['role'] == 'HUMAN':
|
||||
msg['role'] = 'user'
|
||||
elif item['role'] == 'BOT':
|
||||
msg['role'] = 'system'
|
||||
|
||||
messages.append(msg)
|
||||
|
||||
data = {'messages': messages, 'stream': False}
|
||||
|
||||
token_response = self._get_token()
|
||||
if token_response.status_code == 201:
|
||||
token = token_response.headers['X-Subject-Token']
|
||||
print('请求成功!')
|
||||
else:
|
||||
msg = 'token生成失败'
|
||||
print(msg)
|
||||
return ''
|
||||
|
||||
headers = {'Content-Type': 'application/json', 'X-Auth-Token': token}
|
||||
|
||||
max_num_retries = 0
|
||||
while max_num_retries < self.retry:
|
||||
self.acquire()
|
||||
raw_response = requests.request('POST',
|
||||
url=self.url,
|
||||
headers=headers,
|
||||
json=data)
|
||||
response = raw_response.json()
|
||||
self.release()
|
||||
|
||||
if response is None:
|
||||
print('Connection error, reconnect.')
|
||||
# if connect error, frequent requests will casuse
|
||||
# continuous unstable network, therefore wait here
|
||||
# to slow down the request
|
||||
self.wait()
|
||||
continue
|
||||
if raw_response.status_code == 200:
|
||||
# msg = json.load(response.text)
|
||||
# response
|
||||
msg = response['choices'][0]['message']['content']
|
||||
return msg
|
||||
|
||||
if (raw_response.status_code != 200):
|
||||
print(response['error_msg'])
|
||||
return ''
|
||||
print(response)
|
||||
max_num_retries += 1
|
||||
|
||||
raise RuntimeError(response['error_msg'])
|
136
opencompass/models/sensetime_api.py
Normal file
136
opencompass/models/sensetime_api.py
Normal file
@ -0,0 +1,136 @@
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
from .base_api import BaseAPIModel
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
||||
|
||||
class SenseTime(BaseAPIModel):
|
||||
"""Model wrapper around SenseTime.
|
||||
|
||||
Args:
|
||||
path (str): The name of SenseTime model.
|
||||
e.g. `nova-ptc-xl-v1`
|
||||
key (str): Authorization key.
|
||||
query_per_second (int): The maximum queries allowed per second
|
||||
between two consecutive calls of the API. Defaults to 1.
|
||||
max_seq_len (int): Unused here.
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
key: str,
|
||||
url: str,
|
||||
query_per_second: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2,
|
||||
):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
query_per_second=query_per_second,
|
||||
meta_template=meta_template,
|
||||
retry=retry)
|
||||
self.headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {key}'
|
||||
}
|
||||
self.url = url
|
||||
self.model = path
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: List[str or PromptList],
|
||||
max_out_len: int = 512,
|
||||
) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = list(
|
||||
executor.map(self._generate, inputs,
|
||||
[max_out_len] * len(inputs)))
|
||||
self.flush()
|
||||
return results
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
input: str or PromptList,
|
||||
max_out_len: int = 512,
|
||||
) -> str:
|
||||
"""Generate results given an input.
|
||||
|
||||
Args:
|
||||
inputs (str or PromptList): A string or PromptDict.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
assert isinstance(input, (str, PromptList))
|
||||
|
||||
if isinstance(input, str):
|
||||
messages = [{'role': 'user', 'content': input}]
|
||||
else:
|
||||
messages = []
|
||||
for item in input:
|
||||
msg = {'content': item['prompt']}
|
||||
if item['role'] == 'HUMAN':
|
||||
msg['role'] = 'user'
|
||||
elif item['role'] == 'BOT':
|
||||
msg['role'] = 'assistant'
|
||||
|
||||
messages.append(msg)
|
||||
|
||||
data = {'messages': messages, 'model': self.model}
|
||||
|
||||
max_num_retries = 0
|
||||
while max_num_retries < self.retry:
|
||||
self.acquire()
|
||||
raw_response = requests.request('POST',
|
||||
url=self.url,
|
||||
headers=self.headers,
|
||||
json=data)
|
||||
response = raw_response.json()
|
||||
self.release()
|
||||
|
||||
if response is None:
|
||||
print('Connection error, reconnect.')
|
||||
# if connect error, frequent requests will casuse
|
||||
# continuous unstable network, therefore wait here
|
||||
# to slow down the request
|
||||
self.wait()
|
||||
continue
|
||||
if raw_response.status_code == 200:
|
||||
msg = response['data']['choices'][0]['message']
|
||||
return msg
|
||||
|
||||
if (raw_response.status_code != 200):
|
||||
print(raw_response.text)
|
||||
time.sleep(1)
|
||||
continue
|
||||
print(response)
|
||||
max_num_retries += 1
|
||||
|
||||
raise RuntimeError(raw_response.text)
|
@ -1,5 +1,4 @@
|
||||
import json
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
@ -120,37 +119,6 @@ class XunFei(BaseAPIModel):
|
||||
self.flush()
|
||||
return results
|
||||
|
||||
def flush(self):
|
||||
"""Flush stdout and stderr when concurrent resources exists.
|
||||
|
||||
When use multiproessing with standard io rediected to files, need to
|
||||
flush internal information for examination or log loss when system
|
||||
breaks.
|
||||
"""
|
||||
if hasattr(self, 'tokens'):
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
def acquire(self):
|
||||
"""Acquire concurrent resources if exists.
|
||||
|
||||
This behavior will fall back to wait with query_per_second if there are
|
||||
no concurrent resources.
|
||||
"""
|
||||
if hasattr(self, 'tokens'):
|
||||
self.tokens.acquire()
|
||||
else:
|
||||
self.wait()
|
||||
|
||||
def release(self):
|
||||
"""Release concurrent resources if acquired.
|
||||
|
||||
This behavior will fall back to do nothing if there are no concurrent
|
||||
resources.
|
||||
"""
|
||||
if hasattr(self, 'tokens'):
|
||||
self.tokens.release()
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
input: str or PromptList,
|
||||
|
@ -1,4 +1,3 @@
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
@ -66,37 +65,6 @@ class ZhiPuAI(BaseAPIModel):
|
||||
self.flush()
|
||||
return results
|
||||
|
||||
def flush(self):
|
||||
"""Flush stdout and stderr when concurrent resources exists.
|
||||
|
||||
When use multiproessing with standard io rediected to files, need to
|
||||
flush internal information for examination or log loss when system
|
||||
breaks.
|
||||
"""
|
||||
if hasattr(self, 'tokens'):
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
def acquire(self):
|
||||
"""Acquire concurrent resources if exists.
|
||||
|
||||
This behavior will fall back to wait with query_per_second if there are
|
||||
no concurrent resources.
|
||||
"""
|
||||
if hasattr(self, 'tokens'):
|
||||
self.tokens.acquire()
|
||||
else:
|
||||
self.wait()
|
||||
|
||||
def release(self):
|
||||
"""Release concurrent resources if acquired.
|
||||
|
||||
This behavior will fall back to do nothing if there are no concurrent
|
||||
resources.
|
||||
"""
|
||||
if hasattr(self, 'tokens'):
|
||||
self.tokens.release()
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
input: str or PromptList,
|
||||
|
@ -1,2 +1,4 @@
|
||||
sseclient-py==1.7.2
|
||||
volcengine # bytedance
|
||||
websocket-client
|
||||
zhipuai
|
||||
zhipuai # zhipu
|
||||
|
Loading…
Reference in New Issue
Block a user