File size: 2,782 Bytes
7e1f9a4
 
 
 
 
 
 
 
 
 
25a2142
7e1f9a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25a2142
7e1f9a4
 
 
 
25a2142
 
7e1f9a4
 
 
 
25a2142
7e1f9a4
 
 
 
 
 
 
 
 
25a2142
7e1f9a4
25a2142
7e1f9a4
 
 
 
 
 
 
 
 
 
 
 
 
25a2142
7e1f9a4
 
 
 
 
 
 
25a2142
 
7e1f9a4
 
 
 
 
 
 
 
25a2142
7e1f9a4
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import shutil
import tempfile
import threading
import subprocess
from pathlib import Path

import gradio as gr
from PIL import Image

RIFE_DIR = Path("rife")  # Local bundled repo
OUTPUT_DIR = RIFE_DIR / "output"
LOCK = threading.Lock()

def interpolate(a: Image.Image, b: Image.Image, fps: int = 14, exp: int = 2):
    with LOCK:
        # Normalize inputs
        a = a.convert("RGB")
        b = b.convert("RGB")
        if b.size != a.size:
            b = b.resize(a.size, Image.BICUBIC)

        work_dir = Path(tempfile.mkdtemp(prefix="rife_run_"))
        p1 = work_dir / "a.png"
        p2 = work_dir / "b.png"
        a.save(p1, "PNG")
        b.save(p2, "PNG")

        # Clean previous outputs
        if OUTPUT_DIR.exists():
            shutil.rmtree(OUTPUT_DIR)
        OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

        # Run RIFE inference
        cmd = ["python3", str(RIFE_DIR / "inference_img.py"), "--img", str(p1), str(p2)]
        if isinstance(exp, int) and exp >= 1:
            cmd += ["--exp", str(exp)]
        subprocess.run(cmd, cwd=str(RIFE_DIR), check=True)

        # Collect interpolated frames
        frames = []
        i = 1
        while True:
            fp = OUTPUT_DIR / f"img{i}.png"
            if not fp.exists():
                break
            frames.append(fp)
            i += 1
        if not frames:
            raise RuntimeError("No frames generated.")

        # Build GIF
        images = [Image.open(p).convert("RGBA") for p in frames]
        duration_ms = max(1, int(1000 / max(1, fps)))
        gif_path = work_dir / "interpolation.gif"
        images[0].save(
            gif_path,
            save_all=True,
            append_images=images[1:],
            optimize=False,
            duration=duration_ms,
            loop=0,
            disposal=2,
        )

        # Optional cleanup
        try:
            shutil.rmtree(OUTPUT_DIR)
        except Exception:
            pass

        return str(gif_path)

# Gradio UI
TITLE = "🔥 RIFE Interpolation Demo (PyTorch, Local)"
with gr.Blocks(title=TITLE, analytics_enabled=False) as demo:
    gr.Markdown(f"# {TITLE}")
    with gr.Row():
        with gr.Column():
            img_a = gr.Image(type="pil", label="Image A")
            img_b = gr.Image(type="pil", label="Image B")
        with gr.Column():
            fps = gr.Slider(6, 30, value=14, step=1, label="FPS")
            exp = gr.Slider(1, 4, value=2, step=1, label="Interpolation exponent")
            run = gr.Button("Interpolate", variant="primary")
    gif_out = gr.Image(type="filepath", label="Result GIF")
    run.click(interpolate, inputs=[img_a, img_b, fps, exp], outputs=[gif_out])
    demo.queue(concurrency_count=1, max_size=8)

if __name__ == "__main__":
    demo.launch()