scene-digit-human/common/use_model.py
2024-12-09 08:51:09 +08:00

88 lines
2.8 KiB
Python

import sys
sys.path.append("..")
from utils.utils import get_logger
import time
import requests
import json
logger = get_logger("hairuo_general_vllm")
def ask_question(query_url, messages, stream):
if stream:
return vllm_chat_stream(query_url, messages)
else:
return vllm_chat_non_flow(query_url, messages)
def vllm_chat_non_flow(query_url, messages):
try:
T1 = time.time()
inter_q = [
{
"role": "system",
"content": "You are a helpful assistant."
}
] + messages
post_json = {
"model": "general_model",
"stream": False,
"stop": "<|im_end|>",
"messages": inter_q
}
headers = {
'Content-Type': 'application/json'
}
response = requests.post(query_url, headers=headers, json=post_json, timeout=60)
T2 = time.time()
logger.info('程序运行时间:%s毫秒' % ((T2 - T1) * 1000))
logger.info(f'Ask result:{response.json()}')
if response.status_code == 200:
data_ok = response.json()
first_choice = data_ok['choices'][0]
message_content = first_choice['message']['content']
return message_content
else:
return response.json()
except Exception as e:
logger.info(f"Ask occur an error: {e}")
return "Model running abnormally!"
def vllm_chat_stream(query_url, messages):
inter_q = [
{
"role": "system",
"content": "You are a helpful assistant."
}
] + messages
req_body = {
"model": 'general_model',
"stream": True,
"stop": "<|im_end|>",
"messages": inter_q
}
headers = {'Content-Type': 'application/json;charset=utf-8'}
response = requests.post(query_url, json=req_body, headers=headers, stream=True)
answer = ""
# answer_result = []
for chunk in response.iter_lines():
if chunk:
chunk = chunk.decode('utf-8')
chunk = chunk.replace("data: ", "")
try:
try:
chunk = json.loads(chunk)
except:
logger.info("-----最后一条------------------------------------")
if type(chunk) != str:
if chunk['choices'][0]['delta'].get('content', '') != '':
answer = answer + chunk['choices'][0]['delta']['content']
if chunk['choices'][0]["finish_reason"] != "stop":
yield answer
except Exception as e:
logger.info("流式输出error: %s", e)
return answer