Compare commits

...

4 Commits
main ... test

Author SHA1 Message Date
8520a8ff96 update digithuman 2024-12-09 07:31:03 +00:00
8ff3069eac 增加音频驱动服务 2024-12-09 11:49:46 +08:00
0c2fb4793c 数字人服务提交 2024-12-09 08:51:09 +08:00
5c82076d4f 测试示例 2024-12-06 15:33:13 +08:00
24 changed files with 827 additions and 0 deletions

View File

@ -0,0 +1,29 @@
#!/bin/bash
# 定义logs文件夹的名称
LOG_DIR="logs"
cd /data/redserver/red-agent-service/scene-digit-human
export HAIRUO_ENV=prod
# 检查logs文件夹是否存在
if [ ! -d "$LOG_DIR" ]; then
# 如果logs文件夹不存在则创建它
echo "Creating logs directory..."
mkdir "$LOG_DIR"
fi
# 设置conda环境名
conda_env_name="agent-common"
echo "staring upload_api..."
# 使用命令替换和if语句来检查conda环境中是否存在
if conda env list | grep -q "$conda_env_name"; then
# 如果存在,则先杀掉进程
ps -ef | grep upload_api | grep -v grep | awk '{print $2}'| xargs kill -9 2>/dev/null || true
python /data/config-manager/generate_service_configs.py --service_config_info_path configs/config-vars.yml --config_path configs/cfg.yml
conda run -n "$conda_env_name" gunicorn main:app -n digithuiman_api -c digithuman.conf.py --daemon
echo "Conda environment name is set to: $conda_env_name"
echo "Gunicorn started in the background in $conda_env_name environment."
else
# 如果不存在,则打印找不到
echo "can not find : $conda_env_name"
fi

26
common/__init__.py Normal file
View File

@ -0,0 +1,26 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2024 INSPUR Inc. (inspur.com)
#
# @author: J.G. Chen <chenjianguo@inspur.com>
# @date: 2024/06/14
#
import os
from enum import Enum
__version__ = "v1.4.2"
class HairuoEnv(str, Enum):
UNK = "unk"
TEST = "test"
DEV = "dev"
PROD = "prod"
# 默认所有的机器需要设置 "HAIRUO_ENV" 环境变量, test/dev/prod
HAIRUO_ENV = HairuoEnv(os.getenv("HAIRUO_ENV", "unk"))
assert HAIRUO_ENV != HairuoEnv.UNK, "env var `HAIRUO_ENV` is required."

218
common/config.py Normal file
View File

@ -0,0 +1,218 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2024 INSPUR Inc. (inspur.com)
#
# @author: J.G. Chen <chenjianguo@inspur.com>
# @date: 2024/02/17
#
"""
对配置文件的一些操作
Notes
- pyyaml:
* 对科学计数法的支持有特殊要求对于 a[eE][+-]ba 必须有小数点指数必须含正负号
`pyyaml issues#173 <https://github.com/yaml/pyyaml/issues/173#issuecomment-507918276>`_;
* 或者使用 ruamel.yaml 代替
- conf(hocon):
* 变量替换不能添加引号`{ b = "hello"\n a = ${b} world }`;
* 环境变量'{ a = ${HOME} }';
"""
import json
import os
from pathlib import Path
from platform import python_version
from pprint import pformat
from typing import Dict
from typing import Union
import yaml
from omegaconf import DictConfig
from omegaconf import ListConfig
from omegaconf import OmegaConf
from packaging import version
from pyhocon import ConfigFactory
from pyhocon import HOCONConverter
from common import HAIRUO_ENV
from common import HairuoEnv
try:
import torch
except ImportError:
torch = None
try:
from common.utils.logger import init_logger
logger = init_logger(__name__).info
except ImportError:
init_logger = None
logger = print
class ConfigBase:
config: Union[DictConfig, ListConfig]
def __init__(self, config_file: str = ""):
"""load config from `config_file`
Args:
config_file(str): config_file name, could be `.yaml`, `.conf`, `.json`, `.bin`
Notes:
config_file is relative to current work dir(cwd), pass to abs path
"""
self._file_path = self._resolve(config_file)
self.config = self.load(self._file_path)
@classmethod
def _resolve(cls, config_file: str = ""):
cwd = Path.cwd()
valid_extension = (".yaml", ".yml", ".conf", ".json", ".bin")
# check config file in current work dir
if not config_file:
files = filter(lambda f: f.is_file(), cwd.iterdir())
logger(f"loading config file from dir: {cwd.resolve()}")
for filename in sorted(files, key=lambda x: os.path.getmtime(x), reverse=True):
if filename.suffix in valid_extension:
config_file = Path(cwd, filename)
break
else:
# parse config file specified.
config_file = Path(config_file)
# check config file
if not (config_file and config_file.is_file() and config_file.suffix in valid_extension):
raise FileNotFoundError(f"Config '{config_file}' not find!")
logger(f"loading config from {config_file.resolve()}")
return config_file
@classmethod
def show(cls, config: Union[Dict, DictConfig, ListConfig] = None):
"""show config given or parsed self.config."""
kwargs = {} if version.parse(python_version()) < version.parse("3.8") else {"sort_dicts": False}
if not config:
config = cls.config
if isinstance(config, (DictConfig, ListConfig)):
config = OmegaConf.to_object(config)
logger(f"======= resolved config ======>\n{pformat(config, **kwargs).encode('utf-8')}")
@classmethod
def load(cls, path: Union[str, Path]) -> Union[DictConfig, ListConfig]:
"""load config file."""
path = Path(path)
logger(f"loading config from: {path}")
if path.suffix in [".yaml", ".yml"]:
with path.open(encoding="utf-8") as f:
config = OmegaConf.create(yaml.load(f, Loader=yaml.FullLoader), flags={"allow_objects": True})
elif path.suffix == ".conf":
config = ConfigFactory.parse_file(path.as_posix())
config = OmegaConf.create(HOCONConverter.to_json(config))
elif path.suffix == ".json":
with path.open() as f:
config = OmegaConf.create(json.load(f))
elif path.suffix == ".bin":
if torch is not None:
config = OmegaConf.create(torch.load(path))
else:
raise RuntimeError("`torch` required to load .bin config file.")
else:
raise RuntimeError("unsupported file format to load.")
logger(f"parse configs for ENV: {HAIRUO_ENV}")
# keep configs for given env only
# TODO: update fields recursively
config.update(config.get(HAIRUO_ENV, dict()))
for env in HairuoEnv:
config.pop(env, default=None)
ConfigBase.show(config)
return config
@classmethod
def save_config(cls, config: Union[dict, DictConfig, ListConfig], path: Union[str, Path]):
"""save `config` to given `path`"""
path = Path(path)
def convert_to_ct(_config):
# convert dict to ConfigTree recursively.
if isinstance(_config, (dict, DictConfig)):
tmp = {}
for k, v in _config.items():
tmp[k] = convert_to_ct(v)
return ConfigFactory.from_dict(tmp)
elif isinstance(_config, ListConfig):
return [convert_to_ct(i) for i in _config]
else:
return _config
if path.suffix in [".yaml", ".yml"]:
OmegaConf.save(config, path)
elif path.suffix == ".conf":
with open(path, "w") as writer:
writer.write(HOCONConverter.to_hocon(convert_to_ct(config)))
writer.write("\n")
elif path.suffix == ".json":
with open(path, "w") as writer:
json.dump(config, writer, ensure_ascii=False, indent=2)
elif path.suffix == ".bin":
if torch is not None:
torch.save(config, path)
else:
raise RuntimeError("`torch` required to load .bin config file.")
else:
raise RuntimeError("unsupported file format.")
logger(f"saving config to: {path}")
def save(self, path: Union[str, Path] = "", key: str = ""):
"""save resolved config obj to `path` or path provided by `key` in config file.
Args:
path: save
key: the `path` is resolved from `key` in resolved config.
"""
if path:
pass
elif key and self.config.get(key, ""):
path = self.config.get(key)
else:
raise RuntimeError("target dir not found")
if Path(path).is_dir():
path = Path(path) / self._file_path.name
else:
path = Path(path)
self.save_config(self.config, path)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="tool to convert config file format",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
args = parser.add_argument
args("path", metavar="config_file", help="path to config file.")
args("-f", "--format", type=str, default="conf", choices=["conf", "yaml", "json"], help="save conf in format.")
args("-o", "--output", type=str, help="output config to file.")
args("-s", "--show", action="store_true", help="just show the resolved config.")
params = parser.parse_args()
conf = ConfigBase.load(params.path)
output = params.output
if not output:
output = Path(params.path).with_suffix(f".{params.format.strip('.')}")
if params.show:
ConfigBase.show(conf)
else:
ConfigBase.save_config(conf, output)

40
common/security_check.py Normal file
View File

@ -0,0 +1,40 @@
import requests
def security_check(SECURITY_URL: str, auth_code: str, question: str, isRejection: bool, isRefusal: bool):
'''
SECURITY_URL: 拒识拒答url
auth_code: token
question: 被检测的字符串
isRejection: 是否进行拒识检测
isRefusal: 是否进行拒答检测
threshold: 知识库检索最低阈值
'''
headers = {
'content-type': "application/json",
'authorization': auth_code,
}
security_json = {
"query": question,
"isRejection": isRejection,
"isRefusal": isRefusal,
}
try:
security_res = requests.post(SECURITY_URL, json=security_json, headers=headers)
# {'code': 2, 'message': '拒识检测未通过!', 'result': False}
security_res_json = security_res.json()
if not security_res_json['result']:
return {
"result": False,
"msg": security_res_json["message"]
}
elif security_res_json['result']:
return {
"result": True,
"msg": security_res_json["message"]
}
except Exception as e:
return {
"result": False,
"msg": f"error: {e}"
}

87
common/use_model.py Normal file
View File

@ -0,0 +1,87 @@
import sys
sys.path.append("..")
from utils.utils import get_logger
import time
import requests
import json
logger = get_logger("hairuo_general_vllm")
def ask_question(query_url, messages, stream):
if stream:
return vllm_chat_stream(query_url, messages)
else:
return vllm_chat_non_flow(query_url, messages)
def vllm_chat_non_flow(query_url, messages):
try:
T1 = time.time()
inter_q = [
{
"role": "system",
"content": "You are a helpful assistant."
}
] + messages
post_json = {
"model": "general_model",
"stream": False,
"stop": "<|im_end|>",
"messages": inter_q
}
headers = {
'Content-Type': 'application/json'
}
response = requests.post(query_url, headers=headers, json=post_json, timeout=60)
T2 = time.time()
logger.info('程序运行时间:%s毫秒' % ((T2 - T1) * 1000))
logger.info(f'Ask result:{response.json()}')
if response.status_code == 200:
data_ok = response.json()
first_choice = data_ok['choices'][0]
message_content = first_choice['message']['content']
return message_content
else:
return response.json()
except Exception as e:
logger.info(f"Ask occur an error: {e}")
return "Model running abnormally!"
def vllm_chat_stream(query_url, messages):
inter_q = [
{
"role": "system",
"content": "You are a helpful assistant."
}
] + messages
req_body = {
"model": 'general_model',
"stream": True,
"stop": "<|im_end|>",
"messages": inter_q
}
headers = {'Content-Type': 'application/json;charset=utf-8'}
response = requests.post(query_url, json=req_body, headers=headers, stream=True)
answer = ""
# answer_result = []
for chunk in response.iter_lines():
if chunk:
chunk = chunk.decode('utf-8')
chunk = chunk.replace("data: ", "")
try:
try:
chunk = json.loads(chunk)
except:
logger.info("-----最后一条------------------------------------")
if type(chunk) != str:
if chunk['choices'][0]['delta'].get('content', '') != '':
answer = answer + chunk['choices'][0]['delta']['content']
if chunk['choices'][0]["finish_reason"] != "stop":
yield answer
except Exception as e:
logger.info("流式输出error: %s", e)
return answer

10
configs/cfg.yml Normal file
View File

@ -0,0 +1,10 @@
# 数字人
dighthuman:
dev:
dh_webui: http://100.200.128.72:14041/agentstore/api/v1/multimodal_models/dh/dighthuman
test:
dh_webui: http://100.200.128.72:14041/agentstore/api/v1/multimodal_models/dh/dighthuman
prod:
dh_webui: http://100.200.128.72:14041/agentstore/api/v1/multimodal_models/dh/dighthuman
model_name: HaiRuo-AudioVisual-ST
static_token: 7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c

View File

@ -0,0 +1,9 @@
dighthuman:
dev:
dh_webui: {{ red_server.dighthuman.dev.dh_webui }}
test:
dh_webui: {{ red_server.dighthuman.test.dh_webui }}
prod:
dh_webui: {{ red_server.dighthuman.prod.dh_webui }}
model_name: {{ red_server.dighthuman.model_name }}
static_token: {{ global.authorization.static_token }}

View File

@ -0,0 +1,10 @@
red_server:
dighthuman:
dev:
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
test:
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
prod:
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
model_name: HaiRuo-AudioVisual-ST
static_token: 7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c

View File

@ -0,0 +1,10 @@
red_server:
dighthuman:
dev:
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
test:
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
prod:
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
model_name: HaiRuo-AudioVisual-ST
static_token: 7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c

View File

@ -0,0 +1,10 @@
red_server:
dighthuman:
dev:
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
test:
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
prod:
dh_webui: http://127.0.0.1:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
model_name: HaiRuo-AudioVisual-ST
static_token: 7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c

View File

@ -0,0 +1,10 @@
red_server:
dighthuman:
dev:
dh_webui: http://100.200.128.83:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
test:
dh_webui: http://100.200.128.83:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
prod:
dh_webui: http://100.200.128.83:14040/agentstore/api/v1/multimodal_models/dh/dighthuman
model_name: HaiRuo-AudioVisual-ST
static_token: 7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c

30
digit_human.conf.py Normal file
View File

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
import os
path_of_current_file = os.path.abspath(__file__)
path_of_current_dir = os.path.split(path_of_current_file)[0]
# worker_class为sync会报错
# uvicorn.workers.UvicornWorker
worker_class = 'uvicorn.workers.UvicornWorker'
# workers = multiprocessing.cpu_count() * 2 + 1
workers = 1 # 按需启动的进程数
threads = 1 # 各进程包含的线程数
chdir = path_of_current_dir
worker_connections = 0
timeout = 0
max_requests = 0
graceful_timeout = 0
loglevel = 'info'
access_log_format = '%(t)s %(p)s %(h)s "%(r)s" %(s)s %(L)s %(b)s %(f)s" "%(a)s"'
reload = True
debug = False
bind = "%s:%s" % ("0.0.0.0", 14040)
pidfile = '%s/digithuiman.pid' % (path_of_current_dir)
errorlog = '%s/logs/digithuiman.log' % (path_of_current_dir)
accesslog = '%s/logs/digithuiman_access.log' % (path_of_current_dir)
proc_name = "digithuiman_api"

94
main.py Normal file
View File

@ -0,0 +1,94 @@
# -*- encoding: utf-8 -*-
'''
@Email : liaoxiju@inspur.com
'''
import edge_tts
import os
import re
import sys
sys.path.append("../")
sys.path.append("../..")
import datetime
from datetime import timedelta
import logging
import yaml
import uvicorn
from common import HAIRUO_ENV
from common import HairuoEnv
from fastapi import FastAPI, Request
from agent_common_utils.logger import get_logger
from agent_common_utils.function_monitor import log_function_call
from model_utils import audio_driven_video
import tempfile
logger = get_logger("digithuman")
dh_audio_driven_video = log_function_call(logger)(audio_driven_video)
#dh_audio_driven_video = audio_driven_video
app = FastAPI()
def upload_to_s3(filepath):
os.system(f's3cmd put --expiry-days 10 {filepath} -r s3://ihp/tmp/liao/service/edge-tts/result/')
def text_to_wav(text):
save_path = "result"
if not os.path.exists(save_path):
os.system(f"mkdir -p {save_path}")
if text is None or type(text) != str or len(text) == 0:
raise "input text is Error"
with tempfile.TemporaryDirectory(dir=save_path) as tmp_dir:
output_stream = edge_tts.Communicate(text.strip(), "zh-CN-XiaoyiNeural")
wav_save_path = f"{tmp_dir}/text2audio.wav"
output_stream.save_sync(wav_save_path)
#return save_path, tmp_dir
upload_to_s3(tmp_dir)
return f"https://ihp.oss.cn-north-4.inspurcloudoss.com/tmp/liao/service/edge-tts/{wav_save_path}"
@app.post("/hairuo/digithuman")
async def digithuman(request: Request, req: dict):
'''
对输入文本文本描述生成wav音频然后调用dh-webui生成视频
'''
logger.info("<digithuman> Verification authorization.")
body_data = await request.body()
body_data = body_data.decode("utf-8")
logger.info("<body> {}".format(body_data))
logger.info("<digithuman> input text.")
try:
text = req.get("text")
##edge tts生成音频上传oss获取音频url
wav_url = text_to_wav(text)
code, video_url = dh_audio_driven_video(wav_url)
if code != 0:
logger.info("<digithuman> dh model error")
resp = {
"code": "1",
"message": "dh model error",
"result": ''
}
return resp
resp = {
"code": "0",
"message": "success",
"result": {"video_url":video_url}
}
return resp
except:
logger.info("<digithuman> model error")
resp = {
"code": "1",
"message": "dh model call error",
"result": ""
}
return resp
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=14040)

40
model_utils.py Normal file
View File

@ -0,0 +1,40 @@
# -*- encoding: utf-8 -*-
'''
@File : model_utils.py
@Time : 2024/07/24 10:28:28
@Author : liangzz1991
@Email : zhaoliang03@inspur.com
'''
import json
import time
import yaml
import requests
from common.config import ConfigBase
from common import HAIRUO_ENV
from common import HairuoEnv
def get_cfg(config):
config = config.dighthuman
config.update(config.get(HAIRUO_ENV, dict()))
for env in HairuoEnv:
config.pop(env, default=None)
ConfigBase.show(config)
return config
configs = ConfigBase.load('configs/cfg.yml')
configs = get_cfg(configs)
STATIC_TOKEN = configs['static_token']
DH_URL = configs['dh_webui']
model_name = configs['model_name']
def audio_driven_video(wav_url):
data = {
"audio_url": wav_url
}
response = requests.post(DH_URL, data=json.dumps(data),timeout=60,headers={'Content-Type': 'application/json', 'Authorization': STATIC_TOKEN})
if response.status_code != 200:
return 1, ""
else:
video_url = response.json()['result']
return 0, video_url

View File

0
requirements.txt Normal file
View File

BIN
sadtalker-server/nv.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 185 KiB

View File

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
import os
path_of_current_file = os.path.abspath(__file__)
path_of_current_dir = os.path.split(path_of_current_file)[0]
# worker_class为sync会报错
# uvicorn.workers.UvicornWorker
worker_class = 'uvicorn.workers.UvicornWorker'
# workers = multiprocessing.cpu_count() * 2 + 1
workers = 1 # 按需启动的进程数
threads = 1 # 各进程包含的线程数
chdir = path_of_current_dir
worker_connections = 0
timeout = 0
max_requests = 0
graceful_timeout = 0
loglevel = 'info'
access_log_format = '%(t)s %(p)s %(h)s "%(r)s" %(s)s %(L)s %(b)s %(f)s" "%(a)s"'
reload = True
debug = False
bind = "%s:%s" % ("0.0.0.0", 14041)
pidfile = '%s/sadtalker.pid' % (path_of_current_dir)
errorlog = '%s/sadtalker.log' % (path_of_current_dir)
accesslog = '%s/sadtalker_access.log' % (path_of_current_dir)
proc_name = "sadtalker_api"

View File

@ -0,0 +1,107 @@
# -*- encoding: utf-8 -*-
'''
@Email : liaoxiju@inspur.com
'''
import os
import re
import sys
import uvicorn
from fastapi import FastAPI, Request
from modelscope.pipelines import pipeline
import requests
import tempfile
import torch
import time
if torch.cuda.is_available():
torch.cuda.set_device(0)
app = FastAPI()
model_path = os.getenv('MODEL_PATH','./wwd123/sadtalker')
inference = pipeline('talking-head', model=model_path, model_revision='v1.0.0', device='cuda', use_gpu=True, cache="./")
def upload_to_s3(filepath):
timestap = int(time.time())
os.system(f's3cmd put --expiry-days 10 {filepath} -r s3://ihp/tmp/liao/service/sadtalker/results/{timestap}.mp4')
return f"results/{timestap}.mp4"
def download_wav(wav_url):
try:
if not os.path.exists("results"):
os.system("mkdir -p results")
res = requests.get(wav_url)
content = res.content
tmp_dir = tempfile.TemporaryDirectory(dir="results").name
if not os.path.exists(tmp_dir):
os.system(f"mkdir -p {tmp_dir}")
wav_path = os.path.join(tmp_dir, "input_audio.wav")
fout = open(wav_path, "wb")
fout.write(content)
fout.close()
return wav_path, tmp_dir
except:
import traceback
print("download error", traceback.format_exc())
return None
#@app.post("/hairuo/audiodrivenvido")
@app.post("/agentstore/api/v1/multimodal_models/dh/dighthuman")
async def digithuman(request: Request, req: dict):
'''
获取音频url-wav音频然后使用sadtalker
'''
body_data = await request.body()
body_data = body_data.decode("utf-8")
try:
timestamp = int(time.time())
save_path = f"results_{timestamp}"
if not os.path.exists(save_path):
os.system(f"mkdir -p {save_path}")
audio_url = req.get("audio_url")
wav_path, tmp_dir = download_wav(audio_url)
source_image = 'nv.jpg'
kwargs = {
'preprocess' : 'crop', # 'crop', 'resize', 'full'
'still_mode' : True,
'use_enhancer' : False,
'batch_size' : 1,
'size' : 256, # 256, 512
'pose_style' : 0,
'exp_scale' : 1,
'result_dir': save_path
}
video_path = inference(source_image, driven_audio=wav_path, **kwargs)
video_name = upload_to_s3(video_path)
##清除临时文件
os.system(f"rm -rf {tmp_dir}")
os.system(f"rm {video_path}")
os.system(f"rm -rf {save_path}")
video_url = os.path.join("https://ihp.oss.cn-north-4.inspurcloudoss.com/tmp/liao/service/sadtalker/", video_name)
resp = {
"code": "0",
"message": "success",
"result": video_url
}
return resp
except:
import traceback
print(traceback.format_exc())
resp = {
"code": "1",
"message": "audiodrivenvido model call error",
"result": ""
}
return resp
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=14041)

View File

@ -0,0 +1,25 @@
#!/bin/bash
sconfig_common
# 检查命令是否存在
if npu-smi >/dev/null 2>&1; then
echo "npu-smi命令存在, 使用ASCEND_RT_VISIBLE_DEVICES执行相关操作..."
# 在这里执行需要的操作
ASCEND_RT_VISIBLE_DEVICES=1 gunicorn --env MODEL_PATH="./wwd123/sadtalker" sadtalker_server:app -n sadtalker_api -c sadtalker.conf.py --daemon
exit 0
else
echo "npu-smi命令不存在, 跳过该部分脚本。"
fi
# 检查命令是否存在
if cnmon >/dev/null 2>&1; then
echo "cnmon命令存在, 使用MLU_VISIBLE_DEVICES执行相关操作..."
# 在这里执行需要的操作
MLU_VISIBLE_DEVICES=1 gunicorn --env MODEL_PATH="./wwd123/sadtalker" sadtalker:app -n sadtalker_api -c sadtalker.conf.py --daemon
exit 0
else
echo "cnmon命令不存在, 跳过该部分脚本。"
fi
CUDA_VISIBLE_DEVICES=1 gunicorn --env MODEL_PATH="./wwd123/sadtalker" sadtalker_server:app -n sadtalker_api -c sadtalker.conf.py --daemon

View File

@ -0,0 +1 @@
ps -ef | grep sadtalker_api | grep -v grep | awk '{print $2}'| xargs kill -9

8
start.sh Normal file
View File

@ -0,0 +1,8 @@
SELF_DIR=$(cd $(dirname "$0"); pwd)
PROJECT_ROOT=${SELF_DIR}/..
export PYTHONPATH=$PROJECT_ROOT:$PYTHONPATH
cd $PROJECT_ROOT/scene-digit-human
# python main.py
mkdir logs
python /data/config-manager/generate_service_configs.py --service_config_info_path configs/config-vars.yml --config_path configs/cfg.yml
gunicorn main:app -n digithuman_api -c digit_human.conf.py --daemon

1
stop.sh Normal file
View File

@ -0,0 +1 @@
ps -ef | grep digithuman_api | grep -v grep | awk '{print $2}'| xargs kill -9

32
test.py Normal file
View File

@ -0,0 +1,32 @@
"""
@Email: liaoxiju@inspur.com
"""
from modelscope.piplines import pipeline
# Create a pipeline instance for talking head generation using the specified model and revision.
inference = pipeline('talking-head', model='./wwd123/sadtalker', model_revision='v1.0.0')
# Define the input source image and audio file paths.
source_image = "liao.jpg"
driven_audio = "xx_cn.wav"
# Set the output directory where results will be saved.
out_dir = "./results/"
# Configure various parameters for the inference process:
kwargs = {
'preprocess': 'full', # Options are 'crop', 'resize', or 'full'
'still_mode': True,
'use_enhancer': False,
'batch_size': 1,
'size': 256, # Image size can be either 256 or 512 pixels
'pose_style': 0,
'exp_scale': 1,
'result_dir': out_dir
}
# Perform inference to generate the video from the source image and audio.
video_path = inference(source_image=source_image, driven_audio=driven_audio, **kwargs)
# Print the path of the generated video file.
print(f"==>> video_path: {video_path}")