scene-digit-human/sadtalker-server/sadtalker_server.py
2024-12-09 07:31:03 +00:00

108 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- encoding: utf-8 -*-
'''
@Email : liaoxiju@inspur.com
'''
import os
import re
import sys
import uvicorn
from fastapi import FastAPI, Request
from modelscope.pipelines import pipeline
import requests
import tempfile
import torch
import time
if torch.cuda.is_available():
torch.cuda.set_device(0)
app = FastAPI()
model_path = os.getenv('MODEL_PATH','./wwd123/sadtalker')
inference = pipeline('talking-head', model=model_path, model_revision='v1.0.0', device='cuda', use_gpu=True, cache="./")
def upload_to_s3(filepath):
timestap = int(time.time())
os.system(f's3cmd put --expiry-days 10 {filepath} -r s3://ihp/tmp/liao/service/sadtalker/results/{timestap}.mp4')
return f"results/{timestap}.mp4"
def download_wav(wav_url):
try:
if not os.path.exists("results"):
os.system("mkdir -p results")
res = requests.get(wav_url)
content = res.content
tmp_dir = tempfile.TemporaryDirectory(dir="results").name
if not os.path.exists(tmp_dir):
os.system(f"mkdir -p {tmp_dir}")
wav_path = os.path.join(tmp_dir, "input_audio.wav")
fout = open(wav_path, "wb")
fout.write(content)
fout.close()
return wav_path, tmp_dir
except:
import traceback
print("download error", traceback.format_exc())
return None
#@app.post("/hairuo/audiodrivenvido")
@app.post("/agentstore/api/v1/multimodal_models/dh/dighthuman")
async def digithuman(request: Request, req: dict):
'''
获取音频url-wav音频然后使用sadtalker
'''
body_data = await request.body()
body_data = body_data.decode("utf-8")
try:
timestamp = int(time.time())
save_path = f"results_{timestamp}"
if not os.path.exists(save_path):
os.system(f"mkdir -p {save_path}")
audio_url = req.get("audio_url")
wav_path, tmp_dir = download_wav(audio_url)
source_image = 'nv.jpg'
kwargs = {
'preprocess' : 'crop', # 'crop', 'resize', 'full'
'still_mode' : True,
'use_enhancer' : False,
'batch_size' : 1,
'size' : 256, # 256, 512
'pose_style' : 0,
'exp_scale' : 1,
'result_dir': save_path
}
video_path = inference(source_image, driven_audio=wav_path, **kwargs)
video_name = upload_to_s3(video_path)
##清除临时文件
os.system(f"rm -rf {tmp_dir}")
os.system(f"rm {video_path}")
os.system(f"rm -rf {save_path}")
video_url = os.path.join("https://ihp.oss.cn-north-4.inspurcloudoss.com/tmp/liao/service/sadtalker/", video_name)
resp = {
"code": "0",
"message": "success",
"result": video_url
}
return resp
except:
import traceback
print(traceback.format_exc())
resp = {
"code": "1",
"message": "audiodrivenvido model call error",
"result": ""
}
return resp
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=14041)