linoyts's picture
linoyts HF Staff
add lora gallery
ab50362 verified
raw
history blame
32.6 kB
import os
import gradio as gr
import json
import logging # Not strictly used from app (2) but good practice
import torch
from PIL import Image
import spaces
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
from diffusers.utils import export_to_video, load_video, load_image # load_image was also in app (2)
from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
import copy # Not strictly used from app (2) but kept if needed later
import random
import numpy as np
import imageio
import time
import re
--- LoRA related: Load LoRAs from JSON file ---
try:
with open('loras.json', 'r') as f:
loras = json.load(f)
except FileNotFoundError:
print("WARNING: loras.json not found. LoRA gallery will be empty or non-functional.")
print("Please create loras.json with entries like: [{'title': 'My LTX LoRA', 'repo': 'user/repo', 'weights': 'lora.safetensors', 'trigger_word': 'my style', 'image': 'url_to_image.jpg'}]")
loras = []
except json.JSONDecodeError:
print("WARNING: loras.json is not valid JSON. LoRA gallery will be empty or non-functional.")
loras = []
# Initialize the base model
dtype = torch.bfloat16 # Assuming LTX uses bfloat16 as per original app (1)
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- Original app (1) pipeline setup ---
pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.7-distilled", torch_dtype=dtype)
pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("Lightricks/LTX-Video-0.9.7-Latent-Spatial-Upsampler-diffusers", vae=pipe.vae, torch_dtype=dtype)
pipe.to(device)
pipe_upsample.to(device)
pipe.vae.enable_tiling()
MAX_SEED = np.iinfo(np.int32).max # from app (1)
# MAX_SEED_APP2 = 2**32-1 # from app (2), slightly different, stick to app (1)'s for consistency here.
MAX_IMAGE_SIZE = 1280
MAX_NUM_FRAMES = 257
FPS = 30.0
MIN_DIM_SLIDER = 256
TARGET_FIXED_SIDE = 768
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
def update_lora_selection(evt: gr.SelectData):
if not loras or evt.index is None or evt.index >= len(loras):
return gr.update(), None # No update to markdown, no selected index
selected_lora_item = loras[evt.index]
# new_placeholder = f"Type a prompt for {selected_lora_item['title']}" # Not updating placeholders directly
lora_repo = selected_lora_item["repo"]
updated_text = f"### Selected LoRA: [{selected_lora_item['title']}](https://huggingface.co/{lora_repo}) ✨"
if selected_lora_item.get('trigger_word'):
updated_text += f"\nTrigger word: `{selected_lora_item['trigger_word']}`"
# No width/height adjustment to avoid conflict with app (1)'s logic
return (
# gr.update(placeholder=new_placeholder), # Not changing prompt placeholder
updated_text,
evt.index,
)
def get_huggingface_safetensors_for_ltx(link): # Renamed for clarity
split_link = link.split("/")
if len(split_link) != 2:
raise Exception("Invalid Hugging Face repository link format. Should be 'username/repository_name'.")
print(f"Repository attempted: {link}") # Use the combined link
model_card = ModelCard.load(link) # link is "username/repository_name"
base_model = model_card.data.get("base_model")
print(f"Base model from card: {base_model}")
# Validate model type for LTX
acceptable_models = {"Lightricks/LTX-Video-0.9.7-distilled"} # Key line for LTX compatibility
models_to_check = base_model if isinstance(base_model, list) else [base_model]
if not any(str(model).strip() in acceptable_models for model in models_to_check): # Ensure string comparison
raise Exception(f"Not a LoRA for a compatible LTX base model! Expected one of {acceptable_models}, found {models_to_check}")
image_path = None
if model_card.data.get("widget") and isinstance(model_card.data["widget"], list) and len(model_card.data["widget"]) > 0:
image_path = model_card.data["widget"][0].get("output", {}).get("url", None)
trigger_word = model_card.data.get("instance_prompt", "")
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
fs = HfFileSystem()
try:
list_of_files = fs.ls(link, detail=False)
safetensors_name = None
# Simplified logic: find first .safetensors, or prioritize specific names if needed
# For LoRAs, usually there's one main .safetensors file.
# The complex step-based selection from app(2) might be overkill unless LTX LoRAs follow that pattern.
# Prioritize files common for LoRAs
common_lora_filenames = ["lora.safetensors", "pytorch_lora_weights.safetensors"]
for f_common in common_lora_filenames:
if f"{link}/{f_common}" in list_of_files:
safetensors_name = f_common
break
if not safetensors_name: # Fallback to first .safetensors
for file_path in list_of_files:
filename = file_path.split("/")[-1]
if filename.endswith(".safetensors"):
safetensors_name = filename
break
if not safetensors_name: # If still not found, then raise error
raise Exception("No valid *.safetensors file found in the repository.")
if not image_url: # Fallback image search
for file_path in list_of_files:
filename = file_path.split("/")[-1]
if filename.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
image_url = f"https://huggingface.co/{link}/resolve/main/{filename}"
break
except Exception as e:
print(f"Error accessing repository or finding safetensors: {e}")
raise Exception(f"Could not validate Hugging Face repository '{link}' or find a .safetensors LoRA file.") from e
# split_link[0] is user, split_link[1] is repo_name
return split_link[1], link, safetensors_name, trigger_word, image_url
def check_custom_model_for_ltx(link_input): # Renamed for clarity
print(f"Checking a custom model on: {link_input}")
if not link_input or not isinstance(link_input, str):
raise Exception("Invalid custom LoRA input. Please provide a Hugging Face repository path (e.g., 'username/repo-name') or URL.")
link_to_check = link_input.strip()
if link_to_check.startswith("https://huggingface.co/"):
link_to_check = link_to_check.replace("https://huggingface.co/", "").split("?")[0] # Remove base URL and query params
elif link_to_check.startswith("www.huggingface.co/"):
link_to_check = link_to_check.replace("www.huggingface.co/", "").split("?")[0]
# Basic check for 'user/repo' format
if '/' not in link_to_check or len(link_to_check.split('/')) != 2:
raise Exception("Invalid Hugging Face repository path. Use 'username/repo-name' format.")
return get_huggingface_safetensors_for_ltx(link_to_check)
def add_custom_lora_for_ltx(custom_lora_path_input): # Renamed for clarity
global loras # To modify the global loras list
if custom_lora_path_input:
try:
title, repo_id, weights_filename, trigger_word, image_url = check_custom_model_for_ltx(custom_lora_path_input)
print(f"Loaded custom LoRA: {repo_id}")
# Create HTML card for display
card_html = f'''
<div class="custom_lora_card">
<span>Loaded custom LoRA:</span>
<div class="card_internal">
<img src="{image_url if image_url else 'https://huggingface.co/front/assets/huggingface_logo-noborder.svg'}" alt="{title}" style="width:80px; height:80px; object-fit:cover;" />
<div>
<h4>{title}</h4>
<small>Repo: {repo_id}<br>Weights: {weights_filename}<br>
{"Trigger: <code><b>"+trigger_word+"</code></b>" if trigger_word else "No trigger word found. If one is needed, include it in your prompt."}
</small>
</div>
</div>
</div>
'''
# Check if this LoRA (by repo_id) already exists
existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo_id), None)
new_item_data = {
"image": image_url,
"title": title,
"repo": repo_id,
"weights": weights_filename,
"trigger_word": trigger_word,
"custom": True # Mark as custom
}
if existing_item_index is not None:
loras[existing_item_index] = new_item_data # Update existing
else:
loras.append(new_item_data)
existing_item_index = len(loras) - 1
# Update gallery choices
gallery_choices = [(item.get("image", "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"), item["title"]) for item in loras]
return (
gr.update(visible=True, value=card_html),
gr.update(visible=True), # Show remove button
gr.update(choices=gallery_choices, value=None), # Update gallery, deselect
f"Custom LoRA '{title}' added. Select it from the gallery.", # Selected info text
None, # Reset selected_index state
"" # Clear custom LoRA input textbox
)
except Exception as e:
gr.Warning(f"Invalid Custom LoRA: {e}")
return gr.update(visible=True, value=f"<p style='color:red;'>Error adding LoRA: {e}</p>"), gr.update(visible=False), gr.update(), "", None, custom_lora_path_input
else: # No input
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
def remove_custom_lora_for_ltx(): # Renamed for clarity
global loras
# Remove the last added custom LoRA if it's marked (simplistic: assumes one custom at a time or last one)
# A more robust way would be to track the index of the custom LoRA being displayed.
# For now, let's find the *last* custom LoRA and remove it.
custom_lora_indices = [i for i, item in enumerate(loras) if item.get("custom")]
if custom_lora_indices:
loras.pop(custom_lora_indices[-1]) # Remove the last one marked as custom
gallery_choices = [(item.get("image", "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"), item["title"]) for item in loras]
return gr.update(visible=False, value=""), gr.update(visible=False), gr.update(choices=gallery_choices, value=None), "", None, ""
def round_to_nearest_resolution_acceptable_by_vae(height, width):
height = height - (height % pipe.vae_spatial_compression_ratio)
width = width - (width % pipe.vae_spatial_compression_ratio)
return height, width
def calculate_new_dimensions(orig_w, orig_h):
"""Calculates new dimensions maintaining aspect ratio with one side fixed to TARGET_FIXED_SIDE."""
if orig_w == 0 or orig_h == 0: return MIN_DIM_SLIDER, MIN_DIM_SLIDER # Avoid division by zero
if orig_w > orig_h: # Landscape or square
new_w = TARGET_FIXED_SIDE
new_h = int(TARGET_FIXED_SIDE * orig_h / orig_w)
else: # Portrait
new_h = TARGET_FIXED_SIDE
new_w = int(TARGET_FIXED_SIDE * orig_w / orig_h)
# Ensure dimensions are at least MIN_DIM_SLIDER
new_w = max(MIN_DIM_SLIDER, new_w)
new_h = max(MIN_DIM_SLIDER, new_h)
# Ensure divisibility by VAE compression ratio (e.g., 32)
new_h, new_w = round_to_nearest_resolution_acceptable_by_vae(new_h, new_w)
return new_h, new_w
def handle_image_upload_for_dims(image_filepath, current_h, current_w):
if not image_filepath:
return gr.update(value=current_h), gr.update(value=current_w)
try:
img = Image.open(image_filepath)
orig_w, orig_h = img.size
new_h, new_w = calculate_new_dimensions(orig_w, orig_h)
return gr.update(value=new_h), gr.update(value=new_w)
except Exception as e:
print(f"Error processing image for dimension update: {e}")
return gr.update(value=current_h), gr.update(value=current_w)
def handle_video_upload_for_dims(video_filepath, current_h, current_w):
if not video_filepath:
return gr.update(value=current_h), gr.update(value=current_w)
try:
video_filepath_str = str(video_filepath)
if not os.path.exists(video_filepath_str):
print(f"Video file path does not exist for dimension update: {video_filepath_str}")
return gr.update(value=current_h), gr.update(value=current_w)
orig_w, orig_h = -1, -1
with imageio.get_reader(video_filepath_str) as reader:
meta = reader.get_meta_data()
if 'size' in meta:
orig_w, orig_h = meta['size']
else:
try:
first_frame = reader.get_data(0)
orig_h, orig_w = first_frame.shape[0], first_frame.shape[1]
except Exception as e_frame:
print(f"Could not get video size from metadata or first frame: {e_frame}")
return gr.update(value=current_h), gr.update(value=current_w)
if orig_w == -1 or orig_h == -1:
print(f"Could not determine dimensions for video: {video_filepath_str}")
return gr.update(value=current_h), gr.update(value=current_w)
new_h, new_w = calculate_new_dimensions(orig_w, orig_h)
return gr.update(value=new_h), gr.update(value=new_w)
except Exception as e:
print(f"Error processing video for dimension update: {e} (Path: {video_filepath}, Type: {type(video_filepath)})")
return gr.update(value=current_h), gr.update(value=current_w)
def update_task_image(): return "image-to-video"
def update_task_text(): return "text-to-video"
def update_task_video(): return "video-to-video"
def get_duration(prompt, negative_prompt, image, video, height, width, mode, steps, num_frames,
frames_to_use, seed, randomize_seed, guidance_scale, duration_input, improve_texture,
# New LoRA params
selected_lora_index, lora_scale_value,
progress): # Add selected_lora_index and lora_scale_value if they affect duration
if duration_input > 7:
return 75
else:
return 60
@spaces.GPU(duration=get_duration) # Needs selected_lora_index and lora_scale_value if get_duration uses them
def generate(prompt,
negative_prompt,
image,
video,
height,
width,
mode,
steps,
num_frames_slider_val, # Renamed to avoid conflict with internal num_frames
frames_to_use,
seed,
randomize_seed,
guidance_scale,
duration_input,
improve_texture=False,
# New LoRA params
selected_lora_index=None,
lora_scale_value=0.8, # Default LoRA scale
progress=gr.Progress(track_tqdm=True)):
effective_prompt = prompt
# --- LoRA Handling ---
# Unload any existing LoRAs from main pipes first to prevent conflicts
# This should ideally be more granular if LoRAs are very large or loading is slow.
with calculateDuration("Unloading previous LoRAs"):
try:
pipe.unload_lora_weights()
print("Previous LoRAs unloaded if any.")
except Exception as e:
print(f"Note: Could not unload LoRAs (maybe none were loaded): {e}")
if selected_lora_index is not None and 0 <= selected_lora_index < len(loras):
selected_lora_data = loras[selected_lora_index]
lora_repo_id = selected_lora_data["repo"]
lora_weights_name = selected_lora_data.get("weights", None)
lora_trigger = selected_lora_data.get("trigger_word", "")
print(f"Selected LoRA: {selected_lora_data['title']} from {lora_repo_id}")
if lora_trigger:
print(f"Applying trigger word: {lora_trigger}")
if selected_lora_data.get("trigger_position") == "prepend":
effective_prompt = f"{lora_trigger} {prompt}"
else: # Default to append or if not specified
effective_prompt = f"{prompt} {lora_trigger}"
with calculateDuration(f"Loading LoRA weights for {selected_lora_data['title']}"):
try:
# Load into main generation pipe
pipe.load_lora_weights(
lora_repo_id,
weight_name=lora_weights_name,
adapter_name="active_lora" # Use a consistent adapter name
)
pipe.set_adapters(["active_lora"], adapter_weights=[lora_scale_value])
print(f"LoRA loaded into main pipe with scale {lora_scale_value}")
except Exception as e:
gr.Warning(f"Failed to load LoRA '{selected_lora_data['title']}': {e}. Proceeding without LoRA.")
print(f"Error loading LoRA: {e}")
# Ensure pipes are clean if loading failed mid-way (though unload_lora_weights should handle this)
try:
pipe.unload_lora_weights()
except: pass # Ignore errors here
else:
print("No LoRA selected or invalid index.")
# --- End LoRA Handling ---
if randomize_seed:
seed = random.randint(0, MAX_SEED)
target_frames_ideal = duration_input * FPS
target_frames_rounded = round(target_frames_ideal)
if target_frames_rounded < 1: target_frames_rounded = 1
n_val = round((float(target_frames_rounded) - 1.0) / 8.0)
actual_num_frames = int(n_val * 8 + 1)
actual_num_frames = max(9, actual_num_frames)
num_frames = min(MAX_NUM_FRAMES, actual_num_frames) # This num_frames is used by the pipe
if mode == "video-to-video" and (video is not None):
loaded_video_frames = load_video(video)[:frames_to_use]
condition_input_video = True
width, height = loaded_video_frames[0].size
# steps = 4 # This was hardcoded, let user control steps
elif mode == "image-to-video" and (image is not None):
loaded_video_frames = [load_image(image)]
width, height = loaded_video_frames[0].size
condition_input_video = True
else: # text-to-video
condition_input_video=False
loaded_video_frames = None # No video frames for pure t2v
if condition_input_video and loaded_video_frames:
condition1 = LTXVideoCondition(video=loaded_video_frames, frame_index=0)
else:
condition1 = None
expected_height, expected_width = height, width
downscale_factor = 2 / 3
downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)
timesteps_first_pass = [1000, 993, 987, 981, 975, 909, 725]
timesteps_second_pass = [1000, 909, 725, 421]
if steps == 8:
timesteps_first_pass = [1000, 993, 987, 981, 975, 909, 725, 0.03]
timesteps_second_pass = [1000, 909, 725, 421, 0]
elif 7 < steps < 8: # Non-integer steps could be an issue for these pre-defined timesteps
timesteps_first_pass = None
timesteps_second_pass = None
with calculateDuration("Main pipe generation"):
latents = pipe(
conditions=condition1,
prompt=effective_prompt, # Use prompt with trigger word
negative_prompt=negative_prompt,
width=downscaled_width,
height=downscaled_height,
num_frames=num_frames,
num_inference_steps=steps,
decode_timestep=0.05,
decode_noise_scale=0.025,
timesteps=timesteps_first_pass,
image_cond_noise_scale=0.0,
guidance_rescale=0.7,
guidance_scale=guidance_scale,
generator=torch.Generator(device=device).manual_seed(seed),
output_type="latent",
).frames
final_video_frames_np = None # Initialize
if improve_texture:
upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2 # These are internal, not user-facing W/H
with calculateDuration("Latent upscaling"):
upscaled_latents = pipe_upsample(
latents=latents,
adain_factor=1.0,
output_type="latent"
).frames
with calculateDuration("Denoising upscaled video"):
final_video_frames_np = pipe( # Using main pipe for denoising
conditions=condition1, # Re-pass condition if applicable
prompt=effective_prompt,
negative_prompt=negative_prompt,
width=upscaled_width, # Use upscaled dimensions for this pass
height=upscaled_height,
num_frames=num_frames,
guidance_scale=guidance_scale,
denoise_strength=0.999,
timesteps=timesteps_second_pass,
num_inference_steps=10, # Or make this configurable
latents=upscaled_latents,
decode_timestep=0.05,
decode_noise_scale=0.025,
image_cond_noise_scale=0.0,
guidance_rescale=0.7,
generator=torch.Generator(device=device).manual_seed(seed),
output_type="np",
).frames[0]
else: # No texture improvement, just upscale latents and decode
with calculateDuration("Latent upscaling and decoding (no improve_texture)"):
final_video_frames_np = pipe_upsample(
latents=latents,
output_type="np" # Decode directly
).frames[0]
# Video saving
video_uint8_frames = [(frame * 255).astype(np.uint8) for frame in final_video_frames_np]
output_filename = "output.mp4"
with calculateDuration("Saving video to mp4"):
with imageio.get_writer(output_filename, fps=FPS, quality=8, macro_block_size=1) as writer: # Removed bitrate=None
for frame_idx, frame_data in enumerate(video_uint8_frames):
progress((frame_idx + 1) / len(video_uint8_frames), desc="Encoding video frames...")
writer.append_data(frame_data)
return output_filename, seed # Return seed for display
# --- Gradio UI ---
css="""
#col-container { margin: 0 auto; max-width: 1000px; } /* Increased max-width for gallery */
#gallery .grid-wrap{height: 20vh !important; max-height: 250px !important;} /* From app (2), adjusted height */
.custom_lora_card { border: 1px solid #e0e0e0; border-radius: 8px; padding: 10px; margin-top: 10px; background-color: #f9f9f9; }
.card_internal { display: flex; align-items: center; }
.card_internal img { margin-right: 1em; border-radius: 4px; }
.card_internal div h4 { margin-bottom: 0.2em; }
.card_internal div small { font-size: 0.9em; color: #555; }
#lora_list_link { font-size: 90%; background: var(--block-background-fill); padding: 0.5em 1em; border-radius: 8px; display:inline-block; margin-top:10px;}
"""
with gr.Blocks(css=css, theme=gr.themes.Ocean()) as demo:
gr.Markdown("# LTX Video 0.9.7 Distilled with LoRA Explorer")
gr.Markdown("Fast high quality video generation with custom LoRA support. [Model](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled.safetensors) [GitHub](https://github.com/Lightricks/LTX-Video)")
selected_lora_index_state = gr.State(None)
with gr.Row():
with gr.Column(scale=2): # Main controls
with gr.Tab("image-to-video") as image_tab:
with gr.Group():
video_i_hidden = gr.Textbox(label="video_i", visible=False, value=None)
image_i2v = gr.Image(label="Input Image", type="filepath", sources=["upload", "clipboard"]) # Removed webcam
i2v_prompt = gr.Textbox(label="Prompt", value="The creature from the image starts to move", lines=3)
i2v_button = gr.Button("Generate Image-to-Video", variant="primary")
with gr.Tab("text-to-video") as text_tab:
with gr.Group():
image_n_hidden = gr.Textbox(label="image_n", visible=False, value=None)
video_n_hidden = gr.Textbox(label="video_n", visible=False, value=None)
t2v_prompt = gr.Textbox(label="Prompt", value="A majestic dragon flying over a medieval castle", lines=3)
t2v_button = gr.Button("Generate Text-to-Video", variant="primary")
with gr.Tab("video-to-video") as video_tab:
with gr.Group():
image_v_hidden = gr.Textbox(label="image_v", visible=False, value=None)
video_v2v = gr.Video(label="Input Video")
frames_to_use_slider = gr.Slider(label="Frames to use from input video", minimum=9, maximum=MAX_NUM_FRAMES, value=9, step=8, info="Number of initial frames for conditioning. Must be N*8+1.")
v2v_prompt = gr.Textbox(label="Prompt", value="Change the style to cinematic anime", lines=3)
v2v_button = gr.Button("Generate Video-to-Video", variant="primary")
duration_slider = gr.Slider(
label="Video Duration (seconds)", minimum=0.3, maximum=8.5, value=2, step=0.1,
info="Target video duration (0.3s to 8.5s). Actual frames depend on model constraints (multiple of 8 + 1)."
)
improve_texture_checkbox = gr.Checkbox(label="Improve Texture (multi-scale)", value=True, info="Uses a two-pass generation for better quality, but is slower.")
with gr.Column(scale=1): # LoRA Gallery and Output
selected_lora_info_markdown = gr.Markdown("No LoRA selected.")
lora_gallery_display = gr.Gallery(
# Ensure loras is a list of (image_url, title) tuples or similar
value=[(item.get("image", "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"), item["title"]) for item in loras] if loras else [],
label="LoRA Gallery",
allow_preview=True, preview=True,
columns=2, height="auto", object_fit="contain", # Adjusted for better display
elem_id="gallery"
)
with gr.Group():
custom_lora_input_path = gr.Textbox(label="Add Custom LoRA from Hugging Face", info="Path like 'username/repo-name'", placeholder="e.g., multimodalart/flux-lora-example (but for LTX!)")
gr.Markdown("[Find LTX-compatible LoRAs on Hugging Face](https://huggingface.co/models?other=base_model:Lightricks/LTX-Video-0.9.7-distilled&sort=trending)", elem_id="lora_list_link")
custom_lora_status_html = gr.HTML(visible=False) # For displaying custom LoRA card
remove_custom_lora_button = gr.Button("Remove Last Added Custom LoRA", visible=False)
output_video = gr.Video(label="Generated Video", interactive=False)
# output_seed_info = gr.Textbox(label="Seed Used", interactive=False) # Add this to show seed
gr.DeepLinkButton()
with gr.Accordion("Advanced settings", open=False):
mode_dropdown = gr.Dropdown(["text-to-video", "image-to-video", "video-to-video"], label="Task Mode", value="image-to-video", visible=False) # Keep internal
negative_prompt_textbox = gr.Textbox(label="Negative Prompt", value="worst quality, inconsistent motion, blurry, jittery, distorted", lines=2)
with gr.Row():
seed_number_input = gr.Number(label="Seed", value=0, precision=0)
randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=True)
with gr.Row():
guidance_scale_slider = gr.Slider(label="Guidance Scale (CFG)", minimum=0, maximum=10, value=1.0, step=0.1) # LTX uses low CFG
steps_slider = gr.Slider(label="Inference Steps (Main Pass)", minimum=1, maximum=30, value=7, step=1) # Default steps for LTX
# num_frames_slider = gr.Slider(label="# Frames (Debug - Overridden by Duration)", minimum=9, maximum=MAX_NUM_FRAMES, value=96, step=8, visible=False) # Hidden, as duration controls it
with gr.Row():
height_slider = gr.Slider(label="Target Height", value=512, step=pipe.vae_spatial_compression_ratio, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info=f"Must be divisible by {pipe.vae_spatial_compression_ratio}.")
width_slider = gr.Slider(label="Target Width", value=704, step=pipe.vae_spatial_compression_ratio, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info=f"Must be divisible by {pipe.vae_spatial_compression_ratio}.")
with gr.Row():
lora_scale_slider = gr.Slider(label="LoRA Scale", minimum=0.0, maximum=2.0, step=0.05, value=0.8, info="Adjusts the influence of the selected LoRA.")
# --- Event Handlers ---
image_i2v.upload(fn=handle_image_upload_for_dims, inputs=[image_i2v, height_slider, width_slider], outputs=[height_slider, width_slider])
video_v2v.upload(fn=handle_video_upload_for_dims, inputs=[video_v2v, height_slider, width_slider], outputs=[height_slider, width_slider])
video_v2v.clear(lambda cur_h, cur_w: (gr.update(value=cur_h), gr.update(value=cur_w)), inputs=[height_slider, width_slider], outputs=[height_slider, width_slider])
image_i2v.clear(lambda cur_h, cur_w: (gr.update(value=cur_h), gr.update(value=cur_w)), inputs=[height_slider, width_slider], outputs=[height_slider, width_slider])
image_tab.select(fn=update_task_image, outputs=[mode_dropdown])
text_tab.select(fn=update_task_text, outputs=[mode_dropdown])
video_tab.select(fn=update_task_video, outputs=[mode_dropdown])
# LoRA Gallery Callbacks
lora_gallery_display.select(
update_lora_selection,
outputs=[selected_lora_info_markdown, selected_lora_index_state]
)
custom_lora_input_path.submit(
add_custom_lora_for_ltx,
inputs=[custom_lora_input_path],
outputs=[custom_lora_status_html, remove_custom_lora_button, lora_gallery_display, selected_lora_info_markdown, selected_lora_index_state, custom_lora_input_path]
)
remove_custom_lora_button.click(
remove_custom_lora_for_ltx,
outputs=[custom_lora_status_html, remove_custom_lora_button, lora_gallery_display, selected_lora_info_markdown, selected_lora_index_state, custom_lora_input_path]
)
# Consolidate inputs for generate function
gen_inputs = [
negative_prompt_textbox,
# image, video (passed specifically by each button)
height_slider, width_slider, mode_dropdown, steps_slider,
gr.Number(value=96, visible=False), # placeholder for num_frames_slider_val, as it's controlled by duration
frames_to_use_slider,
seed_number_input, randomize_seed_checkbox, guidance_scale_slider, duration_slider, improve_texture_checkbox,
selected_lora_index_state, lora_scale_slider
]
t2v_button.click(fn=generate,
inputs=[t2v_prompt, image_n_hidden, video_n_hidden] + gen_inputs,
outputs=[output_video, seed_number_input]) # Added seed_number_input to outputs
i2v_button.click(fn=generate,
inputs=[i2v_prompt, image_i2v, video_i_hidden] + gen_inputs,
outputs=[output_video, seed_number_input])
v2v_button.click(fn=generate,
inputs=[v2v_prompt, image_v_hidden, video_v2v] + gen_inputs,
outputs=[output_video, seed_number_input])
demo.queue(max_size=10).launch()