diff --git a/configs/api_examples/eval_api_sensetime.py b/configs/api_examples/eval_api_sensetime.py index ae392378..b2f25bbb 100644 --- a/configs/api_examples/eval_api_sensetime.py +++ b/configs/api_examples/eval_api_sensetime.py @@ -22,7 +22,22 @@ models = [ query_per_second=1, max_out_len=2048, max_seq_len=2048, - batch_size=8), + batch_size=8, + parameters={ + "temperature": 0.8, + "top_p": 0.7, + "max_new_tokens": 1024, + "repetition_penalty": 1.05, + "know_ids": [], + "stream": True, + "user": "#*#***TestUser***#*#", + "knowledge_config": { + "control_level": "normal", + "knowledge_base_result": False, + "online_search_result": False + } + } + ) ] infer = dict( diff --git a/opencompass/models/sensetime_api.py b/opencompass/models/sensetime_api.py index 0d155855..0d74f854 100644 --- a/opencompass/models/sensetime_api.py +++ b/opencompass/models/sensetime_api.py @@ -1,3 +1,5 @@ +import json +import os import time from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Optional, Union @@ -30,24 +32,32 @@ class SenseTime(BaseAPIModel): def __init__( self, path: str, - key: str, url: str, + key: str = 'ENV', query_per_second: int = 2, max_seq_len: int = 2048, meta_template: Optional[Dict] = None, retry: int = 2, + parameters: Optional[Dict] = None, ): super().__init__(path=path, max_seq_len=max_seq_len, query_per_second=query_per_second, meta_template=meta_template, retry=retry) + + if isinstance(key, str): + self.keys = os.getenv('SENSENOVA_API_KEY') if key == 'ENV' else key + else: + self.keys = key + self.headers = { 'Content-Type': 'application/json', - 'Authorization': f'Bearer {key}' + 'Authorization': f'Bearer {self.keys}' } self.url = url self.model = path + self.params = parameters def generate( self, @@ -104,38 +114,85 @@ class SenseTime(BaseAPIModel): messages.append(msg) data = {'messages': messages, 'model': self.model} + data.update(self.params) + + stream = data['stream'] max_num_retries = 0 while max_num_retries < self.retry: self.acquire() + + max_num_retries += 1 raw_response = requests.request('POST', url=self.url, headers=self.headers, json=data) - response = raw_response.json() + requests_id = raw_response.headers['X-Request-Id'] # noqa self.release() - if response is None: - print('Connection error, reconnect.') - # if connect error, frequent requests will casuse - # continuous unstable network, therefore wait here - # to slow down the request - self.wait() - continue - if raw_response.status_code == 200: - msg = response['data']['choices'][0]['message'] - return msg + if not stream: + response = raw_response.json() + + if response is None: + print('Connection error, reconnect.') + # if connect error, frequent requests will casuse + # continuous unstable network, therefore wait here + # to slow down the request + self.wait() + continue + if raw_response.status_code == 200: + msg = response['data']['choices'][0]['message'] + return msg + + if (raw_response.status_code != 200): + if response['error']['code'] == 18: + # security issue + return 'error:unsafe' + if response['error']['code'] == 17: + return 'error:too long' + else: + print(raw_response.text) + time.sleep(1) + continue + else: + # stream data to msg + raw_response.encoding = 'utf-8' + + if raw_response.status_code == 200: + response_text = raw_response.text + data_blocks = response_text.split('data:') + data_blocks = data_blocks[1:] + + first_block = json.loads(data_blocks[0]) + if first_block['status']['code'] != 0: + msg = f"error:{first_block['status']['code']}," + f" {first_block['status']['message']}" + self.logger.error(msg) + return msg + + msg = '' + for i, part in enumerate(data_blocks): + # print(f'process {i}: {part}') + try: + if part.startswith('[DONE]'): + break + + json_data = json.loads(part) + choices = json_data['data']['choices'] + for c in choices: + delta = c.get('delta') + msg += delta + except json.decoder.JSONDecodeError as err: + print(err) + self.logger.error(f'Error decoding JSON: {part}') + return msg - if (raw_response.status_code != 200): - if response['error']['code'] == 18: - # security issue - return 'error:unsafe' else: - print(raw_response.text) + print(raw_response.text, + raw_response.headers.get('X-Request-Id')) time.sleep(1) continue - print(response) - max_num_retries += 1 - - raise RuntimeError(raw_response.text) + raise RuntimeError( + f'request id: ' + f'{raw_response.headers.get("X-Request-Id")}, {raw_response.text}')