88 lines
2.8 KiB
Python
88 lines
2.8 KiB
Python
import sys
|
|
|
|
sys.path.append("..")
|
|
from utils.utils import get_logger
|
|
import time
|
|
import requests
|
|
import json
|
|
|
|
logger = get_logger("hairuo_general_vllm")
|
|
|
|
|
|
def ask_question(query_url, messages, stream):
|
|
if stream:
|
|
return vllm_chat_stream(query_url, messages)
|
|
else:
|
|
return vllm_chat_non_flow(query_url, messages)
|
|
|
|
|
|
def vllm_chat_non_flow(query_url, messages):
|
|
try:
|
|
T1 = time.time()
|
|
inter_q = [
|
|
{
|
|
"role": "system",
|
|
"content": "You are a helpful assistant."
|
|
}
|
|
] + messages
|
|
post_json = {
|
|
"model": "general_model",
|
|
"stream": False,
|
|
"stop": "<|im_end|>",
|
|
"messages": inter_q
|
|
}
|
|
headers = {
|
|
'Content-Type': 'application/json'
|
|
}
|
|
response = requests.post(query_url, headers=headers, json=post_json, timeout=60)
|
|
T2 = time.time()
|
|
logger.info('程序运行时间:%s毫秒' % ((T2 - T1) * 1000))
|
|
logger.info(f'Ask result:{response.json()}')
|
|
if response.status_code == 200:
|
|
data_ok = response.json()
|
|
first_choice = data_ok['choices'][0]
|
|
message_content = first_choice['message']['content']
|
|
return message_content
|
|
else:
|
|
return response.json()
|
|
except Exception as e:
|
|
logger.info(f"Ask occur an error: {e}")
|
|
return "Model running abnormally!"
|
|
|
|
|
|
def vllm_chat_stream(query_url, messages):
|
|
inter_q = [
|
|
{
|
|
"role": "system",
|
|
"content": "You are a helpful assistant."
|
|
}
|
|
] + messages
|
|
req_body = {
|
|
"model": 'general_model',
|
|
"stream": True,
|
|
"stop": "<|im_end|>",
|
|
"messages": inter_q
|
|
}
|
|
|
|
headers = {'Content-Type': 'application/json;charset=utf-8'}
|
|
response = requests.post(query_url, json=req_body, headers=headers, stream=True)
|
|
answer = ""
|
|
# answer_result = []
|
|
for chunk in response.iter_lines():
|
|
if chunk:
|
|
chunk = chunk.decode('utf-8')
|
|
chunk = chunk.replace("data: ", "")
|
|
try:
|
|
try:
|
|
chunk = json.loads(chunk)
|
|
except:
|
|
logger.info("-----最后一条------------------------------------")
|
|
if type(chunk) != str:
|
|
if chunk['choices'][0]['delta'].get('content', '') != '':
|
|
answer = answer + chunk['choices'][0]['delta']['content']
|
|
if chunk['choices'][0]["finish_reason"] != "stop":
|
|
yield answer
|
|
except Exception as e:
|
|
logger.info("流式输出error: %s", e)
|
|
return answer
|