cheeseman182 commited on
Commit
de3c817
·
verified ·
1 Parent(s): df92597

Update media.py

Browse files
Files changed (1) hide show
  1. media.py +121 -71
media.py CHANGED
@@ -1,111 +1,159 @@
 
 
1
  # --- LIBRARIES ---
2
  import torch
3
  import gradio as gr
4
  import random
5
  import time
6
- from diffusers import AutoPipelineForText2Image, TextToVideoSDPipeline
7
  import gc
8
  import os
9
  import imageio
 
 
 
 
10
 
11
- # --- DYNAMIC HARDWARE DETECTION (THE FIX) ---
12
- # Check if a CUDA-enabled GPU is available, otherwise use the CPU
13
- if torch.cuda.is_available():
14
- device = "cuda"
15
- torch_dtype = torch.float16 # Use float16 for GPU
16
- print("✅ GPU detected. Using CUDA.")
17
- else:
18
- device = "cpu"
19
- torch_dtype = torch.float32 # Use float32 for CPU
20
- print("⚠️ No GPU detected. Using CPU. Performance will be slower.")
21
 
 
 
 
 
22
 
23
- # --- AUTHENTICATION FOR HUGGING FACE SPACES ---
24
- try:
25
- from huggingface_hub import login
26
- HF_TOKEN = os.environ.get('HF_TOKEN')
27
- if HF_TOKEN:
28
  login(token=HF_TOKEN)
29
  print("✅ Hugging Face Authentication successful.")
30
- else:
31
- print("⚠️ Hugging Face token not found in Space Secrets. Gated models may not be available.")
32
- except ImportError:
33
- print("Could not import huggingface_hub. Please ensure it's in requirements.txt")
34
 
35
  # --- CONFIGURATION & STATE ---
36
  available_models = {
37
  "Fast Image (SDXL Turbo)": "stabilityai/sdxl-turbo",
38
  "Quality Image (SDXL)": "stabilityai/stable-diffusion-xl-base-1.0",
 
39
  "Video (Damo-Vilab)": "damo-vilab/text-to-video-ms-1.7b"
40
  }
41
  model_state = { "current_pipe": None, "loaded_model_name": None }
42
 
43
-
44
- # --- CORE GENERATION FUNCTION ---
45
- def generate_media(model_key, prompt, negative_prompt, steps, cfg_scale, width, height, seed, num_frames):
46
  if model_state.get("loaded_model_name") != model_key:
47
- print(f"Switching to {model_key}. Unloading previous model...")
48
- yield {status_textbox: f"Unloading previous model..."}
49
  if model_state.get("current_pipe"):
50
- del model_state["current_pipe"]
51
- gc.collect()
52
- if device == "cuda":
53
- torch.cuda.empty_cache()
54
-
55
  model_id = available_models[model_key]
56
- print(f"Loading {model_id}...")
57
- yield {status_textbox: f"Loading {model_id}... This can take a minute."}
58
-
59
- # Adapt model loading based on hardware
60
- if "Image" in model_key:
61
- pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16" if device == "cuda" else "fp32")
62
- elif "Video" in model_key:
63
  pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
64
-
65
- # Move pipe to the detected device
 
 
66
  pipe.to(device)
67
 
68
- # CPU offloading only makes sense on a GPU setup
69
- if device == "cuda" and "Turbo" not in model_key and "Video" not in model_key:
70
- pipe.enable_model_cpu_offload()
71
-
72
  model_state["current_pipe"] = pipe
73
  model_state["loaded_model_name"] = model_key
74
- print(f"✅ Model loaded successfully on {device.upper()}.")
75
 
76
  pipe = model_state["current_pipe"]
77
  generator = torch.Generator(device).manual_seed(seed)
78
- yield {status_textbox: f"Generating with {model_key} on {device.upper()}..."}
79
-
80
- if "Image" in model_key:
81
- print("Generating image...")
82
- if "Turbo" in model_key:
83
- num_steps, guidance_scale = 1, 0.0
84
- else:
85
- num_steps, guidance_scale = int(steps), float(cfg_scale)
86
-
87
- image = pipe(
88
- prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_steps,
89
- guidance_scale=guidance_scale, width=int(width), height=int(height), generator=generator
90
- ).images[0]
91
- print("✅ Image generation complete.")
92
- yield {output_image: image, output_video: None, status_textbox: f"Seed used: {seed}"}
93
-
94
- elif "Video" in model_key:
95
- print("Generating video...")
96
  video_frames = pipe(prompt=prompt, num_inference_steps=int(steps), height=320, width=576, num_frames=int(num_frames), generator=generator).frames
 
 
 
 
 
 
 
 
 
 
97
 
98
- video_path = f"/tmp/video_{seed}.mp4"
99
- imageio.mimsave(video_path, video_frames, fps=12)
100
- print(f"✅ Video saved to {video_path}")
101
- yield {output_image: None, output_video: video_path, status_textbox: f"Seed used: {seed}"}
102
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- # --- GRADIO USER INTERFACE (No changes needed here) ---
105
  with gr.Blocks(theme='gradio/soft') as demo:
 
106
  gr.Markdown("# The Generative Media Suite")
107
- # ... (rest of the UI code is identical to before)
108
- gr.Markdown("Create fast images, high-quality images, or short videos. Created by cheeseman182.")
109
  seed_state = gr.State(-1)
110
  with gr.Row():
111
  with gr.Column(scale=2):
@@ -125,6 +173,7 @@ with gr.Blocks(theme='gradio/soft') as demo:
125
  output_image = gr.Image(label="Image Result", interactive=False, height="60vh", visible=True)
126
  output_video = gr.Video(label="Video Result", interactive=False, height="60vh", visible=False)
127
  status_textbox = gr.Textbox(label="Status", interactive=False)
 
128
  def update_ui_on_model_change(model_key):
129
  is_video = "Video" in model_key
130
  is_turbo = "Turbo" in model_key
@@ -138,13 +187,14 @@ with gr.Blocks(theme='gradio/soft') as demo:
138
  output_video: gr.update(visible=is_video)
139
  }
140
  model_selector.change(update_ui_on_model_change, model_selector, [steps_slider, cfg_slider, width_slider, height_slider, num_frames_slider, output_image, output_video])
 
141
  click_event = generate_button.click(
142
  fn=lambda s: (s if s != -1 else random.randint(0, 2**32 - 1)),
143
  inputs=seed_input,
144
  outputs=seed_state,
145
  queue=False
146
  ).then(
147
- fn=generate_media,
148
  inputs=[model_selector, prompt_input, negative_prompt_input, steps_slider, cfg_slider, width_slider, height_slider, seed_state, num_frames_slider],
149
  outputs=[output_image, output_video, status_textbox]
150
  )
 
1
+ # --- START OF FILE media.py (FINAL WITH LIVE PROGRESS) ---
2
+
3
  # --- LIBRARIES ---
4
  import torch
5
  import gradio as gr
6
  import random
7
  import time
8
+ from diffusers import AutoPipelineForText2Image, TextToVideoSDPipeline, EulerAncestralDiscreteScheduler
9
  import gc
10
  import os
11
  import imageio
12
+ import numpy as np
13
+ import threading
14
+ from queue import Queue, Empty as QueueEmpty
15
+ from PIL import Image
16
 
17
+ # --- SECURE AUTHENTICATION FOR HUGGING FACE SPACES ---
18
+ import os
19
+ from huggingface_hub import login
 
 
 
 
 
 
 
20
 
21
+ # This code will attempt to read the HF_TOKEN from the Space's secrets.
22
+ # On your local machine, this will do nothing unless you set it up, which isn't necessary.
23
+ # On the Hugging Face server, it will find the secret you just saved.
24
+ HF_TOKEN = os.environ.get('HF_TOKEN')
25
 
26
+ if HF_TOKEN:
27
+ print("✅ Found HF_TOKEN secret. Logging in...")
28
+ try:
 
 
29
  login(token=HF_TOKEN)
30
  print("✅ Hugging Face Authentication successful.")
31
+ except Exception as e:
32
+ print(f" Hugging Face login failed: {e}")
33
+ else:
34
+ print("⚠️ No HF_TOKEN secret found. Gated models may not be available on the deployed app.")
35
 
36
  # --- CONFIGURATION & STATE ---
37
  available_models = {
38
  "Fast Image (SDXL Turbo)": "stabilityai/sdxl-turbo",
39
  "Quality Image (SDXL)": "stabilityai/stable-diffusion-xl-base-1.0",
40
+ "Photorealism (Juggernaut)": "RunDiffusion/Juggernaut-XL-v9",
41
  "Video (Damo-Vilab)": "damo-vilab/text-to-video-ms-1.7b"
42
  }
43
  model_state = { "current_pipe": None, "loaded_model_name": None }
44
 
45
+ # --- THE FINAL GENERATION FUNCTION WITH LIVE PROGRESS ---
46
+ def generate_media_live_progress(model_key, prompt, negative_prompt, steps, cfg_scale, width, height, seed, num_frames):
47
+ # --- Model Loading (Unchanged) ---
48
  if model_state.get("loaded_model_name") != model_key:
49
+ yield {output_image: None, output_video: None, status_textbox: f"Loading {model_key}..."}
 
50
  if model_state.get("current_pipe"):
51
+ del model_state["current_pipe"]; gc.collect(); torch.cuda.empty_cache()
 
 
 
 
52
  model_id = available_models[model_key]
53
+ if "Video" in model_key:
 
 
 
 
 
 
54
  pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
55
+ else:
56
+ pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16")
57
+
58
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
59
  pipe.to(device)
60
 
61
+ if device == "cuda":
62
+ if "Video" not in model_key: pipe.enable_model_cpu_offload()
63
+ pipe.enable_vae_slicing()
 
64
  model_state["current_pipe"] = pipe
65
  model_state["loaded_model_name"] = model_key
66
+ print(f"✅ Model loaded on {device.upper()}.")
67
 
68
  pipe = model_state["current_pipe"]
69
  generator = torch.Generator(device).manual_seed(seed)
70
+
71
+ # --- Generation Logic ---
72
+ if "Video" in model_key:
73
+ # For video, we'll keep the simple status updates for now
74
+ yield {output_image: None, output_video: None, status_textbox: "Generating video..."}
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  video_frames = pipe(prompt=prompt, num_inference_steps=int(steps), height=320, width=576, num_frames=int(num_frames), generator=generator).frames
76
+ video_frames_5d = np.array(video_frames)
77
+ video_frames_4d = np.squeeze(video_frames_5d)
78
+ video_uint8 = (video_frames_4d * 255).astype(np.uint8)
79
+ list_of_frames = [frame for frame in video_uint8]
80
+ video_path = f"video_{seed}.mp4"
81
+ imageio.mimsave(video_path, list_of_frames, fps=12)
82
+ yield {output_image: None, output_video: video_path, status_textbox: f"Video saved! Seed: {seed}"}
83
+
84
+ else: # Image Generation with Live Progress
85
+ progress_queue = Queue()
86
 
87
+ def run_pipe():
88
+ # This function runs in a separate thread
89
+ start_time = time.time()
90
+
91
+ def progress_callback(pipe, step, timestep, callback_kwargs):
92
+ # This is called by the pipeline at each step
93
+ elapsed_time = time.time() - start_time
94
+ # Avoid division by zero on the first step
95
+ if elapsed_time > 0:
96
+ its_per_sec = (step + 1) / elapsed_time
97
+ progress_queue.put((step + 1, its_per_sec))
98
+ return callback_kwargs
99
+
100
+ try:
101
+ # The final image is still generated using the pipeline's high-quality VAE
102
+ final_image = pipe(
103
+ prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=int(steps),
104
+ guidance_scale=float(cfg_scale), width=int(width), height=int(height),
105
+ generator=generator,
106
+ callback_on_step_end=progress_callback
107
+ ).images[0]
108
+ progress_queue.put(final_image) # Put the final result on the queue
109
+ except Exception as e:
110
+ print(f"An error occurred in the generation thread: {e}")
111
+ progress_queue.put(None) # Signal an error
112
+
113
+ # Start the generation in the background
114
+ thread = threading.Thread(target=run_pipe)
115
+ thread.start()
116
+
117
+ # In the main thread, listen for updates from the queue and yield to Gradio
118
+ total_steps = int(steps)
119
+ yield {status_textbox: "Generating..."} # Initial status
120
+
121
+ while True:
122
+ try:
123
+ update = progress_queue.get(timeout=1.0) # Wait for an update
124
+
125
+ if isinstance(update, Image.Image): # It's the final image
126
+ yield {output_image: update, status_textbox: f"Generation complete! Seed: {seed}"}
127
+ break
128
+ elif isinstance(update, tuple): # It's a progress update (step, speed)
129
+ current_step, its_per_sec = update
130
+ progress_percent = (current_step / total_steps) * 100
131
+ steps_remaining = total_steps - current_step
132
+ eta_seconds = steps_remaining / its_per_sec if its_per_sec > 0 else 0
133
+ eta_minutes, eta_seconds_rem = divmod(int(eta_seconds), 60)
134
+
135
+ status_text = (
136
+ f"Generating... {progress_percent:.0f}% ({current_step}/{total_steps}) | "
137
+ f"{its_per_sec:.2f}it/s | "
138
+ f"ETA: {eta_minutes:02d}:{eta_seconds_rem:02d}"
139
+ )
140
+ yield {status_textbox: status_text}
141
+ elif update is None: # An error occurred
142
+ yield {status_textbox: "Error during generation. Check console."}
143
+ break
144
+ except QueueEmpty:
145
+ if not thread.is_alive():
146
+ print("⚠️ Generation thread finished unexpectedly.")
147
+ yield {status_textbox: "Generation failed. Check console for details."}
148
+ break
149
+
150
+ thread.join()
151
 
152
+ # --- GRADIO UI ---
153
  with gr.Blocks(theme='gradio/soft') as demo:
154
+ # (UI layout is the same, just point to the new function)
155
  gr.Markdown("# The Generative Media Suite")
156
+ gr.Markdown("Create fast images, high-quality images, or short videos. Created by cheeseman182. (note: the speed on the status bar is wrong)")
 
157
  seed_state = gr.State(-1)
158
  with gr.Row():
159
  with gr.Column(scale=2):
 
173
  output_image = gr.Image(label="Image Result", interactive=False, height="60vh", visible=True)
174
  output_video = gr.Video(label="Video Result", interactive=False, height="60vh", visible=False)
175
  status_textbox = gr.Textbox(label="Status", interactive=False)
176
+
177
  def update_ui_on_model_change(model_key):
178
  is_video = "Video" in model_key
179
  is_turbo = "Turbo" in model_key
 
187
  output_video: gr.update(visible=is_video)
188
  }
189
  model_selector.change(update_ui_on_model_change, model_selector, [steps_slider, cfg_slider, width_slider, height_slider, num_frames_slider, output_image, output_video])
190
+
191
  click_event = generate_button.click(
192
  fn=lambda s: (s if s != -1 else random.randint(0, 2**32 - 1)),
193
  inputs=seed_input,
194
  outputs=seed_state,
195
  queue=False
196
  ).then(
197
+ fn=generate_media_live_progress, # Use the new function with progress
198
  inputs=[model_selector, prompt_input, negative_prompt_input, steps_slider, cfg_slider, width_slider, height_slider, seed_state, num_frames_slider],
199
  outputs=[output_image, output_video, status_textbox]
200
  )