SkyReels / app.py
1inkusFace's picture
Update app.py
792ba0c verified
raw
history blame
4.18 kB
import spaces
import gradio as gr
import sys
import time
import os
import random
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# Create the gr.State component *outside* the gr.Blocks context
predictor_state = gr.State(None)
def get_transformer_model_id(task_type: str) -> str:
if task_type == "i2v":
return "Skywork/skyreels-v1-Hunyuan-i2v"
else:
return "Skywork/skyreels-v1-Hunyuan-t2v"
@spaces.GPU(duration=120)
def init_predictor(task_type: str):
# ALL IMPORTS NOW INSIDE THIS FUNCTION
import torch
from skyreelsinfer import TaskType
from skyreelsinfer.offload import OffloadConfig
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError
try:
predictor = SkyReelsVideoInfer(
task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
model_id=get_transformer_model_id(task_type),
quant_model=True,
is_offload=True,
offload_config=OffloadConfig(
high_cpu_memory=True,
parameters_level=True,
),
use_multiprocessing=False,
)
return "Model loaded successfully!", predictor # Return predictor
except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError) as e:
return f"Error: Model not found. Details: {e}", None
except Exception as e:
return f"Error loading model: {e}", None
@spaces.GPU(duration=80)
def generate_video(prompt, seed, image, task_type, predictor): # predictor as argument
# IMPORTS INSIDE THIS FUNCTION TOO
from diffusers.utils import export_to_video
from diffusers.utils import load_image
import os
if task_type == "i2v" and not isinstance(image, str):
return "Error: For i2v, provide image path.", "{}"
if not isinstance(prompt, str) or not isinstance(seed, (int, float)):
return "Error: Invalid inputs.", "{}"
if seed == -1:
random.seed(time.time())
seed = int(random.randrange(4294967294))
kwargs = {
"prompt": prompt,
"height": 256,
"width": 256,
"num_frames": 24,
"num_inference_steps": 30,
"seed": int(seed),
"guidance_scale": 7.0,
"embedded_guidance_scale": 1.0,
"negative_prompt": "bad quality, blur",
"cfg_for": False,
}
if task_type == "i2v":
if image is None or not os.path.exists(image):
return "Error: Image not found.", "{}"
try:
kwargs["image"] = load_image(image=image)
except Exception as e:
return f"Error loading image: {e}", "{}"
try:
if predictor is None:
return "Error: Model not init.", "{}"
output = predictor.inference(kwargs)
frames = output
save_dir = f"./result/{task_type}"
os.makedirs(save_dir, exist_ok=True)
video_out_file = f"{save_dir}/{prompt[:100]}_{int(seed)}.mp4"
print(f"Generating video: {video_out_file}")
export_to_video(frames, video_out_file, fps=24)
return video_out_file
except Exception as e:
return f"Error: {e}", "{}"
# --- Minimal Gradio Interface ---
with gr.Blocks() as demo:
task_type_dropdown = gr.Dropdown(
choices=["i2v", "t2v"], label="Task", value="t2v", elem_id="task_type"
)
load_model_button = gr.Button("Load Model", elem_id="load_button")
prompt_textbox = gr.Textbox(label="Prompt", elem_id="prompt")
generate_button = gr.Button("Generate", elem_id="generate_button")
output_textbox = gr.Textbox(label="Output", elem_id="output") # Just a textbox
output_video = gr.Video(label="Output Video", elem_id="output_video") # Just a textbox
load_model_button.click(
fn=init_predictor,
inputs=[task_type_dropdown],
outputs=[output_textbox, predictor_state], # Correct order of outputs
)
generate_button.click(
fn=generate_video,
inputs=[prompt_textbox, task_type_dropdown, predictor_state],
outputs=[output_video],
)
demo.launch()