Mrwrichard commited on
Commit
96808ee
·
verified ·
1 Parent(s): cfb1f19

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +195 -0
  2. style.css +91 -0
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: UTF-8 -*-
3
+ '''
4
+ @Project : OmniTalker_HF
5
+ @File : app_api.py
6
+ @Author : zhongjian.wzj
7
+ @Date : 2025/4/7 19:55
8
+ Copyright (c) 2025, Alibaba Cloud. All rights reserved.
9
+ '''
10
+ import os
11
+ import json
12
+ import time
13
+ import random
14
+ import requests
15
+ import argparse
16
+ import uuid
17
+ import gradio as gr
18
+ from pathlib import Path
19
+
20
+ local_ip = requests.get('http://myip.ipip.net', timeout=5).text
21
+ print('local_ip: ', local_ip)
22
+
23
+ url = os.getenv('OMNITALKER_URL', "http://localhost:8012")
24
+ headers = {"Content-Type": "application/json"}
25
+
26
+ script_dir = Path(__file__).parent.absolute()
27
+ static_folder = script_dir / "static"
28
+ static_folder.mkdir(parents=True, exist_ok=True)
29
+ result_folder = script_dir / "result"
30
+ result_folder.mkdir(parents=True, exist_ok=True)
31
+
32
+
33
+ def auto_remove(folder, max_files=1000):
34
+ folder = Path(folder)
35
+
36
+ if not folder.exists() or not folder.is_dir():
37
+ return
38
+
39
+ files = [p for p in folder.iterdir() if p.is_file()]
40
+
41
+ if not files or len(files) < max_files:
42
+ return
43
+
44
+ files.sort(key=lambda x: x.stat().st_ctime)
45
+
46
+ # oldest_file = min(files, key=lambda x: x.stat().st_ctime)
47
+ for i in range(len(files) - max_files + 1):
48
+ oldest_file = files[i]
49
+ try:
50
+ oldest_file.unlink()
51
+ print(f"remove file: {oldest_file}")
52
+ except PermissionError:
53
+ print(f"permission denied: {oldest_file}")
54
+ except Exception as e:
55
+ print(f"failed: {str(e)}")
56
+
57
+
58
+ def predict(role, content, seed, speed):
59
+ data = {
60
+ "role": role,
61
+ "content": content,
62
+ "seed": seed,
63
+ "speed": speed,
64
+ }
65
+
66
+ response = requests.post(f'{url}/predict', headers=headers, data=json.dumps(data))
67
+
68
+ gen_file_path = result_folder / f"result-{uuid.uuid4().hex}.mp4"
69
+
70
+ auto_remove(result_folder)
71
+
72
+ if response.status_code == 200:
73
+ with gen_file_path.open(mode='wb') as vid:
74
+ vid.write(response.content)
75
+ else:
76
+ raise gr.Error(response.status_code)
77
+
78
+ return gen_file_path
79
+
80
+
81
+ def generate_seed():
82
+ seed = random.randint(0, 2**32 - 1)
83
+ return {
84
+ "__type__": "update",
85
+ "value": seed
86
+ }
87
+
88
+
89
+ def update_examples():
90
+
91
+ response = requests.get(f"{url}/get_examples")
92
+ if response.status_code == 200:
93
+ examples_dict = response.json()
94
+ print(examples_dict.keys())
95
+ else:
96
+ examples_dict = {}
97
+
98
+ examples = []
99
+ for role_id, role_cfg in examples_dict.items():
100
+ ref_video_path = static_folder / f'{role_id}.mp4'
101
+ if not ref_video_path.is_file():
102
+ response = requests.get(f"{url}/get_video/", params={'role': role_id})
103
+ if response.status_code == 200:
104
+ with ref_video_path.open(mode='wb') as vid:
105
+ vid.write(response.content)
106
+ else:
107
+ break
108
+
109
+ examples.append([role_id, ref_video_path, *list(role_cfg.values())])
110
+
111
+ return {
112
+ "__type__": "update",
113
+ "samples": examples,
114
+ }
115
+
116
+
117
+ def check_http(url, timeout=5):
118
+ try:
119
+ response = requests.get(url, timeout=timeout)
120
+ if response.status_code == 200:
121
+ print(f"Succeed: {response.status_code}")
122
+ return True
123
+ else:
124
+ print(f"Faild: {response.status_code}")
125
+ return False
126
+ except requests.exceptions.RequestException as e:
127
+ print(f"Error: {e}")
128
+ return False
129
+
130
+ MAX_CONNECT_TIMES = 100
131
+ for try_loop in range(MAX_CONNECT_TIMES):
132
+ print(f'Try: {try_loop}/{MAX_CONNECT_TIMES}')
133
+ if check_http(url):
134
+ break
135
+ time.sleep(10)
136
+
137
+ with open('style.css', 'r') as f:
138
+ custom_css = f.read()
139
+ with gr.Blocks(css=custom_css) as demo:
140
+ gr.Markdown("# <center> OmniTalker </center>")
141
+
142
+ gr.Markdown('''
143
+ ### Steps:
144
+ 1. Select a character in Examples. Waiting for `Reference Video` to load the video.
145
+ 2. Enter text and generate (Only Chinese and English are supported by far).
146
+
147
+ ### Tips:
148
+ 1. Try different `Seed` to achieve the best generation. God may not play dice, but AI does.
149
+ 2. Adjust `Speed` to control the speech rate. Especially when generating Chinese speech for native English speakers, recommended setting is 0.9~0.95.
150
+ 3. Due to limitations in network speed and GPU resources, the generation speed may not achieve a real-time 1:1 ratio. Appreciate.
151
+ ''')
152
+
153
+ with gr.Group():
154
+ with gr.Row():
155
+ with gr.Column():
156
+ reference_video = gr.Video(label='Reference Video', interactive=False)
157
+ with gr.Column():
158
+ output_video = gr.Video(label='Output Video', streaming=False, autoplay=True)
159
+
160
+ with gr.Row(equal_height=True):
161
+ input_text = gr.Textbox(label="Input Text", lines=8, scale=5)
162
+ with gr.Column(scale=1):
163
+ with gr.Row(equal_height=True):
164
+ seed = gr.Number(value=-1, label="Seed", elem_classes="gradio-number")
165
+ btn_seed = gr.Button(value="\U0001F3B2", elem_classes="gradio-button")
166
+ speed = gr.Slider(0, 2, value=1, step=0.01, label="Speed", scale=2)
167
+ btn_run = gr.Button('Submit', variant='primary')
168
+
169
+ with gr.Row(equal_height=True):
170
+ role = gr.Textbox(label="Role", lines=1, max_lines=1, elem_classes="gradio-textbox", interactive=False)
171
+ btn_refresh = gr.Button(value='\U0001f504', elem_classes="gradio-button")
172
+
173
+ examples = gr.Examples(
174
+ examples=update_examples()['samples'],
175
+ inputs=[role, reference_video, input_text, seed, speed],
176
+ examples_per_page=10,
177
+ )
178
+
179
+ btn_seed.click(generate_seed, inputs=[], outputs=seed)
180
+ btn_refresh.click(update_examples, outputs=[examples.dataset])
181
+ btn_run.click(predict, inputs=[role, input_text, seed, speed], outputs=[output_video])
182
+
183
+
184
+
185
+ if __name__ == '__main__':
186
+ parser = argparse.ArgumentParser()
187
+ parser.add_argument(
188
+ "-ip", "--server_ip", type=str, default="0.0.0.0",
189
+ )
190
+ parser.add_argument(
191
+ "-p", "--server_port", type=int, default=7860,
192
+ )
193
+ args = parser.parse_args()
194
+
195
+ demo.launch(server_name=args.server_ip, server_port=args.server_port, share=False)
style.css ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* temporary fix to load default gradio font in frontend instead of backend */
2
+
3
+
4
+ :root, .dark{
5
+ --checkbox-label-gap: 0.25em 0.1em;
6
+ --section-header-text-size: 12pt;
7
+ --block-background-fill: transparent;
8
+ width: 50%;
9
+ min-width: 1024px;
10
+ justify-self: center;
11
+ align-items: center;
12
+ }
13
+
14
+ .block.padded:not(.gradio-accordion) {
15
+ padding: 0 !important;
16
+ }
17
+
18
+ div.gradio-container{
19
+ max-width: unset !important;
20
+ }
21
+
22
+ .hidden{
23
+ display: none !important;
24
+ }
25
+
26
+ .compact{
27
+ background: transparent !important;
28
+ padding: 0 !important;
29
+ }
30
+
31
+ div.form{
32
+ border-width: 0;
33
+ box-shadow: none;
34
+ background: transparent;
35
+ overflow: visible;
36
+ gap: 0.5em;
37
+ }
38
+
39
+ div.gradio-group, div.styler{
40
+ border-width: 0 !important;
41
+ background: none;
42
+ }
43
+ .gap.compact{
44
+ padding: 0;
45
+ gap: 0.2em 0;
46
+ }
47
+
48
+ div.compact{
49
+ gap: 1em;
50
+ }
51
+
52
+
53
+ .block.gradio-dropdown,
54
+ .block.gradio-slider,
55
+ .block.gradio-textbox,
56
+ .block.gradio-number {
57
+ border-width: 0 !important;
58
+ box-shadow: none !important;
59
+ }
60
+
61
+ .gradio-slider input[type="number"]{
62
+ width: 6em;
63
+ }
64
+
65
+ .gradio-slider{
66
+ margin-top: 0.75em;
67
+ }
68
+
69
+ /* .gradio-dropdown label span:not(.has-info),
70
+ .gradio-textbox label span:not(.has-info),
71
+ .gradio-number label span:not(.has-info)
72
+ {
73
+ margin-bottom: 0;
74
+ } */
75
+ /*
76
+ .gradio-dropdown label span,
77
+ .gradio-textbox label span,
78
+ .gradio-number label span
79
+ {
80
+ margin-bottom: 0 !important;
81
+ } */
82
+
83
+ .gradio-button{
84
+ max-width: 2.2em;
85
+ min-width: 2.2em !important;
86
+ height: 2.4em;
87
+ align-self: end;
88
+ line-height: 1em;
89
+ border-radius: 0.5em;
90
+ }
91
+