File size: 3,258 Bytes
748160a
227bc73
27b9ec6
a5c228f
27b9ec6
227bc73
11bf65b
bb1eaa6
4d4355a
27b9ec6
fcbfe3f
27b9ec6
cec011c
a5c228f
 
 
 
 
 
eca6bdb
 
 
47d7323
89b0689
 
eca6bdb
 
 
89b0689
a5c228f
ecea5f9
a5c228f
89b0689
 
a5c228f
eca6bdb
a5c228f
d381ea9
 
a5c228f
c2f6dae
47d7323
a5c228f
 
 
eca6bdb
47d7323
a5c228f
47d7323
 
b113647
47d7323
 
 
eca6bdb
ecea5f9
 
a5c228f
 
 
 
89b0689
a5c228f
ecea5f9
a5c228f
ecea5f9
 
 
47d7323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcbfe3f
eca6bdb
47d7323
 
 
c6db084
2958517
c6db084
ecea5f9
47d7323
 
 
 
eca6bdb
c0db3ab
fcbfe3f
eca6bdb
63cb5df
792ba0c
79740f2
 
a5c228f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import spaces
import gradio as gr
import sys
import time
import os
import random
from PIL import Image 
 # os.environ["CUDA_VISIBLE_DEVICES"] = ""

# Create the gr.State component *outside* the gr.Blocks context
predictor_state = gr.State(None)

def init_predictor(task_type: str):
    import torch
    from skyreelsinfer import TaskType
    from skyreelsinfer.offload import OffloadConfig
    from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
    from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError

    try:
        predictor = SkyReelsVideoInfer(
            task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
            model_id="Skywork/skyreels-v1-Hunyuan-i2v",
            quant_model=True,
            is_offload=True,
            offload_config=OffloadConfig(
                high_cpu_memory=True,
                parameters_level=True,
            ),
            use_multiprocessing=False,
        )
        return "Model loaded successfully!", predictor  # Return predictor

    except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError) as e:
        return f"Error: Model not found. Details: {e}", None
    except Exception as e:
        return f"Error loading model: {e}", None
        
init_predictor('i2v')

@spaces.GPU(duration=80)
def generate_video(prompt, image, predictor):
    from diffusers.utils import export_to_video
    from diffusers.utils import load_image
    import os

    if image == None:
        return "Error: For i2v, provide image path.", "{}"
    if not isinstance(prompt, str):
        return "Error: No prompt.", "{}"

    #if seed == -1:
    random.seed(time.time())
    seed = int(random.randrange(4294967294))

    kwargs = {
        "prompt": prompt,
        "height": 256,
        "width": 256,
        "num_frames": 24,
        "num_inference_steps": 30,
        "seed": int(seed),
        "guidance_scale": 7.0,
        "embedded_guidance_scale": 1.0,
        "negative_prompt": "bad quality, blur",
        "cfg_for": False,
    }

    kwargs["image"] = load_image(image=image)
    output = predictor.inference(kwargs)
    frames = output
    save_dir = f"./result/{task_type}"
    os.makedirs(save_dir, exist_ok=True)
    video_out_file = f"{save_dir}/{prompt[:100]}_{int(seed)}.mp4"
    print(f"Generating video: {video_out_file}")
    export_to_video(frames, video_out_file, fps=24)
    return video_out_file

 
def display_image(file):
    if file is not None:
        return Image.open(file.name)
    else:
        return None
        
# --- Minimal Gradio Interface ---
with gr.Blocks() as demo:

    image_file = gr.File(label="Image Prompt (Required)", file_types=["image"])
    image_file_preview = gr.Image(label="Image Prompt Preview", interactive=False)
    prompt_textbox = gr.Textbox(label="Prompt")
    generate_button = gr.Button("Generate")
    output_video = gr.Video(label="Output Video") # Just a textbox

    image_file.change(
        display_image,
        inputs=[image_file],
        outputs=[image_file_preview]
    )

    generate_button.click(
        fn=generate_video,
        inputs=[prompt_textbox, image_file, predictor_state],
        outputs=[output_video],
    )

demo.launch()