Spaces:
Running
Running
#!/usr/bin/env python | |
# -*- coding: UTF-8 -*- | |
''' | |
@Project : OmniTalker | |
@File : app.py | |
@Author : zhongjian.wzj | |
@Date : 2025/4/7 19:55 | |
Copyright (c) 2025, Alibaba Cloud. All rights reserved. | |
''' | |
import os | |
import json | |
import time | |
import random | |
import requests | |
import argparse | |
import uuid | |
import gradio as gr | |
from pathlib import Path | |
local_ip = requests.get('http://myip.ipip.net', timeout=5).text | |
print('local_ip: ', local_ip) | |
url = os.getenv('OMNITALKER_URL', "http://localhost:8012") | |
headers = {"Content-Type": "application/json"} | |
script_dir = Path(__file__).parent.absolute() | |
static_folder = script_dir / "static" | |
static_folder.mkdir(parents=True, exist_ok=True) | |
result_folder = script_dir / "result" | |
result_folder.mkdir(parents=True, exist_ok=True) | |
def auto_remove(folder, max_files=1000): | |
folder = Path(folder) | |
if not folder.exists() or not folder.is_dir(): | |
return | |
files = [p for p in folder.iterdir() if p.is_file()] | |
if not files or len(files) < max_files: | |
return | |
files.sort(key=lambda x: x.stat().st_ctime) | |
# oldest_file = min(files, key=lambda x: x.stat().st_ctime) | |
for i in range(len(files) - max_files + 1): | |
oldest_file = files[i] | |
try: | |
oldest_file.unlink() | |
print(f"remove file: {oldest_file}") | |
except PermissionError: | |
print(f"permission denied: {oldest_file}") | |
except Exception as e: | |
print(f"failed: {str(e)}") | |
def predict(role, content, seed, speed): | |
data = { | |
"role": role, | |
"content": content, | |
"seed": seed, | |
"speed": speed, | |
} | |
response = requests.post(f'{url}/predict', headers=headers, data=json.dumps(data)) | |
gen_file_path = result_folder / f"result-{uuid.uuid4().hex}.mp4" | |
auto_remove(result_folder) | |
if response.status_code == 200: | |
with gen_file_path.open(mode='wb') as vid: | |
vid.write(response.content) | |
# else: | |
# raise gr.Error(response.status_code) | |
return gen_file_path | |
def generate_seed(): | |
seed = random.randint(0, 2**32 - 1) | |
return { | |
"__type__": "update", | |
"value": seed | |
} | |
def update_examples(): | |
response = requests.get(f"{url}/get_examples") | |
if response.status_code == 200: | |
examples_dict = response.json() | |
print(examples_dict.keys()) | |
else: | |
examples_dict = {} | |
examples = [] | |
for role_id, role_cfg in examples_dict.items(): | |
ref_video_path = static_folder / f'{role_id}.mp4' | |
if not ref_video_path.is_file(): | |
response = requests.get(f"{url}/get_video/", params={'role': role_id}) | |
if response.status_code == 200: | |
with ref_video_path.open(mode='wb') as vid: | |
vid.write(response.content) | |
else: | |
break | |
examples.append([role_id, ref_video_path, *list(role_cfg.values())]) | |
return { | |
"__type__": "update", | |
"samples": examples, | |
} | |
def check_http(url, timeout=5): | |
try: | |
response = requests.get(url, timeout=timeout) | |
if response.status_code == 200: | |
print(f"Succeed: {response.status_code}") | |
return True | |
else: | |
print(f"Faild: {response.status_code}") | |
return False | |
except requests.exceptions.RequestException as e: | |
print(f"Error: {e}") | |
return False | |
MAX_CONNECT_TIMES = 100 | |
for try_loop in range(MAX_CONNECT_TIMES): | |
print(f'Try: {try_loop}/{MAX_CONNECT_TIMES}') | |
if check_http(url): | |
break | |
time.sleep(10) | |
with open('style.css', 'r') as f: | |
custom_css = f.read() | |
with gr.Blocks(css=custom_css) as demo: | |
gr.Markdown("# <center> OmniTalker </center>") | |
gr.Markdown("### <center> 🏠 [project](https://humanaigc.github.io/omnitalker) 🚀[Paper](https://arxiv.org/abs/2504.02433v1) </center>") | |
gr.Markdown(''' | |
### 步骤 Steps: | |
1. 选择角色, 等待`参考视频`加载完成 (自定义角色开发中) | |
Select a character in Examples. **Waiting for `Reference Video` to load the video**. (Custom upload is currently under development.) | |
2. 输入`文本` (目前只支持中英文, 限制100字左右) | |
Enter `text` (Only **Chinese** and **English** are supported by far. **100** characters limited for performance.) | |
3. 生成 (受限于网络和资源推理速度可能达不到1:1, 感谢理解)。 | |
Generate (Due to limitations in network speed and GPU resources, the generation speed may not achieve a real-time 1:1 ratio. Appreciate.) | |
### 技巧 Tips: | |
1. 中文中数字尽量用汉字 | |
The numbers in Chinese text are best written in Chinese characters. | |
2. 尝试不同的`seed`来获取最好的结果 (设为-1则每次会自动更改) | |
Try different `Seed` to achieve the best generation(Use -1 for automatic seeding). God may not play dice, but AI does. | |
3. 适当调整语速`Speed`, 尤其是当英语参考人物讲中文时最好调慢一些(0.9-0.95) | |
Adjust `Speed` to control the speech rate. Especially when generating Chinese speech for native English speakers, recommended setting is 0.9-0.95. | |
''') | |
with gr.Group(): | |
with gr.Row(): | |
with gr.Column(): | |
reference_video = gr.Video(label='Reference Video', interactive=False) | |
with gr.Column(): | |
output_video = gr.Video(label='Output Video', streaming=False, autoplay=True) | |
with gr.Row(equal_height=True): | |
input_text = gr.Textbox(label="Input Text", lines=8, scale=5) | |
with gr.Column(scale=1): | |
with gr.Row(equal_height=True): | |
seed = gr.Number(value=-1, label="Seed", elem_classes="gradio-number") | |
btn_seed = gr.Button(value="\U0001F3B2", elem_classes="gradio-button") | |
speed = gr.Slider(0, 2, value=1, step=0.01, label="Speed", scale=2) | |
btn_run = gr.Button('Submit', variant='primary') | |
with gr.Row(equal_height=True): | |
role = gr.Textbox(label="Role", lines=1, max_lines=1, elem_classes="gradio-textbox", interactive=False) | |
btn_refresh = gr.Button(value='\U0001f504', elem_classes="gradio-button") | |
examples = gr.Examples( | |
examples=update_examples()['samples'], | |
inputs=[role, reference_video, input_text, seed, speed], | |
examples_per_page=10, | |
) | |
btn_seed.click(generate_seed, inputs=[], outputs=seed) | |
btn_refresh.click(update_examples, outputs=[examples.dataset]) | |
btn_run.click(predict, inputs=[role, input_text, seed, speed], outputs=[output_video]) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-ip", "--server_ip", type=str, default="0.0.0.0", | |
) | |
parser.add_argument( | |
"-p", "--server_port", type=int, default=7860, | |
) | |
args = parser.parse_args() | |
demo.launch(server_name=args.server_ip, server_port=args.server_port, share=False) | |