AItool's picture
Update app.py
25a2142 verified
raw
history blame
2.78 kB
import os
import shutil
import tempfile
import threading
import subprocess
from pathlib import Path
import gradio as gr
from PIL import Image
RIFE_DIR = Path("rife") # Local bundled repo
OUTPUT_DIR = RIFE_DIR / "output"
LOCK = threading.Lock()
def interpolate(a: Image.Image, b: Image.Image, fps: int = 14, exp: int = 2):
with LOCK:
# Normalize inputs
a = a.convert("RGB")
b = b.convert("RGB")
if b.size != a.size:
b = b.resize(a.size, Image.BICUBIC)
work_dir = Path(tempfile.mkdtemp(prefix="rife_run_"))
p1 = work_dir / "a.png"
p2 = work_dir / "b.png"
a.save(p1, "PNG")
b.save(p2, "PNG")
# Clean previous outputs
if OUTPUT_DIR.exists():
shutil.rmtree(OUTPUT_DIR)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# Run RIFE inference
cmd = ["python3", str(RIFE_DIR / "inference_img.py"), "--img", str(p1), str(p2)]
if isinstance(exp, int) and exp >= 1:
cmd += ["--exp", str(exp)]
subprocess.run(cmd, cwd=str(RIFE_DIR), check=True)
# Collect interpolated frames
frames = []
i = 1
while True:
fp = OUTPUT_DIR / f"img{i}.png"
if not fp.exists():
break
frames.append(fp)
i += 1
if not frames:
raise RuntimeError("No frames generated.")
# Build GIF
images = [Image.open(p).convert("RGBA") for p in frames]
duration_ms = max(1, int(1000 / max(1, fps)))
gif_path = work_dir / "interpolation.gif"
images[0].save(
gif_path,
save_all=True,
append_images=images[1:],
optimize=False,
duration=duration_ms,
loop=0,
disposal=2,
)
# Optional cleanup
try:
shutil.rmtree(OUTPUT_DIR)
except Exception:
pass
return str(gif_path)
# Gradio UI
TITLE = "πŸ”₯ RIFE Interpolation Demo (PyTorch, Local)"
with gr.Blocks(title=TITLE, analytics_enabled=False) as demo:
gr.Markdown(f"# {TITLE}")
with gr.Row():
with gr.Column():
img_a = gr.Image(type="pil", label="Image A")
img_b = gr.Image(type="pil", label="Image B")
with gr.Column():
fps = gr.Slider(6, 30, value=14, step=1, label="FPS")
exp = gr.Slider(1, 4, value=2, step=1, label="Interpolation exponent")
run = gr.Button("Interpolate", variant="primary")
gif_out = gr.Image(type="filepath", label="Result GIF")
run.click(interpolate, inputs=[img_a, img_b, fps, exp], outputs=[gif_out])
demo.queue(concurrency_count=1, max_size=8)
if __name__ == "__main__":
demo.launch()