Spaces:
Sleeping
Sleeping
| # --- START OF FILE media.py (FINAL WITH LIVE PROGRESS) --- | |
| # --- LIBRARIES --- | |
| import torch | |
| import gradio as gr | |
| import random | |
| import time | |
| from diffusers import AutoPipelineForText2Image, TextToVideoSDPipeline, EulerAncestralDiscreteScheduler | |
| import gc | |
| import os | |
| import imageio | |
| import numpy as np | |
| import threading | |
| from queue import Queue, Empty as QueueEmpty | |
| from PIL import Image | |
| # --- SECURE AUTHENTICATION FOR HUGGING FACE SPACES --- | |
| import os | |
| from huggingface_hub import login | |
| # This code will attempt to read the HF_TOKEN from the Space's secrets. | |
| # On your local machine, this will do nothing unless you set it up, which isn't necessary. | |
| # On the Hugging Face server, it will find the secret you just saved. | |
| HF_TOKEN = os.environ.get('HF_TOKEN') | |
| if HF_TOKEN: | |
| print("✅ Found HF_TOKEN secret. Logging in...") | |
| try: | |
| login(token=HF_TOKEN) | |
| print("✅ Hugging Face Authentication successful.") | |
| except Exception as e: | |
| print(f"❌ Hugging Face login failed: {e}") | |
| else: | |
| print("⚠️ No HF_TOKEN secret found. Gated models may not be available on the deployed app.") | |
| # --- CONFIGURATION & STATE --- | |
| available_models = { | |
| "Fast Image (SDXL Turbo)": "stabilityai/sdxl-turbo", | |
| "Quality Image (SDXL)": "stabilityai/stable-diffusion-xl-base-1.0", | |
| "Photorealism (Juggernaut)": "RunDiffusion/Juggernaut-XL-v9", | |
| "Video (Damo-Vilab)": "damo-vilab/text-to-video-ms-1.7b" | |
| } | |
| model_state = { "current_pipe": None, "loaded_model_name": None } | |
| # --- THE FINAL GENERATION FUNCTION WITH LIVE PROGRESS --- | |
| def generate_media_live_progress(model_key, prompt, negative_prompt, steps, cfg_scale, width, height, seed, num_frames): | |
| # --- Model Loading (Unchanged) --- | |
| if model_state.get("loaded_model_name") != model_key: | |
| yield {output_image: None, output_video: None, status_textbox: f"Loading {model_key}..."} | |
| if model_state.get("current_pipe"): | |
| del model_state["current_pipe"]; gc.collect(); torch.cuda.empty_cache() | |
| model_id = available_models[model_key] | |
| if "Video" in model_key: | |
| pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch_dtype) | |
| else: | |
| pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16") | |
| pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
| pipe.to(device) | |
| if device == "cuda": | |
| if "Video" not in model_key: pipe.enable_model_cpu_offload() | |
| pipe.enable_vae_slicing() | |
| model_state["current_pipe"] = pipe | |
| model_state["loaded_model_name"] = model_key | |
| print(f"✅ Model loaded on {device.upper()}.") | |
| pipe = model_state["current_pipe"] | |
| generator = torch.Generator(device).manual_seed(seed) | |
| # --- Generation Logic --- | |
| if "Video" in model_key: | |
| # For video, we'll keep the simple status updates for now | |
| yield {output_image: None, output_video: None, status_textbox: "Generating video..."} | |
| video_frames = pipe(prompt=prompt, num_inference_steps=int(steps), height=320, width=576, num_frames=int(num_frames), generator=generator).frames | |
| video_frames_5d = np.array(video_frames) | |
| video_frames_4d = np.squeeze(video_frames_5d) | |
| video_uint8 = (video_frames_4d * 255).astype(np.uint8) | |
| list_of_frames = [frame for frame in video_uint8] | |
| video_path = f"video_{seed}.mp4" | |
| imageio.mimsave(video_path, list_of_frames, fps=12) | |
| yield {output_image: None, output_video: video_path, status_textbox: f"Video saved! Seed: {seed}"} | |
| else: # Image Generation with Live Progress | |
| progress_queue = Queue() | |
| def run_pipe(): | |
| # This function runs in a separate thread | |
| start_time = time.time() | |
| def progress_callback(pipe, step, timestep, callback_kwargs): | |
| # This is called by the pipeline at each step | |
| elapsed_time = time.time() - start_time | |
| # Avoid division by zero on the first step | |
| if elapsed_time > 0: | |
| its_per_sec = (step + 1) / elapsed_time | |
| progress_queue.put((step + 1, its_per_sec)) | |
| return callback_kwargs | |
| try: | |
| # The final image is still generated using the pipeline's high-quality VAE | |
| final_image = pipe( | |
| prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=int(steps), | |
| guidance_scale=float(cfg_scale), width=int(width), height=int(height), | |
| generator=generator, | |
| callback_on_step_end=progress_callback | |
| ).images[0] | |
| progress_queue.put(final_image) # Put the final result on the queue | |
| except Exception as e: | |
| print(f"An error occurred in the generation thread: {e}") | |
| progress_queue.put(None) # Signal an error | |
| # Start the generation in the background | |
| thread = threading.Thread(target=run_pipe) | |
| thread.start() | |
| # In the main thread, listen for updates from the queue and yield to Gradio | |
| total_steps = int(steps) | |
| yield {status_textbox: "Generating..."} # Initial status | |
| while True: | |
| try: | |
| update = progress_queue.get(timeout=1.0) # Wait for an update | |
| if isinstance(update, Image.Image): # It's the final image | |
| yield {output_image: update, status_textbox: f"Generation complete! Seed: {seed}"} | |
| break | |
| elif isinstance(update, tuple): # It's a progress update (step, speed) | |
| current_step, its_per_sec = update | |
| progress_percent = (current_step / total_steps) * 100 | |
| steps_remaining = total_steps - current_step | |
| eta_seconds = steps_remaining / its_per_sec if its_per_sec > 0 else 0 | |
| eta_minutes, eta_seconds_rem = divmod(int(eta_seconds), 60) | |
| status_text = ( | |
| f"Generating... {progress_percent:.0f}% ({current_step}/{total_steps}) | " | |
| f"{its_per_sec:.2f}it/s | " | |
| f"ETA: {eta_minutes:02d}:{eta_seconds_rem:02d}" | |
| ) | |
| yield {status_textbox: status_text} | |
| elif update is None: # An error occurred | |
| yield {status_textbox: "Error during generation. Check console."} | |
| break | |
| except QueueEmpty: | |
| if not thread.is_alive(): | |
| print("⚠️ Generation thread finished unexpectedly.") | |
| yield {status_textbox: "Generation failed. Check console for details."} | |
| break | |
| thread.join() | |
| # --- GRADIO UI --- | |
| with gr.Blocks(theme='gradio/soft') as demo: | |
| # (UI layout is the same, just point to the new function) | |
| gr.Markdown("# The Generative Media Suite") | |
| gr.Markdown("Create fast images, high-quality images, or short videos. Created by cheeseman182. (note: the speed on the status bar is wrong)") | |
| seed_state = gr.State(-1) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| model_selector = gr.Radio(label="Select Model", choices=list(available_models.keys()), value=list(available_models.keys())[0]) | |
| prompt_input = gr.Textbox(label="Prompt", lines=4, placeholder="An astronaut riding a horse on Mars, cinematic...") | |
| negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, deformed, watermark, text") | |
| with gr.Accordion("Settings", open=True): | |
| steps_slider = gr.Slider(1, 100, 30, step=1, label="Inference Steps") | |
| cfg_slider = gr.Slider(0.0, 15.0, 7.5, step=0.5, label="Guidance Scale (CFG)") | |
| with gr.Row(): | |
| width_slider = gr.Slider(256, 1024, 768, step=64, label="Width") | |
| height_slider = gr.Slider(256, 1024, 768, step=64, label="Height") | |
| num_frames_slider = gr.Slider(12, 48, 24, step=4, label="Video Frames", visible=False) | |
| seed_input = gr.Number(-1, label="Seed (-1 for random)") | |
| generate_button = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=3): | |
| output_image = gr.Image(label="Image Result", interactive=False, height="60vh", visible=True) | |
| output_video = gr.Video(label="Video Result", interactive=False, height="60vh", visible=False) | |
| status_textbox = gr.Textbox(label="Status", interactive=False) | |
| def update_ui_on_model_change(model_key): | |
| is_video = "Video" in model_key | |
| is_turbo = "Turbo" in model_key | |
| return { | |
| steps_slider: gr.update(interactive=not is_turbo, value=1 if is_turbo else 30), | |
| cfg_slider: gr.update(interactive=not is_turbo, value=0.0 if is_turbo else 7.5), | |
| width_slider: gr.update(visible=not is_video), | |
| height_slider: gr.update(visible=not is_video), | |
| num_frames_slider: gr.update(visible=is_video), | |
| output_image: gr.update(visible=not is_video), | |
| output_video: gr.update(visible=is_video) | |
| } | |
| model_selector.change(update_ui_on_model_change, model_selector, [steps_slider, cfg_slider, width_slider, height_slider, num_frames_slider, output_image, output_video]) | |
| click_event = generate_button.click( | |
| fn=lambda s: (s if s != -1 else random.randint(0, 2**32 - 1)), | |
| inputs=seed_input, | |
| outputs=seed_state, | |
| queue=False | |
| ).then( | |
| fn=generate_media_live_progress, # Use the new function with progress | |
| inputs=[model_selector, prompt_input, negative_prompt_input, steps_slider, cfg_slider, width_slider, height_slider, seed_state, num_frames_slider], | |
| outputs=[output_image, output_video, status_textbox] | |
| ) | |
| demo.launch() |