Spaces:
Running
Running
import os | |
import cv2 | |
import torch | |
import numpy as np | |
import gdown | |
import gradio as gr | |
import ffmpeg | |
from PIL import Image as PILImage | |
import yt_dlp | |
import shutil | |
import subprocess | |
import time | |
# Placeholder for RealWaifuUpScaler | |
from upcunet_v3 import RealWaifuUpScaler | |
# Global upscaler cache | |
upscaler_cache = {} | |
# Constants | |
PATH = os.path.dirname(os.path.abspath(__file__)) | |
MODEL_PATH = os.path.join(PATH, "weights_v3") | |
OUTPUT_PATH = os.path.join(PATH, "tmp") | |
TEMP_PATH = os.path.join(PATH, "tmp") | |
TEMP_AUDIO_PATH = os.path.join(TEMP_PATH, "audio.m4a") | |
TEMP_DOWNSCALED_PATH = os.path.join(TEMP_PATH, "downscaled_video.mp4") | |
TEMP_FPS_ADJUSTED_PATH = os.path.join(TEMP_PATH, "fps_adjusted_video.mp4") | |
FRAMES_PATH = os.path.join(TEMP_PATH, "frames") | |
INPUT_RESOLUTIONS = {"Low (640x360)": (640, 360), "Medium (854x480)": (854, 480), "High (1280x720)": (1280, 720)} | |
USE_JPEG_FRAMES = False | |
JPEG_QUALITY = 100 | |
PREFERRED_CODEC = "h264_nvenc" if torch.cuda.is_available() else "libx264" | |
FALLBACK_CODEC = "libx264" | |
NVENC_PRESET = "p1" | |
LIBX264_PRESET = "ultrafast" | |
BATCH_SIZE = 10 | |
# Create directories | |
for path in [OUTPUT_PATH, TEMP_PATH, FRAMES_PATH]: | |
os.makedirs(path, exist_ok=True) | |
# Load model files | |
model_files = [f for f in os.listdir(MODEL_PATH) if f.endswith('.pth')] | |
if not model_files: | |
raise ValueError(f"No model files found in {MODEL_PATH}.") | |
# Check FFmpeg availability | |
def check_ffmpeg(): | |
try: | |
subprocess.run(['ffmpeg', '-version'], capture_output=True, text=True, check=True) | |
codecs = subprocess.run(['ffmpeg', '-codecs'], capture_output=True, text=True) | |
if 'libx264' not in codecs.stdout: | |
raise ValueError("Required codec libx264 not available.") | |
if torch.cuda.is_available() and 'h264_nvenc' not in codecs.stdout: | |
print("h264_nvenc not available, using libx264.") | |
except FileNotFoundError: | |
raise FileNotFoundError("FFmpeg not found.") | |
check_ffmpeg() | |
# Filter models by scale factor | |
def filter_models_by_scale(scale_factor, model_files): | |
target_scale = f"up{scale_factor}x" | |
return [model for model in model_files if target_scale.lower() in model.lower()] or model_files | |
# Update model dropdown | |
def update_model_dropdown(scale_factor, model_files): | |
filtered_models = filter_models_by_scale(scale_factor, model_files) | |
return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None) | |
# Clean previous media files | |
def clean_previous_media(clear_frames=False): | |
for folder in [TEMP_PATH, OUTPUT_PATH]: | |
for file in os.listdir(folder): | |
if file.endswith(('.mp4', '.m4a', '.txt')): | |
try: | |
os.remove(os.path.join(folder, file)) | |
except OSError: | |
pass | |
if clear_frames: | |
shutil.rmtree(FRAMES_PATH, ignore_errors=True) | |
os.makedirs(FRAMES_PATH, exist_ok=True) | |
# Validate media streams | |
def validate_streams(file_path, expected_type): | |
try: | |
probe = ffmpeg.probe(file_path) | |
return any(stream['codec_type'] == expected_type for stream in probe['streams']) | |
except ffmpeg.Error: | |
return False | |
# Validate input video | |
def validate_input_video(video_path): | |
try: | |
probe = ffmpeg.probe(video_path) | |
video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None) | |
if not video_stream: | |
return False, "No video stream found." | |
width, height = int(video_stream['width']), int(video_stream['height']) | |
duration = float(probe['format']['duration']) | |
fps = eval(video_stream['r_frame_rate']) | |
frame_count = int(video_stream.get('nb_frames', duration * fps)) | |
return True, f"Resolution: {width}x{height}, Duration: {duration:.2f}s, FPS: {fps:.1f}, Frames: {frame_count}" | |
except ffmpeg.Error as e: | |
return False, f"Video probe error: {e.stderr.decode('utf-8') if e.stderr else 'Unknown error'}" | |
# Helper function for FFmpeg command with codec fallback | |
def run_ffmpeg_command(cmd, codec, preset): | |
try: | |
print(f"FFmpeg command: {cmd.compile()}") | |
cmd.run(overwrite_output=True) | |
return True | |
except ffmpeg.Error as e: | |
error_msg = e.stderr.decode('utf-8') if e.stderr else "Unknown FFmpeg error" | |
print(f"FFmpeg error with {codec}: {error_msg}") | |
return False | |
# Extract audio | |
def extract_audio(video_path, audio_output_path): | |
if not validate_streams(video_path, 'audio'): | |
print("No audio stream found in video, skipping audio extraction.") | |
return None | |
try: | |
cmd = ffmpeg.input(video_path).audio.output( | |
audio_output_path, | |
**{'c:a': 'aac', 'b:a': '96k', 'loglevel': 'error'} | |
) | |
print(f"FFmpeg extract_audio: {cmd.compile()}") | |
cmd.run(overwrite_output=True) | |
return audio_output_path if os.path.exists(audio_output_path) else None | |
except ffmpeg.Error as e: | |
error_msg = e.stderr.decode('utf-8') if e.stderr else "Unknown FFmpeg error" | |
print(f"extract_audio error: {error_msg}") | |
return None | |
# Downscale video | |
def downscale_video(video_path, output_path, scale_factor, input_resolution=None): | |
cap = cv2.VideoCapture(video_path) | |
width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
cap.release() | |
if input_resolution not in INPUT_RESOLUTIONS: | |
return video_path | |
target_width, target_height = INPUT_RESOLUTIONS[input_resolution] | |
if scale_factor >= 3 and input_resolution == "High (1280x720)": | |
target_width, target_height = INPUT_RESOLUTIONS["Medium (854x480)"] | |
aspect_ratio = width / height | |
if aspect_ratio > target_width / target_height: | |
target_height = int(target_width / aspect_ratio) | |
else: | |
target_width = int(target_height * aspect_ratio) | |
target_width += target_width % 2 | |
target_height += target_height % 2 | |
cmd = ffmpeg.input(video_path).output( | |
output_path, | |
**{ | |
'c:v': PREFERRED_CODEC, | |
'c:a': 'copy', | |
's': f'{target_width}x{target_height}', | |
'r': '24', | |
'preset': NVENC_PRESET if PREFERRED_CODEC == 'h264_nvenc' else LIBX264_PRESET, | |
'crf': '18', | |
'b:v': '2M', | |
'loglevel': 'error' | |
} | |
) | |
if run_ffmpeg_command(cmd, PREFERRED_CODEC, NVENC_PRESET): | |
return output_path if os.path.exists(output_path) else video_path | |
print(f"Falling back to {FALLBACK_CODEC}") | |
cmd = ffmpeg.input(video_path).output( | |
output_path, | |
**{ | |
'c:v': FALLBACK_CODEC, | |
'c:a': 'copy', | |
's': f'{target_width}x{target_height}', | |
'r': '24', | |
'preset': LIBX264_PRESET, | |
'crf': '18', | |
'b:v': '2M', | |
'loglevel': 'error' | |
} | |
) | |
if run_ffmpeg_command(cmd, Fallback, LIBX264_PRESET): | |
return output_path if os.path.exists(output_path) else video_path | |
print(f"downscale_video failed with both codecs") | |
return video_path | |
# Adjust FPS | |
def adjust_fps(video_path, output_path): | |
cmd = ffmpeg.input(video_path).output( | |
output_path, | |
**{ | |
'c:v': PREFERRED_CODEC, | |
'c:a': 'copy', | |
'r': '24', | |
'preset': NVENC_PRESET if PREFERRED_CODEC == 'h264_nvenc' else LIBX264_PRESET, | |
'crf': '18', | |
'b:v': '2M', | |
'loglevel': 'error' | |
} | |
) | |
if run_ffmpeg_command(cmd, PREFERRED_CODEC, NVENC_PRESET): | |
return output_path if os.path.exists(output_path) else video_path | |
print(f"Falling back to {FALLBACK_CODEC}") | |
cmd = ffmpeg.input(video_path).output( | |
output_path, | |
**{ | |
'c:v': FALLBACK_CODEC, | |
'c:a': 'copy', | |
'r': '24', | |
'preset': LIBX264_PRESET, | |
'crf': '18', | |
'b:v': '2M', | |
'loglevel': 'error' | |
} | |
) | |
if run_ffmpeg_command(cmd, Fallback, LIBX264_PRESET): | |
return output_path if os.path.exists(output_path) else video_path | |
print(f"adjust_fps failed with both codecs") | |
return video_path | |
# Post-process video | |
def post_process_video(input_path, output_path): | |
if input_path == output_path: | |
return input_path | |
cmd = ffmpeg.input(input_path).output( | |
output_path, | |
**{ | |
'c:v': PREFERRED_CODEC, | |
'f': 'mp4', | |
'b:v': '2M', | |
'crf': '18', | |
'preset': NVENC_PRESET if PREFERRED_CODEC == 'h264_nvenc' else LIBX264_PRESET, | |
'loglevel': 'error' | |
} | |
) | |
if run_ffmpeg_command(cmd, PREFERRED_CODEC, NVENC_PRESET): | |
return output_path | |
print(f"Falling back to {FALLBACK_CODEC}") | |
cmd = ffmpeg.input(input_path).output( | |
output_path, | |
**{ | |
'c:v': Fallback, | |
'f': 'mp4', | |
'b:v': '2M', | |
'crf': '18', | |
'preset': LIBX264_PRESET, | |
'loglevel': 'error' | |
} | |
) | |
if run_ffmpeg_command(cmd, Fallback, LIBX264_PRESET): | |
return output_path | |
print(f"post_process_video failed with both codecs") | |
return input_path | |
# Add audio to video | |
def add_audio_to_video(video_path, audio_path, output_path): | |
if not os.path.exists(video_path) or not os.path.exists(audio_path) or not validate_streams(video_path, 'video') or not validate_streams(audio_path, 'audio'): | |
print(f"Invalid inputs: video={video_path}, audio={audio_path}") | |
return video_path | |
video_stream = ffmpeg.input(video_path).video | |
audio_stream = ffmpeg.input(audio_path).audio | |
cmd = ffmpeg.output( | |
video_stream, audio_stream, output_path, | |
**{ | |
'c:v': PREFERRED_CODEC, | |
'c:a': 'aac', | |
'b:a': '96k', | |
'b:v': '2M', | |
'crf': '18', | |
'f': 'mp4', | |
'preset': NVENC_PRESET if PREFERRED_CODEC == 'h264_nvenc' else LIBX264_PRESET, | |
'loglevel': 'error', | |
'map': '0:v:0', | |
'map': '1:a:0' | |
} | |
) | |
if run_ffmpeg_command(cmd, PREFERRED_CODEC, NVENC_PRESET): | |
return output_path | |
print(f"Falling back to {FALLBACK_CODEC}") | |
cmd = ffmpeg.output( | |
video_stream, audio_stream, output_path, | |
**{ | |
'c:v': Fallback, | |
'c:a': 'aac', | |
'b:a': '96k', | |
'b:v': '2M', | |
'crf': '18', | |
'f': 'mp4', | |
'preset': LIBX264_PRESET, | |
'loglevel': 'error', | |
'map': '0:v:0', | |
'map': '1:a:0' | |
} | |
) | |
if run_ffmpeg_command(cmd, Fallback, LIBX264_PRESET): | |
return output_path | |
print(f"add_audio_to_video failed with both codecs") | |
return video_path | |
# Download media | |
def download_media(url, media_type, cookies_file=None, scale_factor=2, input_resolution=None): | |
output_path = os.path.join(TEMP_PATH, f"media_{media_type.lower()}.{'mp4' if media_type == 'Video' else 'png'}") | |
temp_output = os.path.join(TEMP_PATH, "temp_download.mp4") if media_type == 'Video' else output_path | |
ydl_opts = { | |
'outtmpl': temp_output, | |
'format': 'bestvideo[vcodec^=avc1]+bestaudio/best' if media_type == 'Video' else 'best', | |
'merge_output_format': 'mp4' if media_type == 'Video' else None, | |
'download_sections': '*0-60', | |
'quiet': True, | |
} | |
if cookies_file and os.path.exists(cookies_file): | |
ydl_opts['cookiefile'] = cookies_file | |
try: | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
ydl.download([url]) | |
if media_type == 'Video' and os.path.exists(temp_output): | |
temp_downscaled_path = downscale_video(temp_output, TEMP_DOWNSCALED_PATH, scale_factor, input_resolution) | |
output_path = adjust_fps(temp_downscaled_path, output_path) | |
if os.path.exists(temp_output): | |
os.remove(temp_output) | |
if os.path.exists(temp_downscaled_path) and temp_downscaled_path != output_path: | |
os.remove(temp_downscaled_path) | |
return output_path if os.path.exists(output_path) else None | |
except Exception as e: | |
print(f"download_media error: {str(e)}") | |
return None | |
finally: | |
if os.path.exists(temp_output) and temp_output != output_path: | |
try: | |
os.remove(temp_output) | |
except OSError: | |
pass | |
# Upscale image | |
def upscale_image(image_path, scale_factor, selected_model, preserve_original_size=True): | |
if not os.path.exists(image_path): | |
return f"Image not found: {image_path}", None | |
try: | |
image = cv2.imread(image_path, cv2.IMREAD_COLOR) | |
if image is None: | |
image = cv2.cvtColor(np.array(PILImage.open(image_path).convert('RGB')), cv2.COLOR_RGB2BGR) | |
height, width = image.shape[:2] | |
output_width, output_height = width * scale_factor, height * scale_factor | |
if preserve_original_size: | |
output_width, output_height = width, height | |
model_file = os.path.join(MODEL_PATH, selected_model) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
cache_key = f"{model_file}_{scale_factor}_{device}" | |
if cache_key not in upscaler_cache: | |
upscaler_cache[cache_key] = RealWaifuUpScaler(scale_factor, model_file, half=True, device=device) | |
upscaler = upscaler_cache[cache_key] | |
result = upscaler(image[:, :, [2, 1, 0]], tile_mode=3, cache_mode=2, alpha=1.0)[:, :, ::-1] | |
if result.shape[1] != output_width or result.shape[0] != output_height: | |
result = cv2.resize(result, (output_width, output_height), interpolation=cv2.INTER_LANCZOS4) | |
output_path = os.path.join(OUTPUT_PATH, f"upscaled_{scale_factor}x_{os.path.splitext(os.path.basename(image_path))[0]}.png") | |
cv2.imwrite(output_path, result, [cv2.IMWRITE_PNG_COMPRESSION, 0]) | |
return "Image upscaling completed!", output_path | |
except Exception as e: | |
return f"Upscale failed: {str(e)}", None | |
# Image processing | |
def process_image(drive_link, uploaded_file, scale_factor, selected_model, cookies_file=None, preserve_original_size=True): | |
clean_previous_media(clear_frames=True) | |
if uploaded_file: | |
file_path = uploaded_file | |
elif drive_link: | |
if "drive.google.com" in drive_link: | |
file_id = drive_link.split('/')[-2] | |
file_path = os.path.join(TEMP_PATH, "anime_image.png") | |
gdown.download(f"https://drive.google.com/uc?id={file_id}", file_path, quiet=False) | |
else: | |
file_path = download_media(drive_link, "Image", cookies_file, scale_factor) | |
if not file_path or not os.path.exists(file_path): | |
return "Image could not be downloaded.", None | |
else: | |
return "No valid image provided.", None | |
if not file_path.lower().endswith(('.png', '.jpeg', '.jpg')): | |
return "Unsupported file format.", None | |
status, output_path = upscale_image(file_path, scale_factor, selected_model, preserve_original_size) | |
return status, output_path | |
# Upscale video | |
def upscale_video(video_path, scale_factor, selected_model, original_video_name, keep_audio=True, progress=gr.Progress()): | |
if not video_path or not os.path.exists(video_path): | |
yield f"Video not found: {video_path or 'No video provided.'}", None | |
return | |
is_valid, video_info = validate_input_video(video_path) | |
if not is_valid: | |
yield f"Invalid video: {video_info}", None | |
return | |
print(f"Video info: {video_info}") | |
temp_output_path = os.path.join(OUTPUT_PATH, f"upscaled_{scale_factor}x_{original_video_name}") | |
final_output_path = os.path.join(OUTPUT_PATH, f"final_upscaled_{scale_factor}x_{original_video_name}") | |
cap = cv2.VideoCapture(video_path) | |
width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = 24 | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
print(f"Total frames: {total_frames}") | |
shutil.rmtree(FRAMES_PATH, ignore_errors=True) | |
os.makedirs(FRAMES_PATH, exist_ok=True) | |
output_width, output_height = width * scale_factor, height * scale_factor | |
output_width += output_width % 2 | |
output_height += output_height % 2 | |
audio_path = None | |
if keep_audio: | |
audio_path = extract_audio(video_path, TEMP_AUDIO_PATH) | |
if audio_path: | |
yield "Extracted audio, starting frame processing...", None | |
else: | |
yield "No audio extracted, proceeding without audio...", None | |
else: | |
yield "Skipping audio extraction...", None | |
frame_idx = 0 | |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) | |
print(f"Starting at frame: {frame_idx}") | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
video_writer = cv2.VideoWriter(temp_output_path, fourcc, fps, (output_width, output_height)) | |
if not video_writer.isOpened(): | |
yield "Failed to initialize video writer.", None | |
return | |
model_file = os.path.join(MODEL_PATH, selected_model) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
cache_key = f"{model_file}_{scale_factor}_{device}" | |
if cache_key not in upscaler_cache: | |
upscaler_cache[cache_key] = RealWaifuUpScaler(scale_factor, model_file, half=True, device=device) | |
upscaler = upscaler_cache[cache_key] | |
start_time = time.time() | |
try: | |
while cap.isOpened() and frame_idx < total_frames: | |
ret, frame = cap.read() | |
if not ret: | |
print("No more frames to read") | |
break | |
print(f"Processing frame {frame_idx}") | |
try: | |
result = upscaler(frame[:, :, [2, 1, 0]], tile_mode=3, cache_mode=2, alpha=1.0)[:, :, ::-1] | |
upscaled_frame = cv2.resize(result, (output_width, output_height), interpolation=cv2.INTER_LANCZOS4) | |
except Exception as e: | |
yield f"Error upscaling frame {frame_idx}: {str(e)}", None | |
break | |
frame_path = os.path.join(FRAMES_PATH, f"frame_{frame_idx:06d}.png") | |
cv2.imwrite(frame_path, upscaled_frame) | |
video_writer.write(upscaled_frame) | |
frame_idx += 1 | |
percentage = (frame_idx / total_frames) * 100 | |
progress(percentage/100, desc=f"Processing frame {frame_idx}/{total_frames} ({percentage:.1f}%)") | |
yield f"Processed frame {frame_idx}/{total_frames}, {percentage:.1f}% complete", None | |
if frame_idx % BATCH_SIZE == 0: | |
time.sleep(0.1) | |
except Exception as e: | |
yield f"Error during processing: {str(e)}", None | |
finally: | |
cap.release() | |
video_writer.release() | |
temp_output_path = post_process_video(temp_output_path, temp_output_path) | |
if audio_path and keep_audio: | |
final_output_path = add_audio_to_video(temp_output_path, audio_path, final_output_path) | |
if os.path.exists(temp_output_path) and final_output_path != temp_output_path: | |
os.remove(temp_output_path) | |
if os.path.exists(audio_path): | |
os.remove(audio_path) | |
else: | |
final_output_path = temp_output_path | |
total_time = time.time() - start_time | |
yield f"Video '{original_video_name}' processed successfully. Total frames: {frame_idx}, Time: {total_time:.2f}s", final_output_path | |
# Process video upscale | |
def process_video_upscale(drive_link, uploaded_file, scale_factor, selected_model, cookies_file, input_resolution=None): | |
clean_previous_media(clear_frames=True) | |
if uploaded_file: | |
file_path = uploaded_file | |
original_video_name = os.path.basename(uploaded_file) | |
if input_resolution in INPUT_RESOLUTIONS: | |
file_path = downscale_video(file_path, TEMP_DOWNSCALED_PATH, scale_factor, input_resolution) | |
file_path = adjust_fps(file_path, TEMP_FPS_ADJUSTED_PATH) | |
else: | |
file_path = adjust_fps(file_path, TEMP_FPS_ADJUSTED_PATH) | |
elif drive_link: | |
if "drive.google.com" in drive_link: | |
file_id = drive_link.split('/')[-2] | |
file_path = os.path.join(TEMP_PATH, "downloaded_video.mp4") | |
gdown.download(f"https://drive.google.com/uc?id={file_id}", file_path, quiet=False) | |
else: | |
file_path = download_media(drive_link, "Video", cookies_file, scale_factor, input_resolution) | |
original_video_name = "downloaded_video.mp4" | |
else: | |
yield "No valid video provided.", None | |
return | |
if not file_path or not os.path.exists(file_path): | |
yield "Video not found.", None | |
return | |
is_valid, video_info = validate_input_video(file_path) | |
if not is_valid: | |
yield f"Failed video invalid: {video_info}", None | |
return | |
for status, video_output in upscale_video( | |
file_path, # Fixed: Replaced video_path with file_path | |
scale_factor, | |
selected_model, | |
original_video_name, | |
keep_audio=True | |
): | |
yield status, video_output | |
# Toggle resolution visibility | |
def toggle_resolution(checkbox): | |
return gr.update(visible=checkbox) | |
# Custom CSS | |
custom_css = """ | |
body { | |
background: linear-gradient(135deg, #1e1e2f, #2a2a4a); | |
font-family: sans-serif; | |
color: #e0e0e0; | |
} | |
.gradio-container { | |
width: 100%; | |
max-width: 100%; | |
margin: 20px auto; | |
background: rgba(30, 30, 47, 0.95); | |
border-radius: 12px; | |
padding: 30px; | |
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.2); | |
} | |
h1 { | |
color: #ff3f81; | |
text-align: center; | |
font-size: 2.5rem; | |
margin-bottom: 20px; | |
} | |
.gr-button { | |
background: #ff3f81; | |
color: white; | |
border-radius: 8px; | |
padding: 12px 20px; | |
font-size: 14px; | |
border: none; | |
cursor: pointer; | |
margin: 5px 0; | |
} | |
.gr-button:hover { | |
background: #c55; | |
} | |
.output-textbox, .gr-input, .gr-textbox, .gr-dropdown { | |
background: #3a3a5a; | |
color: #e0e0e0; | |
border-radius: 8px; | |
padding: 12px; | |
border: 1px solid #555; | |
margin: 10px 0; | |
} | |
.gr-row { | |
gap: 20px; | |
margin-bottom: 20px; | |
} | |
.video-preview, .image-preview { | |
border-radius: 8px; | |
border: 2px solid #ff3f81; | |
max-height: 500px; | |
object-fit: contain; | |
margin: 10px 0; | |
} | |
@media (max-width: 768px) { | |
.gradio-container { | |
padding: 15px; | |
width: 95%; | |
} | |
h1 { | |
font-size: 2rem; | |
} | |
.gr-row { | |
flex-direction: column; | |
gap: 15px; | |
} | |
.gr-column { | |
width: 100% !important; | |
} | |
} | |
.gradio-footer, footer, .footer, a[href*='gradio'], a[href*='settings'], a[data-testid='settings-link'] { | |
display: none !important; | |
} | |
""" | |
# Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: | |
gr.Markdown("# π Anime Media Video Upscaler") | |
with gr.Tabs(): | |
with gr.TabItem("π¬ Video Upscale"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Video Input") | |
uploaded_video = gr.Video(label="Upload Video", interactive=True) | |
video_drive_link = gr.Textbox(label="Video URL") | |
with gr.Accordion("Advanced Options", open=False): | |
video_cookies = gr.File(label="Cookies File", file_types=[".txt"]) | |
video_resolution_checkbox = gr.Checkbox(label="Select Resolution", value=False) | |
video_input_resolution = gr.Dropdown( | |
choices=list(INPUT_RESOLUTIONS.keys()), | |
label="Select Resolution", | |
value=None, | |
visible=False | |
) | |
with gr.Column(): | |
gr.Markdown("### Upscale Settings") | |
video_scale_factor = gr.Dropdown(choices=[2, 3, 4], label="Scale Factor", value=2) | |
video_model = gr.Dropdown( | |
choices=filter_models_by_scale(2, model_files), | |
label="Select Model", | |
value="up2x-latest-denoise3x.pth" | |
) | |
video_process_button = gr.Button("Upscale Video", variant="primary") | |
with gr.Column(): | |
gr.Markdown("### Output") | |
video_output_text = gr.Textbox(label="Status") | |
video_output = gr.Video(label="Upscaled Video") | |
video_scale_factor.change( | |
fn=update_model_dropdown, | |
inputs=[video_scale_factor, gr.State(model_files)], | |
outputs=[video_model] | |
) | |
video_resolution_checkbox.change( | |
fn=toggle_resolution, | |
inputs=video_resolution_checkbox, | |
outputs=video_input_resolution | |
) | |
with gr.TabItem("πΈ Image Upscale"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Image Input") | |
uploaded_image = gr.Image(label="Upload Image", type="filepath", interactive=True) | |
image_drive_link = gr.Textbox(label="Image URL") | |
with gr.Accordion("Advanced Options", open=False): | |
image_cookies_file = gr.File(label="Cookies File", file_types=[".txt"]) | |
with gr.Column(): | |
gr.Markdown("### Upscale Settings") | |
image_scale_factor = gr.Dropdown(choices=[2, 3, 4], label="Scale Factor", value=2) | |
image_preserve_size = gr.Checkbox(label="Preserve Original Size", value=True) | |
image_model = gr.Dropdown( | |
choices=filter_models_by_scale(2, model_files), | |
label="Select Model", | |
value="up2x-latest-denoise3x.pth" | |
) | |
image_process_button = gr.Button("Upscale Image", variant="primary") | |
with gr.Column(): | |
gr.Markdown("### Output") | |
image_output_text = gr.Textbox(label="Status") | |
image_output = gr.Image(label="Upscaled Image", type="filepath") | |
image_scale_factor.change( | |
fn=update_model_dropdown, | |
inputs=[image_scale_factor, gr.State(model_files)], | |
outputs=[image_model] | |
) | |
video_process_button.click( | |
fn=process_video_upscale, | |
inputs=[video_drive_link, uploaded_video, video_scale_factor, video_model, video_cookies, video_input_resolution], | |
outputs=[video_output_text, video_output] | |
) | |
image_process_button.click( | |
fn=process_image, | |
inputs=[image_drive_link, uploaded_image, image_scale_factor, image_model, image_cookies_file, image_preserve_size], | |
outputs=[image_output_text, image_output] | |
) | |
demo.launch( | |
share=True, | |
show_api=False, | |
) |