File size: 4,182 Bytes
748160a
227bc73
27b9ec6
a5c228f
27b9ec6
227bc73
89b0689
b3b574e
4d4355a
27b9ec6
fcbfe3f
27b9ec6
eca6bdb
89b0689
 
 
 
c2f6dae
 
eca6bdb
a5c228f
 
 
 
 
 
 
eca6bdb
 
 
 
89b0689
 
eca6bdb
 
 
89b0689
a5c228f
ecea5f9
a5c228f
89b0689
 
a5c228f
eca6bdb
a5c228f
 
c2f6dae
a5c228f
 
 
 
 
eca6bdb
 
a5c228f
eca6bdb
a5c228f
b113647
227bc73
ecea5f9
227bc73
eca6bdb
ecea5f9
 
a5c228f
 
 
 
89b0689
a5c228f
ecea5f9
a5c228f
ecea5f9
 
 
 
eca6bdb
a5c228f
eca6bdb
 
 
a5c228f
c0db3ab
eca6bdb
 
a5c228f
 
 
 
b113647
5600a1c
 
a5c228f
 
 
149a9b5
ecea5f9
eca6bdb
a5c228f
4d4355a
fcbfe3f
eca6bdb
fcbfe3f
 
 
 
 
 
 
792ba0c
ecea5f9
eca6bdb
 
 
fcbfe3f
eca6bdb
c0db3ab
fcbfe3f
eca6bdb
fcbfe3f
792ba0c
79740f2
 
a5c228f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import spaces
import gradio as gr
import sys
import time
import os
import random

os.environ["CUDA_VISIBLE_DEVICES"] = ""

# Create the gr.State component *outside* the gr.Blocks context
predictor_state = gr.State(None)

def get_transformer_model_id(task_type: str) -> str:
    if task_type == "i2v":
        return "Skywork/skyreels-v1-Hunyuan-i2v"
    else:
        return "Skywork/skyreels-v1-Hunyuan-t2v"
        
@spaces.GPU(duration=120)
def init_predictor(task_type: str):
    # ALL IMPORTS NOW INSIDE THIS FUNCTION
    import torch
    from skyreelsinfer import TaskType
    from skyreelsinfer.offload import OffloadConfig
    from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
    from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError

    try:
        predictor = SkyReelsVideoInfer(
            task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
            model_id=get_transformer_model_id(task_type),
            quant_model=True,
            is_offload=True,
            offload_config=OffloadConfig(
                high_cpu_memory=True,
                parameters_level=True,
            ),
            use_multiprocessing=False,
        )
        return "Model loaded successfully!", predictor  # Return predictor

    except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError) as e:
        return f"Error: Model not found. Details: {e}", None
    except Exception as e:
        return f"Error loading model: {e}", None

@spaces.GPU(duration=80)
def generate_video(prompt, seed, image, task_type, predictor): # predictor as argument
    # IMPORTS INSIDE THIS FUNCTION TOO
    from diffusers.utils import export_to_video
    from diffusers.utils import load_image
    import os

    if task_type == "i2v" and not isinstance(image, str):
        return "Error: For i2v, provide image path.", "{}"
    if not isinstance(prompt, str) or not isinstance(seed, (int, float)):
        return "Error: Invalid inputs.", "{}"

    if seed == -1:
        random.seed(time.time())
        seed = int(random.randrange(4294967294))

    kwargs = {
        "prompt": prompt,
        "height": 256,
        "width": 256,
        "num_frames": 24,
        "num_inference_steps": 30,
        "seed": int(seed),
        "guidance_scale": 7.0,
        "embedded_guidance_scale": 1.0,
        "negative_prompt": "bad quality, blur",
        "cfg_for": False,
    }

    if task_type == "i2v":
        if image is None or not os.path.exists(image):
            return "Error: Image not found.", "{}"
        try:
            kwargs["image"] = load_image(image=image)
        except Exception as e:
            return f"Error loading image: {e}", "{}"

    try:
        if predictor is None:
            return "Error: Model not init.", "{}"

        output = predictor.inference(kwargs)
        frames = output

        save_dir = f"./result/{task_type}"
        os.makedirs(save_dir, exist_ok=True)
        video_out_file = f"{save_dir}/{prompt[:100]}_{int(seed)}.mp4"
        print(f"Generating video: {video_out_file}")
        export_to_video(frames, video_out_file, fps=24)
        return video_out_file

    except Exception as e:
        return f"Error: {e}", "{}"

# --- Minimal Gradio Interface ---
with gr.Blocks() as demo:
    task_type_dropdown = gr.Dropdown(
        choices=["i2v", "t2v"], label="Task", value="t2v", elem_id="task_type"
    )
    load_model_button = gr.Button("Load Model", elem_id="load_button")
    prompt_textbox = gr.Textbox(label="Prompt", elem_id="prompt")
    generate_button = gr.Button("Generate", elem_id="generate_button")
    output_textbox = gr.Textbox(label="Output", elem_id="output") # Just a textbox
    output_video = gr.Video(label="Output Video", elem_id="output_video") # Just a textbox

    load_model_button.click(
        fn=init_predictor,
        inputs=[task_type_dropdown],
        outputs=[output_textbox, predictor_state], # Correct order of outputs
    )

    generate_button.click(
        fn=generate_video,
        inputs=[prompt_textbox, task_type_dropdown, predictor_state],
        outputs=[output_video],
    )

demo.launch()