Compare commits
4 Commits
Author | SHA1 | Date | |
---|---|---|---|
8520a8ff96 | |||
8ff3069eac | |||
0c2fb4793c | |||
5c82076d4f |
29
auto_deploy_dight_human.sh
Normal file
29
auto_deploy_dight_human.sh
Normal file
@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
# 定义logs文件夹的名称
|
||||
LOG_DIR="logs"
|
||||
|
||||
cd /data/redserver/red-agent-service/scene-digit-human
|
||||
export HAIRUO_ENV=prod
|
||||
|
||||
# 检查logs文件夹是否存在
|
||||
if [ ! -d "$LOG_DIR" ]; then
|
||||
# 如果logs文件夹不存在,则创建它
|
||||
echo "Creating logs directory..."
|
||||
mkdir "$LOG_DIR"
|
||||
fi
|
||||
# 设置conda环境名
|
||||
conda_env_name="agent-common"
|
||||
echo "staring upload_api..."
|
||||
|
||||
# 使用命令替换和if语句来检查conda环境中是否存在
|
||||
if conda env list | grep -q "$conda_env_name"; then
|
||||
# 如果存在,则先杀掉进程
|
||||
ps -ef | grep upload_api | grep -v grep | awk '{print $2}'| xargs kill -9 2>/dev/null || true
|
||||
python /data/config-manager/generate_service_configs.py --service_config_info_path configs/config-vars.yml --config_path configs/cfg.yml
|
||||
conda run -n "$conda_env_name" gunicorn main:app -n digithuiman_api -c digithuman.conf.py --daemon
|
||||
echo "Conda environment name is set to: $conda_env_name"
|
||||
echo "Gunicorn started in the background in $conda_env_name environment."
|
||||
else
|
||||
# 如果不存在,则打印找不到
|
||||
echo "can not find : $conda_env_name"
|
||||
fi
|
26
common/__init__.py
Normal file
26
common/__init__.py
Normal file
@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 INSPUR Inc. (inspur.com)
|
||||
#
|
||||
# @author: J.G. Chen <chenjianguo@inspur.com>
|
||||
# @date: 2024/06/14
|
||||
#
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
__version__ = "v1.4.2"
|
||||
|
||||
|
||||
class HairuoEnv(str, Enum):
|
||||
UNK = "unk"
|
||||
TEST = "test"
|
||||
DEV = "dev"
|
||||
PROD = "prod"
|
||||
|
||||
|
||||
# 默认所有的机器需要设置 "HAIRUO_ENV" 环境变量, test/dev/prod
|
||||
HAIRUO_ENV = HairuoEnv(os.getenv("HAIRUO_ENV", "unk"))
|
||||
|
||||
assert HAIRUO_ENV != HairuoEnv.UNK, "env var `HAIRUO_ENV` is required."
|
218
common/config.py
Normal file
218
common/config.py
Normal file
@ -0,0 +1,218 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 INSPUR Inc. (inspur.com)
|
||||
#
|
||||
# @author: J.G. Chen <chenjianguo@inspur.com>
|
||||
# @date: 2024/02/17
|
||||
#
|
||||
"""
|
||||
对配置文件的一些操作
|
||||
|
||||
Notes
|
||||
- pyyaml:
|
||||
* 对科学计数法的支持有特殊要求,对于 a[eE][+-]b:a 必须有小数点,指数必须含正负号
|
||||
`pyyaml issues#173 <https://github.com/yaml/pyyaml/issues/173#issuecomment-507918276>`_;
|
||||
* 或者使用 ruamel.yaml 代替;
|
||||
- conf(hocon):
|
||||
* 变量替换:不能添加引号,`{ b = "hello"\n a = ${b} world }`;
|
||||
* 环境变量:'{ a = ${HOME} }';
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from platform import python_version
|
||||
from pprint import pformat
|
||||
from typing import Dict
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
from omegaconf import DictConfig
|
||||
from omegaconf import ListConfig
|
||||
from omegaconf import OmegaConf
|
||||
from packaging import version
|
||||
from pyhocon import ConfigFactory
|
||||
from pyhocon import HOCONConverter
|
||||
|
||||
from common import HAIRUO_ENV
|
||||
from common import HairuoEnv
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
torch = None
|
||||
|
||||
try:
|
||||
from common.utils.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__).info
|
||||
except ImportError:
|
||||
init_logger = None
|
||||
logger = print
|
||||
|
||||
|
||||
class ConfigBase:
|
||||
config: Union[DictConfig, ListConfig]
|
||||
|
||||
def __init__(self, config_file: str = ""):
|
||||
"""load config from `config_file`
|
||||
|
||||
Args:
|
||||
config_file(str): config_file name, could be `.yaml`, `.conf`, `.json`, `.bin`
|
||||
Notes:
|
||||
config_file is relative to current work dir(cwd), pass to abs path
|
||||
"""
|
||||
self._file_path = self._resolve(config_file)
|
||||
self.config = self.load(self._file_path)
|
||||
|
||||
@classmethod
|
||||
def _resolve(cls, config_file: str = ""):
|
||||
cwd = Path.cwd()
|
||||
valid_extension = (".yaml", ".yml", ".conf", ".json", ".bin")
|
||||
|
||||
# check config file in current work dir
|
||||
if not config_file:
|
||||
files = filter(lambda f: f.is_file(), cwd.iterdir())
|
||||
logger(f"loading config file from dir: {cwd.resolve()}")
|
||||
for filename in sorted(files, key=lambda x: os.path.getmtime(x), reverse=True):
|
||||
if filename.suffix in valid_extension:
|
||||
config_file = Path(cwd, filename)
|
||||
break
|
||||
else:
|
||||
# parse config file specified.
|
||||
config_file = Path(config_file)
|
||||
|
||||
# check config file
|
||||
if not (config_file and config_file.is_file() and config_file.suffix in valid_extension):
|
||||
raise FileNotFoundError(f"Config '{config_file}' not find!")
|
||||
|
||||
logger(f"loading config from {config_file.resolve()}")
|
||||
return config_file
|
||||
|
||||
@classmethod
|
||||
def show(cls, config: Union[Dict, DictConfig, ListConfig] = None):
|
||||
"""show config given or parsed self.config."""
|
||||
kwargs = {} if version.parse(python_version()) < version.parse("3.8") else {"sort_dicts": False}
|
||||
if not config:
|
||||
config = cls.config
|
||||
|
||||
if isinstance(config, (DictConfig, ListConfig)):
|
||||
config = OmegaConf.to_object(config)
|
||||
|
||||
logger(f"======= resolved config ======>\n{pformat(config, **kwargs).encode('utf-8')}")
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path]) -> Union[DictConfig, ListConfig]:
|
||||
"""load config file."""
|
||||
path = Path(path)
|
||||
logger(f"loading config from: {path}")
|
||||
|
||||
if path.suffix in [".yaml", ".yml"]:
|
||||
with path.open(encoding="utf-8") as f:
|
||||
config = OmegaConf.create(yaml.load(f, Loader=yaml.FullLoader), flags={"allow_objects": True})
|
||||
elif path.suffix == ".conf":
|
||||
config = ConfigFactory.parse_file(path.as_posix())
|
||||
config = OmegaConf.create(HOCONConverter.to_json(config))
|
||||
elif path.suffix == ".json":
|
||||
with path.open() as f:
|
||||
config = OmegaConf.create(json.load(f))
|
||||
elif path.suffix == ".bin":
|
||||
if torch is not None:
|
||||
config = OmegaConf.create(torch.load(path))
|
||||
else:
|
||||
raise RuntimeError("`torch` required to load .bin config file.")
|
||||
else:
|
||||
raise RuntimeError("unsupported file format to load.")
|
||||
|
||||
logger(f"parse configs for ENV: {HAIRUO_ENV}")
|
||||
# keep configs for given env only
|
||||
# TODO: update fields recursively
|
||||
config.update(config.get(HAIRUO_ENV, dict()))
|
||||
for env in HairuoEnv:
|
||||
config.pop(env, default=None)
|
||||
|
||||
ConfigBase.show(config)
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def save_config(cls, config: Union[dict, DictConfig, ListConfig], path: Union[str, Path]):
|
||||
"""save `config` to given `path`"""
|
||||
path = Path(path)
|
||||
|
||||
def convert_to_ct(_config):
|
||||
# convert dict to ConfigTree recursively.
|
||||
if isinstance(_config, (dict, DictConfig)):
|
||||
tmp = {}
|
||||
for k, v in _config.items():
|
||||
tmp[k] = convert_to_ct(v)
|
||||
return ConfigFactory.from_dict(tmp)
|
||||
elif isinstance(_config, ListConfig):
|
||||
return [convert_to_ct(i) for i in _config]
|
||||
else:
|
||||
return _config
|
||||
|
||||
if path.suffix in [".yaml", ".yml"]:
|
||||
OmegaConf.save(config, path)
|
||||
elif path.suffix == ".conf":
|
||||
with open(path, "w") as writer:
|
||||
writer.write(HOCONConverter.to_hocon(convert_to_ct(config)))
|
||||
writer.write("\n")
|
||||
elif path.suffix == ".json":
|
||||
with open(path, "w") as writer:
|
||||
json.dump(config, writer, ensure_ascii=False, indent=2)
|
||||
elif path.suffix == ".bin":
|
||||
if torch is not None:
|
||||
torch.save(config, path)
|
||||
else:
|
||||
raise RuntimeError("`torch` required to load .bin config file.")
|
||||
else:
|
||||
raise RuntimeError("unsupported file format.")
|
||||
logger(f"saving config to: {path}")
|
||||
|
||||
def save(self, path: Union[str, Path] = "", key: str = ""):
|
||||
"""save resolved config obj to `path` or path provided by `key` in config file.
|
||||
|
||||
Args:
|
||||
path: save
|
||||
key: the `path` is resolved from `key` in resolved config.
|
||||
"""
|
||||
|
||||
if path:
|
||||
pass
|
||||
elif key and self.config.get(key, ""):
|
||||
path = self.config.get(key)
|
||||
else:
|
||||
raise RuntimeError("target dir not found")
|
||||
|
||||
if Path(path).is_dir():
|
||||
path = Path(path) / self._file_path.name
|
||||
else:
|
||||
path = Path(path)
|
||||
|
||||
self.save_config(self.config, path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="tool to convert config file format",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
args = parser.add_argument
|
||||
args("path", metavar="config_file", help="path to config file.")
|
||||
args("-f", "--format", type=str, default="conf", choices=["conf", "yaml", "json"], help="save conf in format.")
|
||||
args("-o", "--output", type=str, help="output config to file.")
|
||||
args("-s", "--show", action="store_true", help="just show the resolved config.")
|
||||
params = parser.parse_args()
|
||||
|
||||
conf = ConfigBase.load(params.path)
|
||||
output = params.output
|
||||
if not output:
|
||||
output = Path(params.path).with_suffix(f".{params.format.strip('.')}")
|
||||
|
||||
if params.show:
|
||||
ConfigBase.show(conf)
|
||||
else:
|
||||
ConfigBase.save_config(conf, output)
|
40
common/security_check.py
Normal file
40
common/security_check.py
Normal file
@ -0,0 +1,40 @@
|
||||
import requests
|
||||
|
||||
|
||||
def security_check(SECURITY_URL: str, auth_code: str, question: str, isRejection: bool, isRefusal: bool):
|
||||
'''
|
||||
SECURITY_URL: 拒识拒答url
|
||||
auth_code: token
|
||||
question: 被检测的字符串
|
||||
isRejection: 是否进行拒识检测
|
||||
isRefusal: 是否进行拒答检测
|
||||
threshold: 知识库检索最低阈值
|
||||
'''
|
||||
headers = {
|
||||
'content-type': "application/json",
|
||||
'authorization': auth_code,
|
||||
}
|
||||
security_json = {
|
||||
"query": question,
|
||||
"isRejection": isRejection,
|
||||
"isRefusal": isRefusal,
|
||||
}
|
||||
try:
|
||||
security_res = requests.post(SECURITY_URL, json=security_json, headers=headers)
|
||||
# {'code': 2, 'message': '拒识检测未通过!', 'result': False}
|
||||
security_res_json = security_res.json()
|
||||
if not security_res_json['result']:
|
||||
return {
|
||||
"result": False,
|
||||
"msg": security_res_json["message"]
|
||||
}
|
||||
elif security_res_json['result']:
|
||||
return {
|
||||
"result": True,
|
||||
"msg": security_res_json["message"]
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"result": False,
|
||||
"msg": f"error: {e}"
|
||||
}
|
87
common/use_model.py
Normal file
87
common/use_model.py
Normal file
@ -0,0 +1,87 @@
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
from utils.utils import get_logger
|
||||
import time
|
||||
import requests
|
||||
import json
|
||||
|
||||
logger = get_logger("hairuo_general_vllm")
|
||||
|
||||
|
||||
def ask_question(query_url, messages, stream):
|
||||
if stream:
|
||||
return vllm_chat_stream(query_url, messages)
|
||||
else:
|
||||
return vllm_chat_non_flow(query_url, messages)
|
||||
|
||||
|
||||
def vllm_chat_non_flow(query_url, messages):
|
||||
try:
|
||||
T1 = time.time()
|
||||
inter_q = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}
|
||||
] + messages
|
||||
post_json = {
|
||||
"model": "general_model",
|
||||
"stream": False,
|
||||
"stop": "<|im_end|>",
|
||||
"messages": inter_q
|
||||
}
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.post(query_url, headers=headers, json=post_json, timeout=60)
|
||||
T2 = time.time()
|
||||
logger.info('程序运行时间:%s毫秒' % ((T2 - T1) * 1000))
|
||||
logger.info(f'Ask result:{response.json()}')
|
||||
if response.status_code == 200:
|
||||
data_ok = response.json()
|
||||
first_choice = data_ok['choices'][0]
|
||||
message_content = first_choice['message']['content']
|
||||
return message_content
|
||||
else:
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.info(f"Ask occur an error: {e}")
|
||||
return "Model running abnormally!"
|
||||
|
||||
|
||||
def vllm_chat_stream(query_url, messages):
|
||||
inter_q = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}
|
||||
] + messages
|
||||
req_body = {
|
||||
"model": 'general_model',
|
||||
"stream": True,
|
||||
"stop": "<|im_end|>",
|
||||
"messages": inter_q
|
||||
}
|
||||
|
||||
headers = {'Content-Type': 'application/json;charset=utf-8'}
|
||||
response = requests.post(query_url, json=req_body, headers=headers, stream=True)
|
||||
answer = ""
|
||||
# answer_result = []
|
||||
for chunk in response.iter_lines():
|
||||
if chunk:
|
||||
chunk = chunk.decode('utf-8')
|
||||
chunk = chunk.replace("data: ", "")
|
||||
try:
|
||||
try:
|
||||
chunk = json.loads(chunk)
|
||||
except:
|
||||
logger.info("-----最后一条------------------------------------")
|
||||
if type(chunk) != str:
|
||||
if chunk['choices'][0]['delta'].get('content', '') != '':
|
||||
answer = answer + chunk['choices'][0]['delta']['content']
|
||||
if chunk['choices'][0]["finish_reason"] != "stop":
|
||||
yield answer
|
||||
except Exception as e:
|
||||
logger.info("流式输出error: %s", e)
|
||||
return answer
|
10
configs/cfg.yml
Normal file
10
configs/cfg.yml
Normal file
@ -0,0 +1,10 @@
|
||||
# 数字人
|
||||
dighthuman:
|
||||
dev:
|
||||
dh_webui: http://100.200.128.72:14041/agentstore/api/v1/multimodal_models/dh/dighthuman
|
||||
test:
|
||||
dh_webui: http://100.200.128.72:14041/agentstore/api/v1/multimodal_models/dh/dighthuman
|
||||
prod:
|
||||
dh_webui: http://100.200.128.72:14041/agentstore/api/v1/multimodal_models/dh/dighthuman
|
||||
model_name: HaiRuo-AudioVisual-ST
|
||||
static_token: 7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c
|
9
configs/config-template.j2
Normal file
9
configs/config-template.j2
Normal file
@ -0,0 +1,9 @@
|
||||
dighthuman:
|
||||
dev:
|
||||
dh_webui: {{ red_server.dighthuman.dev.dh_webui }}
|
||||
test:
|
||||
dh_webui: {{ red_server.dighthuman.test.dh_webui }}
|
||||
prod:
|
||||
dh_webui: {{ red_server.dighthuman.prod.dh_webui }}
|
||||
model_name: {{ red_server.dighthuman.model_name }}
|
||||
static_token: {{ global.authorization.static_token }}
|
10
configs/config-vars-dev.yml
Normal file
10
configs/config-vars-dev.yml
Normal file
@ -0,0 +1,10 @@
|
||||
red_server:
|
||||
dighthuman:
|
||||
dev:
|
||||
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
|
||||
test:
|
||||
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
|
||||
prod:
|
||||
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
|
||||
model_name: HaiRuo-AudioVisual-ST
|
||||
static_token: 7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c
|
10
configs/config-vars-mindie.yml
Normal file
10
configs/config-vars-mindie.yml
Normal file
@ -0,0 +1,10 @@
|
||||
red_server:
|
||||
dighthuman:
|
||||
dev:
|
||||
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
|
||||
test:
|
||||
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
|
||||
prod:
|
||||
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
|
||||
model_name: HaiRuo-AudioVisual-ST
|
||||
static_token: 7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c
|
10
configs/config-vars-prod.yml
Normal file
10
configs/config-vars-prod.yml
Normal file
@ -0,0 +1,10 @@
|
||||
red_server:
|
||||
dighthuman:
|
||||
dev:
|
||||
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
|
||||
test:
|
||||
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
|
||||
prod:
|
||||
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
|
||||
model_name: HaiRuo-AudioVisual-ST
|
||||
static_token: 7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c
|
10
configs/config-vars-test.yml
Normal file
10
configs/config-vars-test.yml
Normal file
@ -0,0 +1,10 @@
|
||||
red_server:
|
||||
dighthuman:
|
||||
dev:
|
||||
dh_webui: http://100.200.128.83:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
|
||||
test:
|
||||
dh_webui: http://100.200.128.83:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
|
||||
prod:
|
||||
dh_webui: http://100.200.128.83:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
|
||||
model_name: HaiRuo-AudioVisual-ST
|
||||
static_token: 7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c
|
30
digit_human.conf.py
Normal file
30
digit_human.conf.py
Normal file
@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
|
||||
path_of_current_file = os.path.abspath(__file__)
|
||||
path_of_current_dir = os.path.split(path_of_current_file)[0]
|
||||
|
||||
# worker_class为sync会报错
|
||||
# uvicorn.workers.UvicornWorker
|
||||
worker_class = 'uvicorn.workers.UvicornWorker'
|
||||
# workers = multiprocessing.cpu_count() * 2 + 1
|
||||
workers = 1 # 按需启动的进程数
|
||||
threads = 1 # 各进程包含的线程数
|
||||
|
||||
chdir = path_of_current_dir
|
||||
|
||||
worker_connections = 0
|
||||
timeout = 0
|
||||
max_requests = 0
|
||||
graceful_timeout = 0
|
||||
|
||||
loglevel = 'info'
|
||||
access_log_format = '%(t)s %(p)s %(h)s "%(r)s" %(s)s %(L)s %(b)s %(f)s" "%(a)s"'
|
||||
reload = True
|
||||
debug = False
|
||||
bind = "%s:%s" % ("0.0.0.0", 14040)
|
||||
pidfile = '%s/digithuiman.pid' % (path_of_current_dir)
|
||||
errorlog = '%s/logs/digithuiman.log' % (path_of_current_dir)
|
||||
accesslog = '%s/logs/digithuiman_access.log' % (path_of_current_dir)
|
||||
proc_name = "digithuiman_api"
|
94
main.py
Normal file
94
main.py
Normal file
@ -0,0 +1,94 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
'''
|
||||
@Email : liaoxiju@inspur.com
|
||||
'''
|
||||
import edge_tts
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
sys.path.append("../")
|
||||
sys.path.append("../..")
|
||||
import datetime
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
import yaml
|
||||
import uvicorn
|
||||
from common import HAIRUO_ENV
|
||||
from common import HairuoEnv
|
||||
from fastapi import FastAPI, Request
|
||||
from agent_common_utils.logger import get_logger
|
||||
from agent_common_utils.function_monitor import log_function_call
|
||||
from model_utils import audio_driven_video
|
||||
import tempfile
|
||||
|
||||
logger = get_logger("digithuman")
|
||||
|
||||
dh_audio_driven_video = log_function_call(logger)(audio_driven_video)
|
||||
#dh_audio_driven_video = audio_driven_video
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
def upload_to_s3(filepath):
|
||||
os.system(f's3cmd put --expiry-days 10 {filepath} -r s3://ihp/tmp/liao/service/edge-tts/result/')
|
||||
|
||||
def text_to_wav(text):
|
||||
save_path = "result"
|
||||
if not os.path.exists(save_path):
|
||||
os.system(f"mkdir -p {save_path}")
|
||||
if text is None or type(text) != str or len(text) == 0:
|
||||
raise "input text is Error"
|
||||
with tempfile.TemporaryDirectory(dir=save_path) as tmp_dir:
|
||||
output_stream = edge_tts.Communicate(text.strip(), "zh-CN-XiaoyiNeural")
|
||||
wav_save_path = f"{tmp_dir}/text2audio.wav"
|
||||
output_stream.save_sync(wav_save_path)
|
||||
#return save_path, tmp_dir
|
||||
upload_to_s3(tmp_dir)
|
||||
return f"https://ihp.oss.cn-north-4.inspurcloudoss.com/tmp/liao/service/edge-tts/{wav_save_path}"
|
||||
|
||||
|
||||
@app.post("/hairuo/digithuman")
|
||||
async def digithuman(request: Request, req: dict):
|
||||
'''
|
||||
对输入文本文本描述,生成wav音频,然后调用dh-webui生成视频
|
||||
'''
|
||||
|
||||
logger.info("<digithuman> Verification authorization.")
|
||||
body_data = await request.body()
|
||||
body_data = body_data.decode("utf-8")
|
||||
|
||||
logger.info("<body> {}".format(body_data))
|
||||
logger.info("<digithuman> input text.")
|
||||
try:
|
||||
text = req.get("text")
|
||||
|
||||
##edge tts生成音频,上传oss,获取音频url
|
||||
wav_url = text_to_wav(text)
|
||||
|
||||
code, video_url = dh_audio_driven_video(wav_url)
|
||||
if code != 0:
|
||||
logger.info("<digithuman> dh model error")
|
||||
resp = {
|
||||
"code": "1",
|
||||
"message": "dh model error",
|
||||
"result": ''
|
||||
}
|
||||
return resp
|
||||
resp = {
|
||||
"code": "0",
|
||||
"message": "success",
|
||||
"result": {"video_url":video_url}
|
||||
}
|
||||
return resp
|
||||
|
||||
except:
|
||||
logger.info("<digithuman> model error")
|
||||
resp = {
|
||||
"code": "1",
|
||||
"message": "dh model call error",
|
||||
"result": ""
|
||||
}
|
||||
return resp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=14040)
|
40
model_utils.py
Normal file
40
model_utils.py
Normal file
@ -0,0 +1,40 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
'''
|
||||
@File : model_utils.py
|
||||
@Time : 2024/07/24 10:28:28
|
||||
@Author : liangzz1991
|
||||
@Email : zhaoliang03@inspur.com
|
||||
'''
|
||||
import json
|
||||
import time
|
||||
import yaml
|
||||
import requests
|
||||
from common.config import ConfigBase
|
||||
from common import HAIRUO_ENV
|
||||
from common import HairuoEnv
|
||||
|
||||
def get_cfg(config):
|
||||
config = config.dighthuman
|
||||
config.update(config.get(HAIRUO_ENV, dict()))
|
||||
for env in HairuoEnv:
|
||||
config.pop(env, default=None)
|
||||
ConfigBase.show(config)
|
||||
return config
|
||||
|
||||
configs = ConfigBase.load('configs/cfg.yml')
|
||||
configs = get_cfg(configs)
|
||||
STATIC_TOKEN = configs['static_token']
|
||||
DH_URL = configs['dh_webui']
|
||||
model_name = configs['model_name']
|
||||
|
||||
def audio_driven_video(wav_url):
|
||||
data = {
|
||||
"audio_url": wav_url
|
||||
}
|
||||
response = requests.post(DH_URL, data=json.dumps(data),timeout=60,headers={'Content-Type': 'application/json', 'Authorization': STATIC_TOKEN})
|
||||
if response.status_code != 200:
|
||||
return 1, ""
|
||||
else:
|
||||
video_url = response.json()['result']
|
||||
return 0, video_url
|
||||
|
0
requirements-sadtalker.txt
Normal file
0
requirements-sadtalker.txt
Normal file
0
requirements.txt
Normal file
0
requirements.txt
Normal file
BIN
sadtalker-server/nv.jpg
Normal file
BIN
sadtalker-server/nv.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 185 KiB |
30
sadtalker-server/sadtalker.conf.py
Normal file
30
sadtalker-server/sadtalker.conf.py
Normal file
@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
|
||||
path_of_current_file = os.path.abspath(__file__)
|
||||
path_of_current_dir = os.path.split(path_of_current_file)[0]
|
||||
|
||||
# worker_class为sync会报错
|
||||
# uvicorn.workers.UvicornWorker
|
||||
worker_class = 'uvicorn.workers.UvicornWorker'
|
||||
# workers = multiprocessing.cpu_count() * 2 + 1
|
||||
workers = 1 # 按需启动的进程数
|
||||
threads = 1 # 各进程包含的线程数
|
||||
|
||||
chdir = path_of_current_dir
|
||||
|
||||
worker_connections = 0
|
||||
timeout = 0
|
||||
max_requests = 0
|
||||
graceful_timeout = 0
|
||||
|
||||
loglevel = 'info'
|
||||
access_log_format = '%(t)s %(p)s %(h)s "%(r)s" %(s)s %(L)s %(b)s %(f)s" "%(a)s"'
|
||||
reload = True
|
||||
debug = False
|
||||
bind = "%s:%s" % ("0.0.0.0", 14041)
|
||||
pidfile = '%s/sadtalker.pid' % (path_of_current_dir)
|
||||
errorlog = '%s/sadtalker.log' % (path_of_current_dir)
|
||||
accesslog = '%s/sadtalker_access.log' % (path_of_current_dir)
|
||||
proc_name = "sadtalker_api"
|
107
sadtalker-server/sadtalker_server.py
Normal file
107
sadtalker-server/sadtalker_server.py
Normal file
@ -0,0 +1,107 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
'''
|
||||
@Email : liaoxiju@inspur.com
|
||||
'''
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request
|
||||
from modelscope.pipelines import pipeline
|
||||
import requests
|
||||
import tempfile
|
||||
import torch
|
||||
import time
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
model_path = os.getenv('MODEL_PATH','./wwd123/sadtalker')
|
||||
|
||||
inference = pipeline('talking-head', model=model_path, model_revision='v1.0.0', device='cuda', use_gpu=True, cache="./")
|
||||
|
||||
def upload_to_s3(filepath):
|
||||
timestap = int(time.time())
|
||||
os.system(f's3cmd put --expiry-days 10 {filepath} -r s3://ihp/tmp/liao/service/sadtalker/results/{timestap}.mp4')
|
||||
return f"results/{timestap}.mp4"
|
||||
|
||||
def download_wav(wav_url):
|
||||
try:
|
||||
if not os.path.exists("results"):
|
||||
os.system("mkdir -p results")
|
||||
res = requests.get(wav_url)
|
||||
content = res.content
|
||||
tmp_dir = tempfile.TemporaryDirectory(dir="results").name
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.system(f"mkdir -p {tmp_dir}")
|
||||
wav_path = os.path.join(tmp_dir, "input_audio.wav")
|
||||
fout = open(wav_path, "wb")
|
||||
fout.write(content)
|
||||
fout.close()
|
||||
return wav_path, tmp_dir
|
||||
except:
|
||||
import traceback
|
||||
print("download error", traceback.format_exc())
|
||||
return None
|
||||
|
||||
#@app.post("/hairuo/audiodrivenvido")
|
||||
@app.post("/agentstore/api/v1/multimodal_models/dh/dighthuman")
|
||||
async def digithuman(request: Request, req: dict):
|
||||
'''
|
||||
获取音频url-wav音频,然后使用sadtalker
|
||||
'''
|
||||
body_data = await request.body()
|
||||
body_data = body_data.decode("utf-8")
|
||||
try:
|
||||
timestamp = int(time.time())
|
||||
save_path = f"results_{timestamp}"
|
||||
if not os.path.exists(save_path):
|
||||
os.system(f"mkdir -p {save_path}")
|
||||
|
||||
audio_url = req.get("audio_url")
|
||||
wav_path, tmp_dir = download_wav(audio_url)
|
||||
source_image = 'nv.jpg'
|
||||
|
||||
kwargs = {
|
||||
'preprocess' : 'crop', # 'crop', 'resize', 'full'
|
||||
'still_mode' : True,
|
||||
'use_enhancer' : False,
|
||||
'batch_size' : 1,
|
||||
'size' : 256, # 256, 512
|
||||
'pose_style' : 0,
|
||||
'exp_scale' : 1,
|
||||
'result_dir': save_path
|
||||
}
|
||||
|
||||
video_path = inference(source_image, driven_audio=wav_path, **kwargs)
|
||||
|
||||
video_name = upload_to_s3(video_path)
|
||||
|
||||
##清除临时文件
|
||||
os.system(f"rm -rf {tmp_dir}")
|
||||
os.system(f"rm {video_path}")
|
||||
os.system(f"rm -rf {save_path}")
|
||||
|
||||
video_url = os.path.join("https://ihp.oss.cn-north-4.inspurcloudoss.com/tmp/liao/service/sadtalker/", video_name)
|
||||
resp = {
|
||||
"code": "0",
|
||||
"message": "success",
|
||||
"result": video_url
|
||||
}
|
||||
return resp
|
||||
|
||||
except:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
resp = {
|
||||
"code": "1",
|
||||
"message": "audiodrivenvido model call error",
|
||||
"result": ""
|
||||
}
|
||||
return resp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=14041)
|
25
sadtalker-server/start_sadtalker.sh
Normal file
25
sadtalker-server/start_sadtalker.sh
Normal file
@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
sconfig_common
|
||||
|
||||
# 检查命令是否存在
|
||||
if npu-smi >/dev/null 2>&1; then
|
||||
echo "npu-smi命令存在, 使用ASCEND_RT_VISIBLE_DEVICES执行相关操作..."
|
||||
# 在这里执行需要的操作
|
||||
ASCEND_RT_VISIBLE_DEVICES=1 gunicorn --env MODEL_PATH="./wwd123/sadtalker" sadtalker_server:app -n sadtalker_api -c sadtalker.conf.py --daemon
|
||||
exit 0
|
||||
else
|
||||
echo "npu-smi命令不存在, 跳过该部分脚本。"
|
||||
fi
|
||||
|
||||
# 检查命令是否存在
|
||||
if cnmon >/dev/null 2>&1; then
|
||||
echo "cnmon命令存在, 使用MLU_VISIBLE_DEVICES执行相关操作..."
|
||||
# 在这里执行需要的操作
|
||||
MLU_VISIBLE_DEVICES=1 gunicorn --env MODEL_PATH="./wwd123/sadtalker" sadtalker:app -n sadtalker_api -c sadtalker.conf.py --daemon
|
||||
exit 0
|
||||
else
|
||||
echo "cnmon命令不存在, 跳过该部分脚本。"
|
||||
fi
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 gunicorn --env MODEL_PATH="./wwd123/sadtalker" sadtalker_server:app -n sadtalker_api -c sadtalker.conf.py --daemon
|
1
sadtalker-server/stop_sadtalker.sh
Normal file
1
sadtalker-server/stop_sadtalker.sh
Normal file
@ -0,0 +1 @@
|
||||
ps -ef | grep sadtalker_api | grep -v grep | awk '{print $2}'| xargs kill -9
|
8
start.sh
Normal file
8
start.sh
Normal file
@ -0,0 +1,8 @@
|
||||
SELF_DIR=$(cd $(dirname "$0"); pwd)
|
||||
PROJECT_ROOT=${SELF_DIR}/..
|
||||
export PYTHONPATH=$PROJECT_ROOT:$PYTHONPATH
|
||||
cd $PROJECT_ROOT/scene-digit-human
|
||||
# python main.py
|
||||
mkdir logs
|
||||
python /data/config-manager/generate_service_configs.py --service_config_info_path configs/config-vars.yml --config_path configs/cfg.yml
|
||||
gunicorn main:app -n digithuman_api -c digit_human.conf.py --daemon
|
1
stop.sh
Normal file
1
stop.sh
Normal file
@ -0,0 +1 @@
|
||||
ps -ef | grep digithuman_api | grep -v grep | awk '{print $2}'| xargs kill -9
|
32
test.py
Normal file
32
test.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""
|
||||
@Email: liaoxiju@inspur.com
|
||||
"""
|
||||
from modelscope.piplines import pipeline
|
||||
|
||||
# Create a pipeline instance for talking head generation using the specified model and revision.
|
||||
inference = pipeline('talking-head', model='./wwd123/sadtalker', model_revision='v1.0.0')
|
||||
|
||||
# Define the input source image and audio file paths.
|
||||
source_image = "liao.jpg"
|
||||
driven_audio = "xx_cn.wav"
|
||||
|
||||
# Set the output directory where results will be saved.
|
||||
out_dir = "./results/"
|
||||
|
||||
# Configure various parameters for the inference process:
|
||||
kwargs = {
|
||||
'preprocess': 'full', # Options are 'crop', 'resize', or 'full'
|
||||
'still_mode': True,
|
||||
'use_enhancer': False,
|
||||
'batch_size': 1,
|
||||
'size': 256, # Image size can be either 256 or 512 pixels
|
||||
'pose_style': 0,
|
||||
'exp_scale': 1,
|
||||
'result_dir': out_dir
|
||||
}
|
||||
|
||||
# Perform inference to generate the video from the source image and audio.
|
||||
video_path = inference(source_image=source_image, driven_audio=driven_audio, **kwargs)
|
||||
|
||||
# Print the path of the generated video file.
|
||||
print(f"==>> video_path: {video_path}")
|
Loading…
Reference in New Issue
Block a user