Spaces:
Running
on
L4
Running
on
L4
Update app.py and navigation.py for UI enhancements and image path correction. Changed primary color theme in app.py and updated image path in navigation.py to reflect new asset.
de752a5
import subprocess | |
import sys | |
subprocess.check_call([ | |
sys.executable, "-m", "pip", "install", "-e", | |
"./extern/CUT3R/src/croco/models/curope" | |
]) | |
from typing import List, Literal | |
from pathlib import Path | |
from functools import partial | |
import spaces | |
import gradio as gr | |
import numpy as np | |
import torch | |
from omegaconf import OmegaConf | |
from modeling.pipeline import VMemPipeline | |
from diffusers.utils import export_to_video | |
from scipy.spatial.transform import Rotation, Slerp | |
from navigation import Navigator | |
from PIL import Image | |
from utils import tensor_to_pil, encode_vae_image, encode_image, get_default_intrinsics, load_img_and_K, transform_img_and_K | |
import os | |
import glob | |
CONFIG_PATH = "configs/inference/inference.yaml" | |
CONFIG = OmegaConf.load(CONFIG_PATH) | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
MODEL = VMemPipeline(CONFIG, DEVICE) | |
NAVIGATORS = [] | |
NAVIGATION_FPS = 3 | |
WIDTH = 576 | |
HEIGHT = 576 | |
IMAGE_PATHS = ['test_samples/oxford.jpg', | |
'test_samples/open_door.jpg', | |
'test_samples/living_room.jpg', | |
'test_samples/living_room_2.jpeg', | |
'test_samples/arc_de_tromphe.jpeg', | |
'test_samples/changi.jpg', | |
'test_samples/jesus.jpg',] | |
# If no images found, create placeholders | |
if not IMAGE_PATHS: | |
def create_placeholder_images(num_samples=5, height=HEIGHT, width=WIDTH): | |
"""Create placeholder images for the demo""" | |
images = [] | |
for i in range(num_samples): | |
img = np.zeros((height, width, 3), dtype=np.uint8) | |
for h in range(height): | |
for w in range(width): | |
img[h, w, 0] = int(255 * h / height) # Red gradient | |
img[h, w, 1] = int(255 * w / width) # Green gradient | |
img[h, w, 2] = int(255 * (i+1) / num_samples) # Blue varies by image | |
images.append(img) | |
return images | |
# Create placeholder video frames and poses | |
def create_placeholder_video_and_poses(num_samples=5, num_frames=1, height=HEIGHT, width=WIDTH): | |
"""Create placeholder videos and poses for the demo""" | |
videos = [] | |
poses = [] | |
for i in range(num_samples): | |
# Create a simple video (just one frame initially for each sample) | |
frames = [] | |
for j in range(num_frames): | |
# Create a gradient frame | |
img = np.zeros((height, width, 3), dtype=np.uint8) | |
for h in range(height): | |
for w in range(width): | |
img[h, w, 0] = int(255 * h / height) # Red gradient | |
img[h, w, 1] = int(255 * w / width) # Green gradient | |
img[h, w, 2] = int(255 * (i+1) / num_samples) # Blue varies by video | |
# Convert to torch tensor [C, H, W] with normalized values | |
frame = torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0 | |
frames.append(frame) | |
video = torch.stack(frames) | |
videos.append(video) | |
# Create placeholder poses (identity matrices flattened) | |
# This creates a 4x4 identity matrix flattened to match expected format | |
# pose = torch.eye(4).flatten()[:-4] # Remove last row of 4x4 matrix | |
poses.append(torch.eye(4).unsqueeze(0).repeat(num_frames, 1, 1)) | |
return videos, poses | |
first_frame_list = create_placeholder_images(num_samples=5) | |
video_list, poses_list = create_placeholder_video_and_poses(num_samples=5) | |
# Function to load image from path | |
def load_image_for_navigation(image_path): | |
"""Load image from path and prepare for navigation""" | |
# Load image and get default intrinsics | |
image, _ = load_img_and_K(image_path, None, K=None, device=DEVICE) | |
# Transform image to the target size | |
config = OmegaConf.load(CONFIG_PATH) | |
image, _ = transform_img_and_K(image, (config.model.height, config.model.width), mode="crop", K=None) | |
# Create initial video with single frame and pose | |
video = image | |
pose = torch.eye(4).unsqueeze(0) # [1, 4, 4] | |
return { | |
"image": tensor_to_pil(image), | |
"video": video, | |
"pose": pose | |
} | |
class CustomProgressBar: | |
def __init__(self, pbar): | |
self.pbar = pbar | |
def set_postfix(self, **kwargs): | |
pass | |
def __getattr__(self, attr): | |
return getattr(self.pbar, attr) | |
def get_duration_navigate_video(video: torch.Tensor, | |
poses: torch.Tensor, | |
x_angle: float, | |
y_angle: float, | |
distance: float | |
): | |
# Estimate processing time based on navigation complexity and number of frames | |
base_duration = 15 # Base duration in seconds | |
# Add time for more complex navigation operations | |
if abs(x_angle) > 20 or abs(y_angle) > 30: | |
base_duration += 10 # More time for sharp turns | |
if distance > 100: | |
base_duration += 10 # More time for longer distances | |
# Add time proportional to existing video length (more frames = more processing) | |
base_duration += min(10, len(video)) | |
return base_duration | |
def navigate_video( | |
video: torch.Tensor, | |
poses: torch.Tensor, | |
x_angle: float, | |
y_angle: float, | |
distance: float, | |
): | |
""" | |
Generate new video frames by navigating in the 3D scene. | |
This function uses the Navigator class from navigation.py to handle movement: | |
- y_angle parameter controls left/right turning (turn_left/turn_right methods) | |
- distance parameter controls forward movement (move_forward method) | |
- x_angle parameter controls vertical angle (not directly implemented in Navigator) | |
Each Navigator instance is stored based on the video session to maintain state. | |
""" | |
try: | |
# Convert first frame to PIL Image for navigator | |
initial_frame = tensor_to_pil(video[0]) | |
# Initialize the navigator for this session if not already done | |
if len(NAVIGATORS) == 0: | |
# Create a new navigator instance | |
NAVIGATORS.append(Navigator(MODEL, step_size=0.1, num_interpolation_frames=4)) | |
# Get the initial pose and convert to numpy | |
initial_pose = poses[0].cpu().numpy().reshape(4, 4) | |
# Default camera intrinsics if not available | |
initial_K = np.array(get_default_intrinsics()[0]) | |
# Initialize the navigator | |
NAVIGATORS[0].initialize(initial_frame, initial_pose, initial_K) | |
navigator = NAVIGATORS[0] | |
# Generate new frames based on navigation commands | |
new_frames = [] | |
# First handle any x-angle (vertical angle) adjustments | |
# Note: This is approximated as Navigator doesn't directly support this | |
if abs(x_angle) > 0: | |
# Implementation for x-angle could be added here | |
# For now, we'll skip this as it's not directly supported | |
pass | |
# Next handle y-angle (turning left/right) | |
if abs(y_angle) > 0: | |
# Use Navigator's turn methods | |
if y_angle > 0: | |
new_frames = navigator.turn_left(abs(y_angle//2)) | |
else: | |
new_frames = navigator.turn_right(abs(y_angle//2)) | |
# Finally handle distance (moving forward) | |
elif distance > 0: | |
# Calculate number of steps based on distance | |
steps = max(1, int(distance / 10)) | |
new_frames = navigator.move_forward(steps) | |
elif distance < 0: | |
# Handle moving backward if needed | |
steps = max(1, int(abs(distance) / 10)) | |
new_frames = navigator.move_backward(steps) | |
if not new_frames: | |
# If no new frames were generated, return the current state | |
return video, poses, tensor_to_pil(video[-1]), export_to_video([tensor_to_pil(video[i]) for i in range(len(video))], fps=NAVIGATION_FPS), [(tensor_to_pil(video[i]), f"t={i}") for i in range(len(video))] | |
# Convert PIL images to tensors | |
new_frame_tensors = [] | |
for frame in new_frames: | |
# Convert PIL Image to tensor [C, H, W] | |
frame_np = np.array(frame) / 255.0 | |
# Convert to [-1, 1] range to match the expected format | |
frame_tensor = torch.from_numpy(frame_np.transpose(2, 0, 1)).float() * 2.0 - 1.0 | |
new_frame_tensors.append(frame_tensor) | |
new_frames_tensor = torch.stack(new_frame_tensors) | |
# Get the updated camera poses from the navigator | |
current_pose = navigator.current_pose | |
new_poses = torch.from_numpy(current_pose).float().unsqueeze(0).repeat(len(new_frames), 1, 1) | |
# Reshape the poses to match the expected format | |
new_poses = new_poses.view(len(new_frames), 4, 4) | |
# Concatenate new frames and poses with existing ones | |
updated_video = torch.cat([video.cpu(), new_frames_tensor], dim=0) | |
updated_poses = torch.cat([poses.cpu(), new_poses], dim=0) | |
# Create output images for gallery | |
all_images = [(tensor_to_pil(updated_video[i]), f"t={i}") for i in range(len(updated_video))] | |
updated_video_pil = [tensor_to_pil(updated_video[i]) for i in range(len(updated_video))] | |
return ( | |
updated_video, | |
updated_poses, | |
tensor_to_pil(updated_video[-1]), # Current view | |
export_to_video(updated_video_pil, fps=NAVIGATION_FPS), # Video | |
all_images, # Gallery | |
) | |
except Exception as e: | |
print(f"Error in navigate_video: {e}") | |
gr.Warning(f"Navigation error: {e}") | |
# Return the original inputs to avoid crashes | |
current_frame = tensor_to_pil(video[-1]) if len(video) > 0 else None | |
all_frames = [(tensor_to_pil(video[i]), f"t={i}") for i in range(len(video))] | |
video_frames = [tensor_to_pil(video[i]) for i in range(len(video))] | |
video_output = export_to_video(video_frames, fps=NAVIGATION_FPS) if video_frames else None | |
return video, poses, current_frame, video_output, all_frames | |
def undo_navigation( | |
video: torch.Tensor, | |
poses: torch.Tensor, | |
): | |
""" | |
Undo the last navigation step by removing the last set of frames. | |
Uses the Navigator's undo method which in turn uses the pipeline's undo_latest_move | |
to properly handle surfels and state management. | |
""" | |
if len(NAVIGATORS) > 0: | |
navigator = NAVIGATORS[0] | |
# Call the Navigator's undo method to handle the operation | |
success = navigator.undo() | |
if success: | |
# Since the navigator has handled the frame removal internally, | |
# we need to update our video and poses tensors to match | |
updated_video = video[:len(navigator.frames)] | |
updated_poses = poses[:len(navigator.frames)] | |
# Create gallery images | |
all_images = [(tensor_to_pil(updated_video[i]), f"t={i}") for i in range(len(updated_video))] | |
return ( | |
updated_video, | |
updated_poses, | |
tensor_to_pil(updated_video[-1]), | |
export_to_video([tensor_to_pil(updated_video[i]) for i in range(len(updated_video))], fps=NAVIGATION_FPS), | |
all_images, | |
) | |
else: | |
gr.Warning("You have no moves left to undo!") | |
else: | |
gr.Warning("No navigation session available!") | |
# If undo wasn't successful or no navigator exists, return original state | |
all_images = [(tensor_to_pil(video[i]), f"t={i}") for i in range(len(video))] | |
return ( | |
video, | |
poses, | |
tensor_to_pil(video[-1]), | |
export_to_video([tensor_to_pil(video[i]) for i in range(len(video))], fps=NAVIGATION_FPS), | |
all_images, | |
) | |
def render_demonstrate( | |
s: Literal["Selection", "Generation"], | |
idx: int, | |
demonstrate_stage: gr.State, | |
demonstrate_selected_index: gr.State, | |
demonstrate_current_video: gr.State, | |
demonstrate_current_poses: gr.State | |
): | |
gr.Markdown( | |
""" | |
## Single Image → Consistent Scene Navigation | |
> #### _Select an image and navigate through the scene by controlling camera movements._ | |
""", | |
elem_classes=["task-title"] | |
) | |
match s: | |
case "Selection": | |
with gr.Group(): | |
# Add upload functionality | |
with gr.Group(elem_classes=["gradio-box"]): | |
gr.Markdown("### Upload Your Own Image") | |
gr.Markdown("_Upload an image to navigate through its 3D scene_") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
upload_image = gr.Image( | |
label="Upload an image", | |
type="filepath", | |
height=300, | |
elem_id="upload-image" | |
) | |
with gr.Column(scale=1): | |
gr.Markdown("#### Instructions:") | |
gr.Markdown("1. Upload a clear, high-quality image") | |
gr.Markdown("2. Images with distinct visual features work best") | |
gr.Markdown("3. Landscape or architectural scenes are ideal") | |
upload_btn = gr.Button("Start Navigation", variant="primary", size="lg") | |
def process_uploaded_image(image_path): | |
if image_path is None: | |
gr.Warning("Please upload an image first") | |
return "Selection", None, None, None | |
try: | |
# Load image and prepare for navigation | |
result = load_image_for_navigation(image_path) | |
# Clear any existing navigators | |
global NAVIGATORS | |
NAVIGATORS = [] | |
return ( | |
"Generation", | |
None, # No predefined index for uploaded images | |
result["video"], | |
result["pose"], | |
) | |
except Exception as e: | |
print(f"Error in process_uploaded_image: {e}") | |
gr.Warning(f"Error processing uploaded image: {e}") | |
return "Selection", None, None, None | |
upload_btn.click( | |
fn=process_uploaded_image, | |
inputs=[upload_image], | |
outputs=[demonstrate_stage, demonstrate_selected_index, demonstrate_current_video, demonstrate_current_poses] | |
) | |
gr.Markdown("### Or Choose From Our Examples") | |
# Define image captions | |
image_captions = { | |
'test_samples/oxford.jpg': 'Oxford University', | |
'test_samples/open_door.jpg': 'Bedroom Interior', | |
'test_samples/living_room.jpg': 'Living Room', | |
'test_samples/living_room_2.jpeg': 'Living Room 2', | |
'test_samples/arc_de_tromphe.jpeg': 'Arc de Triomphe', | |
'test_samples/jesus.jpg': 'Jesus College', | |
'test_samples/changi.jpg': 'Changi Airport', | |
} | |
# Load all images for the gallery with captions | |
gallery_images = [] | |
for img_path in IMAGE_PATHS: | |
try: | |
# Get caption or default to basename | |
caption = image_captions.get(img_path, os.path.basename(img_path)) | |
gallery_images.append((img_path, caption)) | |
except Exception as e: | |
print(f"Error loading image {img_path}: {e}") | |
# Show image gallery for selection | |
demonstrate_image_gallery = gr.Gallery( | |
value=gallery_images, | |
label="Select an Image to Start Navigation", | |
columns=len(gallery_images), | |
height=400, | |
allow_preview=True, | |
preview=False, | |
elem_id="navigation-gallery" | |
) | |
gr.Markdown("_Click on an image to begin navigation_") | |
def start_navigation(evt: gr.SelectData): | |
try: | |
# Get the selected image path | |
selected_path = IMAGE_PATHS[evt.index] | |
# Load image and prepare for navigation | |
result = load_image_for_navigation(selected_path) | |
# Clear any existing navigators | |
global NAVIGATORS | |
NAVIGATORS = [] | |
return ( | |
"Generation", | |
evt.index, | |
result["video"], | |
result["pose"], | |
) | |
except Exception as e: | |
print(f"Error in start_navigation: {e}") | |
gr.Warning(f"Error starting navigation: {e}") | |
return "Selection", None, None, None | |
demonstrate_image_gallery.select( | |
fn=start_navigation, | |
inputs=None, | |
outputs=[demonstrate_stage, demonstrate_selected_index, demonstrate_current_video, demonstrate_current_poses] | |
) | |
case "Generation": | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Row(): | |
demonstrate_current_view = gr.Image( | |
label="Current View", | |
width=256, | |
height=256, | |
) | |
demonstrate_video = gr.Video( | |
label="Generated Video", | |
width=256, | |
height=256, | |
autoplay=True, | |
loop=True, | |
show_share_button=True, | |
show_download_button=True, | |
) | |
demonstrate_generated_gallery = gr.Gallery( | |
value=[], | |
label="Generated Frames", | |
columns=[6], | |
) | |
# Initialize the current view with the selected image if available | |
if idx is not None: | |
try: | |
selected_path = IMAGE_PATHS[idx] | |
result = load_image_for_navigation(selected_path) | |
demonstrate_current_view.value = result["image"] | |
except Exception as e: | |
print(f"Error initializing current view: {e}") | |
with gr.Column(): | |
gr.Markdown("### Navigation Controls ↓") | |
with gr.Accordion("Instructions", open=False): | |
gr.Markdown(""" | |
- **The model will predict the next few frames based on your camera movements. Repeat the process to continue navigating through the scene.** | |
- **Use the navigation controls to move forward/backward and turn left/right.** | |
- **At the end of your navigation, you can save your camera path for later use.** | |
""") | |
# with gr.Tab("Basic", elem_id="basic-controls-tab"): | |
with gr.Group(): | |
gr.Markdown("_**Select a direction to move:**_") | |
# First row: Turn left/right | |
with gr.Row(elem_id="basic-controls"): | |
gr.Button( | |
"↰20°\nVeer", | |
size="sm", | |
min_width=0, | |
variant="primary", | |
).click( | |
fn=partial( | |
navigate_video, | |
x_angle=0, | |
y_angle=20, | |
distance=0, | |
), | |
inputs=[ | |
demonstrate_current_video, | |
demonstrate_current_poses, | |
], | |
outputs=[ | |
demonstrate_current_video, | |
demonstrate_current_poses, | |
demonstrate_current_view, | |
demonstrate_video, | |
demonstrate_generated_gallery, | |
], | |
) | |
gr.Button( | |
"↖10°\nTurn", | |
size="sm", | |
min_width=0, | |
variant="primary", | |
).click( | |
fn=partial( | |
navigate_video, | |
x_angle=0, | |
y_angle=10, | |
distance=0, | |
), | |
inputs=[ | |
demonstrate_current_video, | |
demonstrate_current_poses, | |
], | |
outputs=[ | |
demonstrate_current_video, | |
demonstrate_current_poses, | |
demonstrate_current_view, | |
demonstrate_video, | |
demonstrate_generated_gallery, | |
], | |
) | |
gr.Button( | |
"↗10°\nTurn", | |
size="sm", | |
min_width=0, | |
variant="primary", | |
).click( | |
fn=partial( | |
navigate_video, | |
x_angle=0, | |
y_angle=-10, | |
distance=0, | |
), | |
inputs=[ | |
demonstrate_current_video, | |
demonstrate_current_poses, | |
], | |
outputs=[ | |
demonstrate_current_video, | |
demonstrate_current_poses, | |
demonstrate_current_view, | |
demonstrate_video, | |
demonstrate_generated_gallery, | |
], | |
) | |
gr.Button( | |
"↱\n20° Veer", | |
size="sm", | |
min_width=0, | |
variant="primary", | |
).click( | |
fn=partial( | |
navigate_video, | |
x_angle=0, | |
y_angle=-20, | |
distance=0, | |
), | |
inputs=[ | |
demonstrate_current_video, | |
demonstrate_current_poses, | |
], | |
outputs=[ | |
demonstrate_current_video, | |
demonstrate_current_poses, | |
demonstrate_current_view, | |
demonstrate_video, | |
demonstrate_generated_gallery, | |
], | |
) | |
# Second row: Forward/Backward movement | |
with gr.Row(elem_id="forward-backward-controls"): | |
gr.Button( | |
"↓\nBackward", | |
size="sm", | |
min_width=0, | |
variant="secondary", | |
).click( | |
fn=partial( | |
navigate_video, | |
x_angle=0, | |
y_angle=0, | |
distance=-10, | |
), | |
inputs=[ | |
demonstrate_current_video, | |
demonstrate_current_poses, | |
], | |
outputs=[ | |
demonstrate_current_video, | |
demonstrate_current_poses, | |
demonstrate_current_view, | |
demonstrate_video, | |
demonstrate_generated_gallery, | |
], | |
) | |
gr.Button( | |
"↑\nForward", | |
size="sm", | |
min_width=0, | |
variant="secondary", | |
).click( | |
fn=partial( | |
navigate_video, | |
x_angle=0, | |
y_angle=0, | |
distance=10, | |
), | |
inputs=[ | |
demonstrate_current_video, | |
demonstrate_current_poses, | |
], | |
outputs=[ | |
demonstrate_current_video, | |
demonstrate_current_poses, | |
demonstrate_current_view, | |
demonstrate_video, | |
demonstrate_generated_gallery, | |
], | |
) | |
gr.Markdown("---") | |
with gr.Group(): | |
gr.Markdown("_**Navigation controls:**_") | |
with gr.Row(): | |
gr.Button("Undo Last Move", variant="huggingface").click( | |
fn=undo_navigation, | |
inputs=[demonstrate_current_video, demonstrate_current_poses], | |
outputs=[ | |
demonstrate_current_video, | |
demonstrate_current_poses, | |
demonstrate_current_view, | |
demonstrate_video, | |
demonstrate_generated_gallery, | |
], | |
) | |
# Add a function to save camera poses | |
def save_camera_poses(video, poses): | |
if len(NAVIGATORS) > 0: | |
navigator = NAVIGATORS[0] | |
# Create a directory for saved poses | |
os.makedirs("./visualization", exist_ok=True) | |
save_path = f"./visualization/transforms_{len(navigator.frames)}_frames.json" | |
navigator.save_camera_poses(save_path) | |
return gr.Info(f"Camera poses saved to {save_path}") | |
return gr.Warning("No navigation instance found") | |
gr.Button("Save Camera", variant="huggingface").click( | |
fn=save_camera_poses, | |
inputs=[demonstrate_current_video, demonstrate_current_poses], | |
outputs=[] | |
) | |
# Add a button to return to image selection | |
def reset_navigation(): | |
# Clear current navigator | |
global NAVIGATORS | |
NAVIGATORS = [] | |
return "Selection", None, None, None | |
gr.Button("Choose New Image", variant="secondary").click( | |
fn=reset_navigation, | |
inputs=[], | |
outputs=[demonstrate_stage, demonstrate_selected_index, demonstrate_current_video, demonstrate_current_poses] | |
) | |
# Create the Gradio Blocks | |
with gr.Blocks(theme=gr.themes.Base(primary_hue="blue")) as demo: | |
gr.HTML( | |
""" | |
<style> | |
[data-tab-id="task-1"], [data-tab-id="task-2"], [data-tab-id="task-3"] { | |
font-size: 16px !important; | |
font-weight: bold; | |
} | |
#page-title h1 { | |
color: #002147 !important; | |
} | |
.task-title h2 { | |
color: #004080 !important; | |
} | |
.header-button-row { | |
gap: 4px !important; | |
} | |
.header-button-row div { | |
width: 131.0px !important; | |
} | |
.header-button-column { | |
width: 131.0px !important; | |
gap: 5px !important; | |
} | |
.header-button a { | |
border: 1px solid #002147; | |
} | |
.header-button .button-icon { | |
margin-right: 8px; | |
} | |
.demo-button-column .gap { | |
gap: 5px !important; | |
} | |
#basic-controls { | |
column-gap: 0px; | |
} | |
#basic-controls-tab { | |
padding: 0px; | |
} | |
#advanced-controls-tab { | |
padding: 0px; | |
} | |
#forward-backward-controls { | |
column-gap: 0px; | |
justify-content: center; | |
margin-top: 8px; | |
} | |
#selected-demo-button { | |
color: #004080; | |
text-decoration: underline; | |
} | |
.demo-button { | |
text-align: left !important; | |
display: block !important; | |
} | |
#navigation-gallery { | |
margin-bottom: 15px; | |
} | |
#navigation-gallery .gallery-item { | |
cursor: pointer; | |
border-radius: 6px; | |
transition: transform 0.2s, box-shadow 0.2s; | |
} | |
#navigation-gallery .gallery-item:hover { | |
transform: scale(1.02); | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
} | |
#navigation-gallery .gallery-item.selected { | |
border: 3px solid #002147; | |
} | |
/* Upload image styling */ | |
#upload-image { | |
border-radius: 8px; | |
border: 2px dashed #002147; | |
padding: 10px; | |
transition: all 0.3s ease; | |
} | |
#upload-image:hover { | |
border-color: #002147; | |
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1); | |
} | |
/* Box styling */ | |
.gradio-box { | |
border-radius: 10px; | |
margin-bottom: 20px; | |
padding: 15px; | |
background-color: #002147; | |
border: 1px solid #002147; | |
} | |
/* Start Navigation button styling */ | |
button[data-testid="Start Navigation"] { | |
background-color: #004080 !important; | |
border-color: #004080 !important; | |
color: white !important; | |
} | |
button[data-testid="Start Navigation"]:hover { | |
background-color: #002147 !important; | |
border-color: #002147 !important; | |
} | |
/* Override Gradio's primary button color */ | |
.gradio-button.primary { | |
background-color: #004080 !important; | |
border-color: #004080 !important; | |
color: white !important; | |
} | |
.gradio-button.primary:hover { | |
background-color: #002147 !important; | |
border-color: #002147 !important; | |
} | |
</style> | |
""" | |
) | |
demo_idx = gr.State(value=3) | |
with gr.Sidebar(): | |
gr.Markdown("# VMem: Consistent Video Scene Generation with Surfel-Indexed View Memory", elem_id="page-title") | |
gr.Markdown( | |
"### Official Interactive Demo for [_VMem_](https://arxiv.org/abs/2502.06764) that enables interactive consistent video scene generation." | |
) | |
gr.Markdown("---") | |
gr.Markdown("#### Links ↓") | |
with gr.Row(elem_classes=["header-button-row"]): | |
with gr.Column(elem_classes=["header-button-column"], min_width=0): | |
gr.Button( | |
value="Website", | |
link="https://v-mem.github.io/", | |
icon="https://simpleicons.org/icons/googlechrome.svg", | |
elem_classes=["header-button"], | |
size="md", | |
min_width=0, | |
) | |
gr.Button( | |
value="Paper", | |
link="https://arxiv.org/abs/2502.06764", | |
icon="https://simpleicons.org/icons/arxiv.svg", | |
elem_classes=["header-button"], | |
size="md", | |
min_width=0, | |
) | |
with gr.Column(elem_classes=["header-button-column"], min_width=0): | |
gr.Button( | |
value="Code", | |
link="https://github.com/kwsong0113/diffusion-forcing-transformer", | |
icon="https://simpleicons.org/icons/github.svg", | |
elem_classes=["header-button"], | |
size="md", | |
min_width=0, | |
) | |
gr.Button( | |
value="Weights", | |
link="https://huggingface.co/liguang0115/vmem", | |
icon="https://simpleicons.org/icons/huggingface.svg", | |
elem_classes=["header-button"], | |
size="md", | |
min_width=0, | |
) | |
gr.Markdown("---") | |
gr.Markdown("This demo interface is adapted from the History-Guided Video Diffusion demo template. We thank the authors for their work.") | |
demonstrate_stage = gr.State(value="Selection") | |
demonstrate_selected_index = gr.State(value=None) | |
demonstrate_current_video = gr.State(value=None) | |
demonstrate_current_poses = gr.State(value=None) | |
def render_demo( | |
_demo_idx, _demonstrate_stage, _demonstrate_selected_index | |
): | |
match _demo_idx: | |
case 3: | |
render_demonstrate(_demonstrate_stage, _demonstrate_selected_index, demonstrate_stage, demonstrate_selected_index, demonstrate_current_video, demonstrate_current_poses) | |
if __name__ == "__main__": | |
demo.launch(debug=True, | |
share=True, | |
max_threads=1, # Limit concurrent processing | |
show_error=True, # Show detailed error messages | |
) | |