ginipick commited on
Commit
d57c019
1 Parent(s): e7ca6c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -527
app.py CHANGED
@@ -1,528 +1,2 @@
1
- # 1. 먼저 로깅 설정
2
- import logging
3
- logging.basicConfig(level=logging.INFO)
4
- logger = logging.getLogger(__name__)
5
-
6
- # 2. spaces를 먼저 import
7
- import spaces
8
-
9
- # 3. 나머지 imports
10
  import os
11
- import time
12
- from datetime import datetime
13
- import gradio as gr
14
- import torch
15
- import requests
16
- from pathlib import Path
17
- import cv2
18
- from PIL import Image
19
- import json
20
- import torchaudio
21
- import tempfile
22
-
23
- # 4. GPU 초기화 설정
24
- if torch.cuda.is_available():
25
- device = torch.device('cuda')
26
- logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
27
- else:
28
- device = torch.device('cpu')
29
- logger.warning("GPU not available, using CPU")
30
-
31
- try:
32
- import mmaudio
33
- except ImportError:
34
- os.system("pip install -e .")
35
- import mmaudio
36
-
37
- # 나머지 imports
38
- from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
39
- setup_eval_logging)
40
- from mmaudio.model.flow_matching import FlowMatching
41
- from mmaudio.model.networks import MMAudio, get_my_mmaudio
42
- from mmaudio.model.sequence_config import SequenceConfig
43
- from mmaudio.model.utils.features_utils import FeaturesUtils
44
-
45
- # 번역 모델 import
46
- from transformers import pipeline
47
- translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
48
-
49
- # API 설정
50
- CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
51
- REPLICATE_API_TOKEN = os.getenv("API_KEY")
52
-
53
- # 오디오 모델 설정
54
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
55
-
56
- # 5. get_model 함수 정의
57
- def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
58
- seq_cfg = model.seq_cfg
59
-
60
- net: MMAudio = get_my_mmaudio(model.model_name).to(device)
61
- if torch.cuda.is_available():
62
- net = net.to(dtype)
63
- net.eval()
64
-
65
- net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
66
- logger.info(f'Loaded weights from {model.model_path}')
67
-
68
- feature_utils = FeaturesUtils(
69
- tod_vae_ckpt=model.vae_path,
70
- synchformer_ckpt=model.synchformer_ckpt,
71
- enable_conditions=True,
72
- mode=model.mode,
73
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
74
- need_vae_encoder=False
75
- ).to(device)
76
-
77
- if torch.cuda.is_available():
78
- feature_utils = feature_utils.to(dtype)
79
- feature_utils.eval()
80
-
81
- return net, feature_utils, seq_cfg
82
-
83
- # 6. 모델 초기화
84
- model: ModelConfig = all_model_cfg['large_44k_v2']
85
- model.download_if_needed()
86
- output_dir = Path('./output/gradio')
87
-
88
- setup_eval_logging()
89
- net, feature_utils, seq_cfg = get_model()
90
-
91
- @spaces.GPU(duration=30)
92
- @torch.inference_mode()
93
- def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
94
- seed: int = -1, num_steps: int = 15,
95
- cfg_strength: float = 4.0, target_duration: float = None):
96
- try:
97
- logger.info("Starting audio generation process")
98
- if torch.cuda.is_available():
99
- torch.cuda.empty_cache()
100
-
101
-
102
- # 비디오 길이 확인
103
- cap = cv2.VideoCapture(video_path)
104
- fps = cap.get(cv2.CAP_PROP_FPS)
105
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
106
- video_duration = total_frames / fps
107
- cap.release()
108
-
109
- # 실제 비디오 길이를 target_duration으로 사용
110
- target_duration = video_duration
111
- logger.info(f"Video duration: {target_duration} seconds")
112
-
113
- rng = torch.Generator(device=device)
114
- if seed >= 0:
115
- rng.manual_seed(seed)
116
- else:
117
- rng.seed()
118
-
119
- fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
120
-
121
- # 비디오 길이에 맞춰 load_video 호출
122
- video_info = load_video(video_path, duration_sec=target_duration)
123
-
124
- if video_info is None:
125
- logger.error("Failed to load video")
126
- return video_path
127
-
128
- clip_frames = video_info.clip_frames
129
- sync_frames = video_info.sync_frames
130
- actual_duration = video_info.duration_sec
131
-
132
- if clip_frames is None or sync_frames is None:
133
- logger.error("Failed to extract frames from video")
134
- return video_path
135
-
136
- # 실제 비디오 프레임 수에 맞춰 조정
137
- clip_frames = clip_frames[:int(actual_duration * video_info.fps)]
138
- sync_frames = sync_frames[:int(actual_duration * video_info.fps)]
139
-
140
- clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
141
- sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
142
-
143
- # sequence config 업데이트
144
- seq_cfg.duration = actual_duration
145
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
146
-
147
- logger.info(f"Generating audio for {actual_duration} seconds...")
148
-
149
-
150
- logger.info("Generating audio...")
151
- with torch.cuda.amp.autocast():
152
- audios = generate(clip_frames,
153
- sync_frames,
154
- [prompt],
155
- negative_text=[negative_prompt],
156
- feature_utils=feature_utils,
157
- net=net,
158
- fm=fm,
159
- rng=rng,
160
- cfg_strength=cfg_strength)
161
-
162
- if audios is None:
163
- logger.error("Failed to generate audio")
164
- return video_path
165
-
166
- audio = audios.float().cpu()[0]
167
-
168
- output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
169
- logger.info(f"Creating final video with audio at {output_path}")
170
-
171
- make_video(video_info, output_path, audio, sampling_rate=seq_cfg.sampling_rate)
172
-
173
- torch.cuda.empty_cache()
174
-
175
- if not os.path.exists(output_path):
176
- logger.error("Failed to create output video")
177
- return video_path
178
-
179
- logger.info(f'Successfully saved video with audio to {output_path}')
180
- return output_path
181
-
182
- except Exception as e:
183
- logger.error(f"Error in video_to_audio: {str(e)}")
184
- torch.cuda.empty_cache()
185
- return video_path
186
-
187
- def upload_to_catbox(file_path):
188
- """catbox.moe API를 사용하여 파일 업로드"""
189
- try:
190
- logger.info(f"Preparing to upload file: {file_path}")
191
- url = "https://catbox.moe/user/api.php"
192
-
193
- mime_types = {
194
- '.jpg': 'image/jpeg',
195
- '.jpeg': 'image/jpeg',
196
- '.png': 'image/png',
197
- '.gif': 'image/gif',
198
- '.webp': 'image/webp',
199
- '.jfif': 'image/jpeg'
200
- }
201
-
202
- file_extension = Path(file_path).suffix.lower()
203
-
204
- if file_extension not in mime_types:
205
- try:
206
- img = Image.open(file_path)
207
- if img.mode != 'RGB':
208
- img = img.convert('RGB')
209
-
210
- new_path = file_path.rsplit('.', 1)[0] + '.png'
211
- img.save(new_path, 'PNG')
212
- file_path = new_path
213
- file_extension = '.png'
214
- logger.info(f"Converted image to PNG: {file_path}")
215
- except Exception as e:
216
- logger.error(f"Failed to convert image: {str(e)}")
217
- return None
218
-
219
- files = {
220
- 'fileToUpload': (
221
- os.path.basename(file_path),
222
- open(file_path, 'rb'),
223
- mime_types.get(file_extension, 'application/octet-stream')
224
- )
225
- }
226
-
227
- data = {
228
- 'reqtype': 'fileupload',
229
- 'userhash': CATBOX_USER_HASH
230
- }
231
-
232
- response = requests.post(url, files=files, data=data)
233
-
234
- if response.status_code == 200 and response.text.startswith('http'):
235
- file_url = response.text
236
- logger.info(f"File uploaded successfully: {file_url}")
237
- return file_url
238
- else:
239
- raise Exception(f"Upload failed: {response.text}")
240
-
241
- except Exception as e:
242
- logger.error(f"File upload error: {str(e)}")
243
- return None
244
- finally:
245
- if 'new_path' in locals() and os.path.exists(new_path):
246
- try:
247
- os.remove(new_path)
248
- except:
249
- pass
250
-
251
- def add_watermark(video_path):
252
- """OpenCV를 사용하여 비디오에 워터마크 추가"""
253
- try:
254
- cap = cv2.VideoCapture(video_path)
255
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
256
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
257
- fps = int(cap.get(cv2.CAP_PROP_FPS))
258
-
259
- text = "GiniGEN.AI"
260
- font = cv2.FONT_HERSHEY_SIMPLEX
261
- font_scale = height * 0.05 / 30
262
- thickness = 2
263
- color = (255, 255, 255)
264
-
265
- (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
266
- margin = int(height * 0.02)
267
- x_pos = width - text_width - margin
268
- y_pos = height - margin
269
-
270
- output_path = "watermarked_output.mp4"
271
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
272
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
273
-
274
- while cap.isOpened():
275
- ret, frame = cap.read()
276
- if not ret:
277
- break
278
- cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
279
- out.write(frame)
280
-
281
- cap.release()
282
- out.release()
283
-
284
- return output_path
285
-
286
- except Exception as e:
287
- logger.error(f"Error adding watermark: {str(e)}")
288
- return video_path
289
-
290
- def generate_video(image, prompt):
291
- logger.info("Starting video generation with API")
292
- try:
293
- API_KEY = os.getenv("API_KEY", "").strip()
294
- if not API_KEY:
295
- return "API key not properly configured"
296
-
297
- temp_dir = "temp_videos"
298
- os.makedirs(temp_dir, exist_ok=True)
299
-
300
- image_url = None
301
- if image:
302
- image_url = upload_to_catbox(image)
303
- if not image_url:
304
- return "Failed to upload image"
305
- logger.info(f"Input image URL: {image_url}")
306
-
307
- generation_url = "https://api.minimaxi.chat/v1/video_generation"
308
- headers = {
309
- 'authorization': f'Bearer {API_KEY}',
310
- 'Content-Type': 'application/json'
311
- }
312
-
313
- payload = {
314
- "model": "video-01",
315
- "prompt": prompt if prompt else "",
316
- "prompt_optimizer": True
317
- }
318
-
319
- if image_url:
320
- payload["first_frame_image"] = image_url
321
-
322
- logger.info(f"Sending request with payload: {payload}")
323
-
324
- response = requests.post(generation_url, headers=headers, json=payload)
325
-
326
- if not response.ok:
327
- error_msg = f"Failed to create video generation task: {response.text}"
328
- logger.error(error_msg)
329
- return error_msg
330
-
331
- response_data = response.json()
332
- task_id = response_data.get('task_id')
333
- if not task_id:
334
- return "Failed to get task ID from response"
335
-
336
- query_url = "https://api.minimaxi.chat/v1/query/video_generation"
337
- max_attempts = 30
338
- attempt = 0
339
-
340
- while attempt < max_attempts:
341
- time.sleep(10)
342
- query_response = requests.get(
343
- f"{query_url}?task_id={task_id}",
344
- headers={'authorization': f'Bearer {API_KEY}'}
345
- )
346
-
347
- if not query_response.ok:
348
- attempt += 1
349
- continue
350
-
351
- status_data = query_response.json()
352
- status = status_data.get('status')
353
-
354
- if status == 'Success':
355
- file_id = status_data.get('file_id')
356
- if not file_id:
357
- return "Failed to get file ID"
358
-
359
- retrieve_url = "https://api.minimaxi.chat/v1/files/retrieve"
360
- params = {'file_id': file_id}
361
-
362
- file_response = requests.get(
363
- retrieve_url,
364
- headers={'authorization': f'Bearer {API_KEY}'},
365
- params=params
366
- )
367
-
368
- if not file_response.ok:
369
- return "Failed to retrieve video file"
370
-
371
- try:
372
- file_data = file_response.json()
373
- download_url = file_data.get('file', {}).get('download_url')
374
- if not download_url:
375
- return "Failed to get download URL"
376
-
377
- result_info = {
378
- "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
379
- "input_image": image_url,
380
- "output_video_url": download_url,
381
- "prompt": prompt
382
- }
383
- logger.info(f"Video generation result: {json.dumps(result_info, indent=2)}")
384
-
385
- video_response = requests.get(download_url)
386
- if not video_response.ok:
387
- return "Failed to download video"
388
-
389
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
390
- output_path = os.path.join(temp_dir, f"output_{timestamp}.mp4")
391
-
392
- with open(output_path, 'wb') as f:
393
- f.write(video_response.content)
394
-
395
- final_path = add_watermark(output_path)
396
-
397
- # 비디오 길이 확인
398
- cap = cv2.VideoCapture(final_path)
399
- fps = cap.get(cv2.CAP_PROP_FPS)
400
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
401
- video_duration = total_frames / fps
402
- cap.release()
403
-
404
- logger.info(f"Original video duration: {video_duration} seconds")
405
-
406
- # 오디오 처리 추가
407
- try:
408
- logger.info("Starting audio generation process")
409
- final_path_with_audio = video_to_audio(
410
- final_path,
411
- prompt=prompt,
412
- negative_prompt="music",
413
- seed=-1,
414
- num_steps=20,
415
- cfg_strength=4.5
416
- # target_duration 제거 - 자동으로 비디오 길이 사용
417
- )
418
-
419
- if final_path_with_audio != final_path:
420
- logger.info("Audio generation successful")
421
- try:
422
- if output_path != final_path:
423
- os.remove(output_path)
424
- if final_path != final_path_with_audio:
425
- os.remove(final_path)
426
- except Exception as e:
427
- logger.warning(f"Error cleaning up temporary files: {str(e)}")
428
-
429
- return final_path_with_audio
430
- else:
431
- logger.warning("Audio generation skipped, using original video")
432
- return final_path
433
-
434
- except Exception as e:
435
- logger.error(f"Error in audio processing: {str(e)}")
436
- return final_path # 오디오 처리 실패 시 워터마크만 된 비디오 반환
437
-
438
- except Exception as e:
439
- logger.error(f"Error processing video file: {str(e)}")
440
- return "Error processing video file"
441
-
442
- elif status == 'Fail':
443
- return "Video generation failed"
444
-
445
- attempt += 1
446
-
447
- return "Timeout waiting for video generation"
448
-
449
- except Exception as e:
450
- logger.error(f"Error in video generation: {str(e)}")
451
- return f"Error in video generation process: {str(e)}"
452
-
453
- css = """
454
- footer {
455
- visibility: hidden;
456
- }
457
- .gradio-container {max-width: 1200px !important}
458
- """
459
-
460
-
461
- with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
462
- gr.HTML('<div class="title">🎥 Dokdo Multimodal✨ "Prompt guide for automated video and sound synthesis from images" </div>')
463
- gr.HTML('<div class="title">😄 Explore: <a href="https://huggingface.co/spaces/ginigen/theater" target="_blank">https://huggingface.co/spaces/ginigen/theater</a></div>')
464
-
465
- with gr.Row():
466
- with gr.Column(scale=3):
467
- video_prompt = gr.Textbox(
468
- label="Video Description",
469
- placeholder="Enter video description...",
470
- lines=3
471
- )
472
- upload_image = gr.Image(type="filepath", label="Upload First Frame Image")
473
- video_generate_btn = gr.Button("🎬 Generate Video")
474
-
475
- with gr.Column(scale=4):
476
- video_output = gr.Video(label="Generated Video")
477
-
478
-
479
-
480
-
481
- # process_and_generate_video 함수 수정
482
- def process_and_generate_video(image, prompt):
483
- if image is None:
484
- return "Please upload an image"
485
-
486
- try:
487
- # 한글 프롬프트 감지 및 번역
488
- contains_korean = any(ord('가') <= ord(char) <= ord('힣') for char in prompt)
489
- if contains_korean:
490
- translated = translator(prompt)[0]['translation_text']
491
- logger.info(f"Translated prompt from '{prompt}' to '{translated}'")
492
- prompt = translated
493
-
494
- img = Image.open(image)
495
- if img.mode != 'RGB':
496
- img = img.convert('RGB')
497
-
498
- temp_path = f"temp_{int(time.time())}.png"
499
- img.save(temp_path, 'PNG')
500
-
501
- result = generate_video(temp_path, prompt)
502
-
503
- try:
504
- os.remove(temp_path)
505
- except:
506
- pass
507
-
508
- return result
509
-
510
- except Exception as e:
511
- logger.error(f"Error processing image: {str(e)}")
512
- return "Error processing image"
513
-
514
-
515
- video_generate_btn.click(
516
- process_and_generate_video,
517
- inputs=[upload_image, video_prompt],
518
- outputs=video_output
519
- )
520
-
521
- if __name__ == "__main__":
522
- # GPU 초기화 확인
523
- if torch.cuda.is_available():
524
- logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
525
- else:
526
- logger.warning("GPU not available, using CPU")
527
-
528
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ exec(os.environ.get('APP'))