[Feature] Update CoreBench 2.0 (#1566)

* [Feature] 1. Update CoreBench Base\n 2. Fix lint issue in BalingAPI

* Update

* Update
This commit is contained in:
Songyang Zhang 2024-09-26 18:44:00 +08:00 committed by GitHub
parent 3f833186dc
commit a7bacfdf7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 379 additions and 114 deletions

View File

@ -15,9 +15,9 @@ datasets = [
models = [ models = [
dict( dict(
path="Bailing-Lite-0830", path='Bailing-Lite-0830',
token="xxxxxx", # set your key here or in environment variable BAILING_API_KEY token='xxxxxx', # set your key here or in environment variable BAILING_API_KEY
url="https://bailingchat.alipay.com/chat/completions", url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI, type=BailingAPI,
generation_kwargs={}, generation_kwargs={},
query_per_second=1, query_per_second=1,
@ -35,4 +35,4 @@ infer = dict(
), ),
) )
work_dir = "outputs/api_bailing/" work_dir = 'outputs/api_bailing/'

View File

@ -0,0 +1,188 @@
from mmengine.config import read_base
import os.path as osp
from opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask
#######################################################################
# PART 0 Essential Configs #
#######################################################################
with read_base():
# Datasets Part
## Core Set
# ## Examination
from opencompass.configs.datasets.mmlu.mmlu_ppl_ac766d import mmlu_datasets
from opencompass.configs.datasets.mmlu_pro.mmlu_pro_few_shot_gen_bfaf90 import \
mmlu_pro_datasets
from opencompass.configs.datasets.cmmlu.cmmlu_ppl_041cbf import \
cmmlu_datasets
# ## Reasoning
from opencompass.configs.datasets.bbh.bbh_gen_98fba6 import bbh_datasets
from opencompass.configs.datasets.hellaswag.hellaswag_10shot_ppl_59c85e import hellaswag_datasets
from opencompass.configs.datasets.drop.drop_gen_a2697c import drop_datasets
# ## Math
from opencompass.configs.datasets.math.math_4shot_base_gen_43d5b6 import math_datasets
from opencompass.configs.datasets.gsm8k.gsm8k_gen_17d0dc import gsm8k_datasets
from opencompass.configs.datasets.MathBench.mathbench_2024_few_shot_mixed_4a3fd4 import \
mathbench_datasets
# ## Scientific
from opencompass.configs.datasets.gpqa.gpqa_few_shot_ppl_2c9cd6 import \
gpqa_datasets
# ## Coding
from opencompass.configs.datasets.humaneval.deprecated_humaneval_gen_d2537e import humaneval_datasets
from opencompass.configs.datasets.mbpp.sanitized_mbpp_gen_742f0c import sanitized_mbpp_datasets
# TODO: Add LiveCodeBench
# ## Instruction Following
# from opencompass.configs.datasets.IFEval.IFEval_gen_3321a3 import ifeval_datasets
# Summarizer
from opencompass.configs.summarizers.groups.mmlu import mmlu_summary_groups
from opencompass.configs.summarizers.groups.mmlu_pro import mmlu_pro_summary_groups
from opencompass.configs.summarizers.groups.cmmlu import cmmlu_summary_groups
from opencompass.configs.summarizers.groups.bbh import bbh_summary_groups
from opencompass.configs.summarizers.groups.mathbench_v1_2024 import \
mathbench_2024_summary_groups
# Model List
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_1_5b import models as lmdeploy_qwen2_5_1_5b_model
# from opencompass.configs.models.qwen.lmdeploy_qwen2_1_5b_instruct import models as lmdeploy_qwen2_1_5b_instruct_model
# from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_7b_chat import models as hf_internlm2_5_7b_chat_model
# from opencompass.configs.models.openbmb.hf_minicpm_2b_sft_bf16 import models as hf_minicpm_2b_sft_bf16_model
# from opencompass.configs.models.yi.hf_yi_1_5_6b_chat import models as hf_yi_1_5_6b_chat_model
# from opencompass.configs.models.gemma.hf_gemma_2b_it import models as hf_gemma_2b_it_model
# from opencompass.configs.models.yi.hf_yi_1_5_34b_chat import models as hf_yi_1_5_34b_chat_model
#######################################################################
# PART 1 Datasets List #
#######################################################################
# datasets list for evaluation
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
#######################################################################
# PART 2 Datset Summarizer #
#######################################################################
# with read_base():
core_summary_groups = [
{
'name': 'core_average',
'subsets': [
['mmlu', 'accuracy'],
['mmlu_pro', 'accuracy'],
['cmmlu', 'accuracy'],
['bbh', 'naive_average'],
['hellaswag', 'accuracy'],
['drop', 'accuracy'],
['math', 'accuracy'],
['gsm8k', 'accuracy'],
['mathbench-t (average)', 'naive_average']
['GPQA_diamond', 'accuracy'],
['openai_humaneval', 'humaneval_pass@1'],
['IFEval', 'Prompt-level-strict-accuracy'],
['sanitized_mbpp', 'score'],
['mathbench-t (average)', 'naive_average']
],
},
]
summarizer = dict(
dataset_abbrs=[
['mmlu', 'accuracy'],
['mmlu_pro', 'accuracy'],
['cmmlu', 'accuracy'],
['bbh', 'naive_average'],
['hellaswag', 'accuracy'],
['drop', 'accuracy'],
['math', 'accuracy'],
['gsm8k', 'accuracy'],
['mathbench-t (average)', 'naive_average']
['GPQA_diamond', 'accuracy'],
['openai_humaneval', 'humaneval_pass@1'],
['IFEval', 'Prompt-level-strict-accuracy'],
['sanitized_mbpp', 'score'],
'mathbench-a (average)',
'mathbench-t (average)'
'',
['mmlu', 'accuracy'],
['mmlu-stem', 'accuracy'],
['mmlu-social-science', 'accuracy'],
['mmlu-humanities', 'accuracy'],
['mmlu-other', 'accuracy'],
'',
['mmlu_pro', 'accuracy'],
['mmlu_pro_math','accuracy'],
['mmlu_pro_physics', 'accuracy'],
['mmlu_pro_chemistry', 'accuracy'],
['mmlu_pro_law', 'accuracy'],
['mmlu_pro_engineering', 'accuracy'],
['mmlu_pro_other', 'accuracy'],
['mmlu_pro_economics', 'accuracy'],
['mmlu_pro_health', 'accuracy'],
['mmlu_pro_psychology', 'accuracy'],
['mmlu_pro_business', 'accuracy'],
['mmlu_pro_biology', 'accuracy'],
['mmlu_pro_philosophy', 'accuracy'],
['mmlu_pro_computer_science','accuracy'],
['mmlu_pro_history', 'accuracy'],
'',
['cmmlu', 'accuracy'],
['cmmlu-stem', 'accuracy'],
['cmmlu-social-science', 'accuracy'],
['cmmlu-humanities', 'accuracy'],
['cmmlu-other', 'accuracy'],
['cmmlu-china-specific', 'accuracy'],
],
summary_groups=sum(
[v for k, v in locals().items() if k.endswith('_summary_groups')], []),
)
#######################################################################
# PART 3 Models List #
#######################################################################
models = sum([v for k, v in locals().items() if k.endswith('_model')], [])
#######################################################################
# PART 4 Inference/Evaluation Configuaration #
#######################################################################
# Local Runner
infer = dict(
partitioner=dict(
type=NumWorkerPartitioner,
num_worker=8
),
runner=dict(
type=LocalRunner,
max_num_workers=16,
retry=0, # Modify if needed
task=dict(type=OpenICLInferTask)
),
)
# eval with local runner
eval = dict(
partitioner=dict(type=NaivePartitioner, n=10),
runner=dict(
type=LocalRunner,
max_num_workers=16,
task=dict(type=OpenICLEvalTask)),
)
#######################################################################
# PART 5 Utils Configuaration #
#######################################################################
base_exp_dir = 'outputs/corebench_2409_objective/'
work_dir = osp.join(base_exp_dir, 'chat_objective')

View File

@ -18,20 +18,22 @@ with read_base():
# ## Reasoning # ## Reasoning
from opencompass.configs.datasets.bbh.bbh_gen_4a31fa import bbh_datasets from opencompass.configs.datasets.bbh.bbh_gen_4a31fa import bbh_datasets
# TODO: Add HellaSwag from opencompass.configs.datasets.hellaswag.hellaswag_10shot_gen_e42710 import \
# TODO: Add DROP hellaswag_datasets
from opencompass.configs.datasets.drop.drop_openai_simple_evals_gen_3857b0 import drop_datasets
# ## Math # ## Math
from opencompass.configs.datasets.math.math_0shot_gen_393424 import math_datasets from opencompass.configs.datasets.math.math_0shot_gen_393424 import math_datasets
# TODO: Add GSM8K from opencompass.configs.datasets.gsm8k.gsm8k_0shot_v2_gen_a58960 import \
# TODO: Add MathBench gsm8k_datasets
from opencompass.configs.datasets.MathBench.mathbench_2024_gen_50a320 import mathbench_datasets
# ## Scientific # ## Scientific
from opencompass.configs.datasets.gpqa.gpqa_openai_simple_evals_gen_5aeece import gpqa_datasets from opencompass.configs.datasets.gpqa.gpqa_openai_simple_evals_gen_5aeece import gpqa_datasets
# ## Coding # ## Coding
from opencompass.configs.datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets from opencompass.configs.datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
# TODO: Add MBPP from opencompass.configs.datasets.mbpp.sanitized_mbpp_mdblock_gen_a447ff import sanitized_mbpp_datasets
# TODO: Add LiveCodeBench # TODO: Add LiveCodeBench
# ## Instruction Following # ## Instruction Following
@ -70,13 +72,17 @@ core_summary_groups = [
'subsets': [ 'subsets': [
['mmlu', 'accuracy'], ['mmlu', 'accuracy'],
['mmlu_pro', 'accuracy'], ['mmlu_pro', 'accuracy'],
# ['cmmlu', 'naive_average'],
['cmmlu', 'accuracy'], ['cmmlu', 'accuracy'],
['bbh', 'score'], ['bbh', 'score'],
['math', 'accuracy'], ['math', 'accuracy'],
['openai_humaneval', 'humaneval_pass@1'], ['openai_humaneval', 'humaneval_pass@1'],
['GPQA_diamond', 'accuracy'], ['GPQA_diamond', 'accuracy'],
['IFEval', 'Prompt-level-strict-accuracy'], ['IFEval', 'Prompt-level-strict-accuracy'],
['drop', 'accuracy'],
['sanitized_mbpp', 'score'],
['gsm8k', 'accuracy'],
['hellaswag', 'accuracy'],
['mathbench-t (average)', 'naive_average']
], ],
}, },
] ]
@ -92,6 +98,12 @@ summarizer = dict(
['openai_humaneval', 'humaneval_pass@1'], ['openai_humaneval', 'humaneval_pass@1'],
['GPQA_diamond', 'accuracy'], ['GPQA_diamond', 'accuracy'],
['IFEval', 'Prompt-level-strict-accuracy'], ['IFEval', 'Prompt-level-strict-accuracy'],
['drop', 'accuracy'],
['sanitized_mbpp', 'score'],
['gsm8k', 'accuracy'],
['hellaswag', 'accuracy'],
'mathbench-a (average)',
'mathbench-t (average)'
'', '',
['mmlu', 'accuracy'], ['mmlu', 'accuracy'],
@ -204,5 +216,5 @@ eval = dict(
####################################################################### #######################################################################
# PART 5 Utils Configuaration # # PART 5 Utils Configuaration #
####################################################################### #######################################################################
base_exp_dir = 'outputs/corebench/' base_exp_dir = 'outputs/corebench_2409_objective/'
work_dir = osp.join(base_exp_dir, 'chat_objective') work_dir = osp.join(base_exp_dir, 'chat_objective')

View File

@ -2,30 +2,29 @@ from opencompass.models import BailingAPI
api_meta_template = dict( api_meta_template = dict(
round=[ round=[
dict(role="HUMAN", api_role="HUMAN"), dict(role='HUMAN', api_role='HUMAN'),
dict(role="BOT", api_role="BOT", generate=False), dict(role='BOT', api_role='BOT', generate=False),
], ],
reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")], reserved_roles=[dict(role='SYSTEM', api_role='SYSTEM')],
) )
models = [ models = [
dict( dict(
path="Bailing-Lite-0830", path='Bailing-Lite-0830',
token="", # set your key here or in environment variable BAILING_API_KEY token='', # set your key here or in environment variable BAILING_API_KEY
url="https://bailingchat.alipay.com/chat/completions", url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI, type=BailingAPI,
meta_template=api_meta_template, meta_template=api_meta_template,
query_per_second=1, query_per_second=1,
max_seq_len=4096, max_seq_len=4096,
batch_size=1, batch_size=1,
generation_kwargs={ generation_kwargs={
"temperature": 0.4, 'temperature': 0.4,
"top_p": 1.0, 'top_p': 1.0,
"top_k": -1, 'top_k': -1,
"n": 1, 'n': 1,
"logprobs": 1, 'logprobs': 1,
"use_beam_search": False, 'use_beam_search': False,
}, },
), ),
] ]

View File

@ -2,30 +2,29 @@ from opencompass.models import BailingAPI
api_meta_template = dict( api_meta_template = dict(
round=[ round=[
dict(role="HUMAN", api_role="HUMAN"), dict(role='HUMAN', api_role='HUMAN'),
dict(role="BOT", api_role="BOT", generate=False), dict(role='BOT', api_role='BOT', generate=False),
], ],
reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")], reserved_roles=[dict(role='SYSTEM', api_role='SYSTEM')],
) )
models = [ models = [
dict( dict(
path="Bailing-Pro-0920", path='Bailing-Pro-0920',
token="", # set your key here or in environment variable BAILING_API_KEY token='', # set your key here or in environment variable BAILING_API_KEY
url="https://bailingchat.alipay.com/chat/completions", url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI, type=BailingAPI,
meta_template=api_meta_template, meta_template=api_meta_template,
query_per_second=1, query_per_second=1,
max_seq_len=4096, max_seq_len=4096,
batch_size=1, batch_size=1,
generation_kwargs={ generation_kwargs={
"temperature": 0.4, 'temperature': 0.4,
"top_p": 1.0, 'top_p': 1.0,
"top_k": -1, 'top_k': -1,
"n": 1, 'n': 1,
"logprobs": 1, 'logprobs': 1,
"use_beam_search": False, 'use_beam_search': False,
}, },
), ),
] ]

View File

@ -0,0 +1,15 @@
from opencompass.models import TurboMindModel
models = [
dict(
type=TurboMindModel,
abbr='qwen2.5-1.5b-turbomind',
path='Qwen/Qwen2.5-1.5B',
engine_config=dict(session_len=7168, max_batch_size=16, tp=1),
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=1024),
max_seq_len=7168,
max_out_len=1024,
batch_size=16,
run_cfg=dict(num_gpus=1),
)
]

View File

@ -0,0 +1,15 @@
from opencompass.models import TurboMindModel
models = [
dict(
type=TurboMindModel,
abbr='qwen2.5-7b-turbomind',
path='Qwen/Qwen2.5-7B',
engine_config=dict(session_len=7168, max_batch_size=16, tp=1),
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=1024),
max_seq_len=7168,
max_out_len=1024,
batch_size=16,
run_cfg=dict(num_gpus=1),
)
]

View File

@ -2,30 +2,29 @@ from opencompass.models import BailingAPI
api_meta_template = dict( api_meta_template = dict(
round=[ round=[
dict(role="HUMAN", api_role="HUMAN"), dict(role='HUMAN', api_role='HUMAN'),
dict(role="BOT", api_role="BOT", generate=False), dict(role='BOT', api_role='BOT', generate=False),
], ],
reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")], reserved_roles=[dict(role='SYSTEM', api_role='SYSTEM')],
) )
models = [ models = [
dict( dict(
path="Bailing-Lite-0830", path='Bailing-Lite-0830',
token="", # set your key here or in environment variable BAILING_API_KEY token='', # set your key here or in environment variable BAILING_API_KEY
url="https://bailingchat.alipay.com/chat/completions", url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI, type=BailingAPI,
meta_template=api_meta_template, meta_template=api_meta_template,
query_per_second=1, query_per_second=1,
max_seq_len=4096, max_seq_len=4096,
batch_size=1, batch_size=1,
generation_kwargs={ generation_kwargs={
"temperature": 0.4, 'temperature': 0.4,
"top_p": 1.0, 'top_p': 1.0,
"top_k": -1, 'top_k': -1,
"n": 1, 'n': 1,
"logprobs": 1, 'logprobs': 1,
"use_beam_search": False, 'use_beam_search': False,
}, },
), ),
] ]

View File

@ -2,30 +2,29 @@ from opencompass.models import BailingAPI
api_meta_template = dict( api_meta_template = dict(
round=[ round=[
dict(role="HUMAN", api_role="HUMAN"), dict(role='HUMAN', api_role='HUMAN'),
dict(role="BOT", api_role="BOT", generate=False), dict(role='BOT', api_role='BOT', generate=False),
], ],
reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")], reserved_roles=[dict(role='SYSTEM', api_role='SYSTEM')],
) )
models = [ models = [
dict( dict(
path="Bailing-Pro-0920", path='Bailing-Pro-0920',
token="", # set your key here or in environment variable BAILING_API_KEY token='', # set your key here or in environment variable BAILING_API_KEY
url="https://bailingchat.alipay.com/chat/completions", url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI, type=BailingAPI,
meta_template=api_meta_template, meta_template=api_meta_template,
query_per_second=1, query_per_second=1,
max_seq_len=4096, max_seq_len=4096,
batch_size=1, batch_size=1,
generation_kwargs={ generation_kwargs={
"temperature": 0.4, 'temperature': 0.4,
"top_p": 1.0, 'top_p': 1.0,
"top_k": -1, 'top_k': -1,
"n": 1, 'n': 1,
"logprobs": 1, 'logprobs': 1,
"use_beam_search": False, 'use_beam_search': False,
}, },
), ),
] ]

View File

@ -0,0 +1,15 @@
from opencompass.models import TurboMindModel
models = [
dict(
type=TurboMindModel,
abbr='qwen2.5-1.5b-turbomind',
path='Qwen/Qwen2.5-1.5B',
engine_config=dict(session_len=7168, max_batch_size=16, tp=1),
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=1024),
max_seq_len=7168,
max_out_len=1024,
batch_size=16,
run_cfg=dict(num_gpus=1),
)
]

View File

@ -0,0 +1,15 @@
from opencompass.models import TurboMindModel
models = [
dict(
type=TurboMindModel,
abbr='qwen2.5-7b-turbomind',
path='Qwen/Qwen2.5-7B',
engine_config=dict(session_len=7168, max_batch_size=16, tp=1),
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=1024),
max_seq_len=7168,
max_out_len=1024,
batch_size=16,
run_cfg=dict(num_gpus=1),
)
]

View File

@ -42,7 +42,8 @@ from .sensetime_api import SenseTime # noqa: F401
from .stepfun_api import StepFun # noqa: F401 from .stepfun_api import StepFun # noqa: F401
from .turbomind import TurboMindModel # noqa: F401 from .turbomind import TurboMindModel # noqa: F401
from .turbomind_tis import TurboMindTisModel # noqa: F401 from .turbomind_tis import TurboMindTisModel # noqa: F401
from .turbomind_with_tf_above_v4_33 import TurboMindModelwithChatTemplate # noqa: F401 from .turbomind_with_tf_above_v4_33 import \
TurboMindModelwithChatTemplate # noqa: F401
from .unigpt_api import UniGPT # noqa: F401 from .unigpt_api import UniGPT # noqa: F401
from .vllm import VLLM # noqa: F401 from .vllm import VLLM # noqa: F401
from .vllm_with_tf_above_v4_33 import VLLMwithChatTemplate # noqa: F401 from .vllm_with_tf_above_v4_33 import VLLMwithChatTemplate # noqa: F401

View File

@ -7,9 +7,14 @@ from typing import Dict, List, Optional, Union
import requests import requests
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from retrying import retry
from urllib3.connection import HTTPConnection from urllib3.connection import HTTPConnection
try:
from retrying import retry
except ImportError:
retry = None
print('please install retrying by `pip install retrying`')
from opencompass.utils.prompt import PromptList from opencompass.utils.prompt import PromptList
from .base_api import BaseAPIModel from .base_api import BaseAPIModel
@ -18,6 +23,7 @@ PromptType = Union[PromptList, str]
class HTTPAdapterWithSocketOptions(HTTPAdapter): class HTTPAdapterWithSocketOptions(HTTPAdapter):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._socket_options = HTTPConnection.default_socket_options + [ self._socket_options = HTTPConnection.default_socket_options + [
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
@ -29,8 +35,9 @@ class HTTPAdapterWithSocketOptions(HTTPAdapter):
def init_poolmanager(self, *args, **kwargs): def init_poolmanager(self, *args, **kwargs):
if self._socket_options is not None: if self._socket_options is not None:
kwargs["socket_options"] = self._socket_options kwargs['socket_options'] = self._socket_options
super(HTTPAdapterWithSocketOptions, self).init_poolmanager(*args, **kwargs) super(HTTPAdapterWithSocketOptions,
self).init_poolmanager(*args, **kwargs)
class BailingAPI(BaseAPIModel): class BailingAPI(BaseAPIModel):
@ -64,31 +71,29 @@ class BailingAPI(BaseAPIModel):
generation_kwargs=generation_kwargs, generation_kwargs=generation_kwargs,
) )
self.logger.info(f"Bailing API Model Init path: {path} url={url}") self.logger.info(f'Bailing API Model Init path: {path} url={url}')
if not token: if not token:
token = os.environ.get("BAILING_API_KEY") token = os.environ.get('BAILING_API_KEY')
if token: if token:
self._headers = {"Authorization": f"Bearer {token}"} self._headers = {'Authorization': f'Bearer {token}'}
else: else:
raise RuntimeError(f"There is not valid token.") raise RuntimeError('There is not valid token.')
self._headers["Content-Type"] = "application/json" self._headers['Content-Type'] = 'application/json'
self._url = url if url else "https://bailingchat.alipay.com/chat/completions" self._url = url if url else \
'https://bailingchat.alipay.com/chat/completions'
self._model = path self._model = path
self._sessions = [] self._sessions = []
self._num = ( self._num = (int(os.environ.get('BAILING_API_PARALLEL_NUM'))
int(os.environ.get("BAILING_API_PARALLEL_NUM")) if os.environ.get('BAILING_API_PARALLEL_NUM') else 1)
if os.environ.get("BAILING_API_PARALLEL_NUM")
else 1
)
try: try:
for _ in range(self._num): for _ in range(self._num):
adapter = HTTPAdapterWithSocketOptions() adapter = HTTPAdapterWithSocketOptions()
sess = requests.Session() sess = requests.Session()
sess.mount("http://", adapter) sess.mount('http://', adapter)
sess.mount("https://", adapter) sess.mount('https://', adapter)
self._sessions.append(sess) self._sessions.append(sess)
except Exception as e: except Exception as e:
self.logger.error(f"Fail to setup the session. {e}") self.logger.error(f'Fail to setup the session. {e}')
raise e raise e
def generate( def generate(
@ -99,7 +104,8 @@ class BailingAPI(BaseAPIModel):
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (Union[List[str], PromptList]): A list of strings or PromptDicts. inputs (Union[List[str], PromptList]):
A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' API format. The PromptDict should be organized in OpenCompass' API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
@ -107,8 +113,7 @@ class BailingAPI(BaseAPIModel):
List[str]: A list of generated strings. List[str]: A list of generated strings.
""" """
with concurrent.futures.ThreadPoolExecutor( with concurrent.futures.ThreadPoolExecutor(
max_workers=self._num, max_workers=self._num, ) as executor:
) as executor:
future_to_m = { future_to_m = {
executor.submit( executor.submit(
self._generate, self._generate,
@ -120,22 +125,22 @@ class BailingAPI(BaseAPIModel):
} }
results = [] results = []
for future in concurrent.futures.as_completed(future_to_m): for future in concurrent.futures.as_completed(future_to_m):
m = future_to_m[future] m = future_to_m[future] # noqa F841
resp = future.result() resp = future.result()
if resp and resp.status_code == 200: if resp and resp.status_code == 200:
try: try:
result = resp.json() result = resp.json()
except: except Exception as e: # noqa F841
results.append("") results.append('')
else: else:
if ( if (result.get('choices')
result.get("choices") and result['choices'][0].get('message')
and result["choices"][0].get("message") and result['choices'][0]['message'].get(
and result["choices"][0]["message"].get("content") 'content')):
): results.append(
results.append(result["choices"][0]["message"]["content"]) result['choices'][0]['message']['content'])
else: else:
results.append("") results.append('')
self.flush() self.flush()
return results return results
@ -156,27 +161,30 @@ class BailingAPI(BaseAPIModel):
str: The generated string. str: The generated string.
""" """
if isinstance(input, str): if isinstance(input, str):
messages = [{"role": "user", "content": input}] messages = [{'role': 'user', 'content': input}]
else: else:
messages = [] messages = []
for item in input: for item in input:
content = item["prompt"] content = item['prompt']
if not content: if not content:
continue continue
message = {"content": content} message = {'content': content}
if item["role"] == "HUMAN": if item['role'] == 'HUMAN':
message["role"] = "user" message['role'] = 'user'
elif item["role"] == "BOT": elif item['role'] == 'BOT':
message["role"] = "assistant" message['role'] = 'assistant'
elif item["role"] == "SYSTEM": elif item['role'] == 'SYSTEM':
message["role"] = "system" message['role'] = 'system'
else: else:
message["role"] = item["role"] message['role'] = item['role']
messages.append(message) messages.append(message)
request = { request = {
"model": self._model, 'model':
"messages": messages, self._model,
"max_seq_len": max( 'messages':
messages,
'max_seq_len':
max(
max_out_len if max_out_len else 4096, max_out_len if max_out_len else 4096,
self.max_seq_len if self.max_seq_len else 4096, self.max_seq_len if self.max_seq_len else 4096,
), ),
@ -191,22 +199,22 @@ class BailingAPI(BaseAPIModel):
elif response.status_code == 426: elif response.status_code == 426:
retry_num += 1 # retry retry_num += 1 # retry
else: else:
raise ValueError(f"Status code = {response.status_code}") raise ValueError(f'Status code = {response.status_code}')
else: else:
raise ValueError( raise ValueError(
f"Exceed the maximal retry times. Last status code = {response.status_code}" f'Exceed the maximal retry times. Last status code '
) f'= {response.status_code}')
except Exception as e: except Exception as e:
self.logger.error( self.logger.error(f'Fail to inference request={request}; '
f"Fail to inference request={request}; model_name={self.path}; error={e}, stack:{traceback.format_exc()}" f'model_name={self.path}; error={e}, '
) f'stack:{traceback.format_exc()}')
raise e raise e
return response return response
@retry(stop_max_attempt_number=3, wait_fixed=16000) # ms @retry(stop_max_attempt_number=3, wait_fixed=16000) # ms
def _infer_result(self, request, sess): def _infer_result(self, request, sess):
response = sess.request( response = sess.request(
"POST", 'POST',
self._url, self._url,
json=request, json=request,
headers=self._headers, headers=self._headers,