Upload generateDataDepthCycleGAN.py
Browse files
src/training/generateDataDepthCycleGAN.py
ADDED
@@ -0,0 +1,977 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import shutil
|
8 |
+
import tempfile
|
9 |
+
from PIL import Image
|
10 |
+
from huggingface_hub import HfApi, HfFolder, hf_hub_download, create_repo
|
11 |
+
import time
|
12 |
+
import random
|
13 |
+
import threading
|
14 |
+
from datetime import datetime, timedelta
|
15 |
+
from tqdm import tqdm
|
16 |
+
import concurrent.futures
|
17 |
+
import traceback
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
+
|
21 |
+
# Print all environment variables to check
|
22 |
+
print("All environment variables:")
|
23 |
+
for key, value in os.environ.items():
|
24 |
+
if "DEPTH" in key:
|
25 |
+
print(f"{key}: {value}")
|
26 |
+
|
27 |
+
# Check specific variable
|
28 |
+
depth_path = os.getenv('DEPTH_ANYTHING_V2_PATH')
|
29 |
+
print(f"DEPTH_ANYTHING_V2_PATH value: {depth_path}")
|
30 |
+
|
31 |
+
# Continue with your code
|
32 |
+
if depth_path is None:
|
33 |
+
depth_anything_path = os.path.dirname(os.path.abspath(__file__))
|
34 |
+
print(f"Environment variable not set. Using current directory: {depth_anything_path}")
|
35 |
+
else:
|
36 |
+
depth_anything_path = depth_path
|
37 |
+
print(f"Using environment variable path: {depth_anything_path}")
|
38 |
+
|
39 |
+
sys.path.append(depth_anything_path)
|
40 |
+
try:
|
41 |
+
from depth_anything_v2.dpt import DepthAnythingV2
|
42 |
+
print("Successfully imported DepthAnythingV2")
|
43 |
+
except ImportError as e:
|
44 |
+
print(f"Import error: {e}")
|
45 |
+
print(f"Contents of directory: {os.listdir(depth_anything_path)}")
|
46 |
+
if os.path.exists(os.path.join(depth_anything_path, 'depth_anything_v2')):
|
47 |
+
print(f"Contents of depth_anything_v2: {os.listdir(os.path.join(depth_anything_path, 'depth_anything_v2'))}")
|
48 |
+
|
49 |
+
# Device selection
|
50 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
51 |
+
print(f"Using device: {DEVICE}")
|
52 |
+
|
53 |
+
# Model configurations
|
54 |
+
model_configs = {
|
55 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
56 |
+
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
57 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
|
58 |
+
}
|
59 |
+
|
60 |
+
encoder2name = {
|
61 |
+
'vits': 'Small',
|
62 |
+
'vitb': 'Base',
|
63 |
+
'vitl': 'Large'
|
64 |
+
}
|
65 |
+
|
66 |
+
name2encoder = {v: k for k, v in encoder2name.items()}
|
67 |
+
|
68 |
+
# Model IDs and filenames for HuggingFace Hub
|
69 |
+
MODEL_INFO = {
|
70 |
+
'vits': {
|
71 |
+
'repo_id': 'depth-anything/Depth-Anything-V2-Small',
|
72 |
+
'filename': 'depth_anything_v2_vits.pth'
|
73 |
+
},
|
74 |
+
'vitb': {
|
75 |
+
'repo_id': 'depth-anything/Depth-Anything-V2-Base',
|
76 |
+
'filename': 'depth_anything_v2_vitb.pth'
|
77 |
+
},
|
78 |
+
'vitl': {
|
79 |
+
'repo_id': 'depth-anything/Depth-Anything-V2-Large',
|
80 |
+
'filename': 'depth_anything_v2_vitl.pth'
|
81 |
+
}
|
82 |
+
}
|
83 |
+
|
84 |
+
# Global variables for model management
|
85 |
+
current_model = None
|
86 |
+
current_encoder = None
|
87 |
+
|
88 |
+
# Global variable for live preview
|
89 |
+
live_preview_queue = []
|
90 |
+
live_preview_lock = threading.Lock()
|
91 |
+
|
92 |
+
def download_model(encoder):
|
93 |
+
"""Download the specified model from HuggingFace Hub"""
|
94 |
+
model_info = MODEL_INFO[encoder]
|
95 |
+
|
96 |
+
# Check if the file already exists in the checkpoints directory of DEPTH_ANYTHING_V2_PATH
|
97 |
+
depth_path = os.getenv('DEPTH_ANYTHING_V2_PATH')
|
98 |
+
if depth_path:
|
99 |
+
checkpoint_dir = os.path.join(depth_path, 'checkpoints')
|
100 |
+
local_file = os.path.join(checkpoint_dir, model_info['filename'])
|
101 |
+
if os.path.exists(local_file):
|
102 |
+
print(f"Using existing model file: {local_file}")
|
103 |
+
return local_file
|
104 |
+
|
105 |
+
# If not found, download it
|
106 |
+
model_path = hf_hub_download(
|
107 |
+
repo_id=model_info['repo_id'],
|
108 |
+
filename=model_info['filename'],
|
109 |
+
local_dir='checkpoints'
|
110 |
+
)
|
111 |
+
return model_path
|
112 |
+
|
113 |
+
def load_model(encoder):
|
114 |
+
"""Load the specified model"""
|
115 |
+
global current_model, current_encoder
|
116 |
+
if current_encoder != encoder:
|
117 |
+
model_path = download_model(encoder)
|
118 |
+
current_model = DepthAnythingV2(**model_configs[encoder])
|
119 |
+
current_model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
120 |
+
current_model = current_model.to(DEVICE).eval()
|
121 |
+
current_encoder = encoder
|
122 |
+
return current_model
|
123 |
+
|
124 |
+
def convert_to_bw(image):
|
125 |
+
"""Convert image to black and white"""
|
126 |
+
if isinstance(image, Image.Image):
|
127 |
+
return image.convert('L').convert('RGB')
|
128 |
+
elif isinstance(image, np.ndarray):
|
129 |
+
return cv2.cvtColor(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY), cv2.COLOR_GRAY2RGB)
|
130 |
+
return image
|
131 |
+
|
132 |
+
def blend_images(original, depth_colored, opacity=0.5, make_bw=False, depth_on_top=True):
|
133 |
+
"""Blend original image with depth map using specified opacity
|
134 |
+
opacity: 0.0 = original image only, 1.0 = depth map only
|
135 |
+
depth_on_top: If True, depth map is blended on top of original image"""
|
136 |
+
|
137 |
+
# Convert inputs to numpy arrays if needed
|
138 |
+
if isinstance(original, Image.Image):
|
139 |
+
original = np.array(original)
|
140 |
+
if isinstance(depth_colored, Image.Image):
|
141 |
+
depth_colored = np.array(depth_colored)
|
142 |
+
|
143 |
+
# Convert original to black and white if requested
|
144 |
+
if make_bw:
|
145 |
+
original = cv2.cvtColor(cv2.cvtColor(original, cv2.COLOR_RGB2GRAY), cv2.COLOR_GRAY2RGB)
|
146 |
+
|
147 |
+
# Ensure both images are float32 for blending
|
148 |
+
original = original.astype(np.float32)
|
149 |
+
depth_colored = depth_colored.astype(np.float32)
|
150 |
+
|
151 |
+
# Calculate blend based on opacity
|
152 |
+
if depth_on_top:
|
153 |
+
blended = original * (1 - opacity) + depth_colored * opacity
|
154 |
+
else:
|
155 |
+
blended = original * opacity + depth_colored * (1 - opacity)
|
156 |
+
|
157 |
+
# Clip values and convert back to uint8
|
158 |
+
blended = np.clip(blended, 0, 255).astype(np.uint8)
|
159 |
+
|
160 |
+
return blended # Return numpy array instead of PIL Image
|
161 |
+
|
162 |
+
@torch.inference_mode()
|
163 |
+
def predict_depth(image, encoder, invert_depth=False):
|
164 |
+
"""Predict depth using the selected model"""
|
165 |
+
model = load_model(encoder)
|
166 |
+
if model is None:
|
167 |
+
raise ValueError(f"Model for encoder {encoder} could not be loaded.")
|
168 |
+
|
169 |
+
# Convert to numpy array if PIL Image
|
170 |
+
if isinstance(image, Image.Image):
|
171 |
+
image = np.array(image)
|
172 |
+
|
173 |
+
# Get depth prediction
|
174 |
+
depth = model.infer_image(image)
|
175 |
+
|
176 |
+
# Ensure we have valid depth values (no NaNs or infs)
|
177 |
+
depth = np.nan_to_num(depth)
|
178 |
+
|
179 |
+
# Normalize to 0-255 range for visualization
|
180 |
+
depth_min = depth.min()
|
181 |
+
depth_max = depth.max()
|
182 |
+
|
183 |
+
if depth_max > depth_min:
|
184 |
+
# Linear normalization
|
185 |
+
depth_normalized = (depth - depth_min) / (depth_max - depth_min)
|
186 |
+
# Apply slight gamma correction to enhance visibility
|
187 |
+
depth_normalized = np.power(depth_normalized, 0.8)
|
188 |
+
# Scale to 0-255 range
|
189 |
+
depth_map = (depth_normalized * 255).astype(np.uint8)
|
190 |
+
else:
|
191 |
+
depth_map = np.zeros_like(depth, dtype=np.uint8)
|
192 |
+
|
193 |
+
# Invert if requested (after normalization)
|
194 |
+
if invert_depth:
|
195 |
+
depth_map = 255 - depth_map
|
196 |
+
|
197 |
+
return depth_map
|
198 |
+
|
199 |
+
def apply_colormap(depth, colormap=cv2.COLORMAP_TURBO, reverse_colormap=False):
|
200 |
+
"""Apply a colormap to the depth image"""
|
201 |
+
# Ensure input is a valid numpy array
|
202 |
+
if not isinstance(depth, np.ndarray):
|
203 |
+
depth = np.array(depth)
|
204 |
+
|
205 |
+
# Ensure single channel
|
206 |
+
if len(depth.shape) > 2:
|
207 |
+
depth = cv2.cvtColor(depth, cv2.COLOR_RGB2GRAY)
|
208 |
+
|
209 |
+
# Reverse depth values if requested
|
210 |
+
if reverse_colormap:
|
211 |
+
depth = 255 - depth
|
212 |
+
|
213 |
+
# Apply colormap
|
214 |
+
colored = cv2.applyColorMap(depth, colormap)
|
215 |
+
|
216 |
+
# Convert BGR to RGB
|
217 |
+
colored_rgb = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
|
218 |
+
|
219 |
+
return colored_rgb
|
220 |
+
|
221 |
+
def resize_image(image, max_size=1200):
|
222 |
+
"""Resize image if its dimensions exceed max_size"""
|
223 |
+
if max(image.size) > max_size:
|
224 |
+
ratio = max_size / max(image.size)
|
225 |
+
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
|
226 |
+
image = image.resize(new_size, Image.LANCZOS)
|
227 |
+
return image
|
228 |
+
|
229 |
+
def save_image(image, path):
|
230 |
+
"""Save PIL Image to the specified path"""
|
231 |
+
image.save(path, format="PNG")
|
232 |
+
|
233 |
+
def add_to_live_preview(original_image, depth_image):
|
234 |
+
"""Add processed images to the live preview queue"""
|
235 |
+
global live_preview_queue
|
236 |
+
with live_preview_lock:
|
237 |
+
# Keep only the most recent 10 pairs
|
238 |
+
if len(live_preview_queue) >= 10:
|
239 |
+
live_preview_queue.pop(0)
|
240 |
+
live_preview_queue.append([original_image, depth_image])
|
241 |
+
|
242 |
+
def get_live_preview():
|
243 |
+
"""Get the current live preview images"""
|
244 |
+
global live_preview_queue
|
245 |
+
with live_preview_lock:
|
246 |
+
return live_preview_queue.copy()
|
247 |
+
|
248 |
+
class ProcessProgressTracker:
|
249 |
+
"""Track progress of image processing"""
|
250 |
+
def __init__(self, total_files):
|
251 |
+
self.total_files = total_files
|
252 |
+
self.processed_files = 0
|
253 |
+
self.start_time = time.time()
|
254 |
+
self.lock = threading.Lock()
|
255 |
+
|
256 |
+
def update(self):
|
257 |
+
with self.lock:
|
258 |
+
self.processed_files += 1
|
259 |
+
elapsed = time.time() - self.start_time
|
260 |
+
files_per_sec = self.processed_files / elapsed if elapsed > 0 else 0
|
261 |
+
eta = (self.total_files - self.processed_files) / files_per_sec if files_per_sec > 0 else 0
|
262 |
+
|
263 |
+
# Only print status every 5 files or at completion
|
264 |
+
if self.processed_files % 5 == 0 or self.processed_files == self.total_files:
|
265 |
+
print(f"Processed {self.processed_files}/{self.total_files} images " +
|
266 |
+
f"({self.processed_files/self.total_files*100:.1f}%) " +
|
267 |
+
f"- {files_per_sec:.2f} imgs/sec - ETA: {timedelta(seconds=int(eta))}")
|
268 |
+
|
269 |
+
return self.processed_files, self.total_files
|
270 |
+
|
271 |
+
def process_image(args):
|
272 |
+
"""Process a single image for multi-threading"""
|
273 |
+
filename, folder_path, temp_dir, output_dir, encoder, progress_tracker, invert_depth, colormap, enable_blending, blend_opacity, make_base_bw, depth_on_top, use_colormap, reverse_colormap = args
|
274 |
+
|
275 |
+
try:
|
276 |
+
image_path = os.path.join(folder_path, filename)
|
277 |
+
|
278 |
+
# Define output paths
|
279 |
+
temp_image_path = os.path.join(temp_dir, filename)
|
280 |
+
output_image_path = os.path.join(output_dir, filename) if output_dir else None
|
281 |
+
|
282 |
+
# Process image
|
283 |
+
image = Image.open(image_path).convert('RGB')
|
284 |
+
image = resize_image(image)
|
285 |
+
image_np = np.array(image)
|
286 |
+
|
287 |
+
# Generate depth map
|
288 |
+
depth_map = predict_depth(image_np, encoder, invert_depth)
|
289 |
+
|
290 |
+
# Handle colormap and depth visualization
|
291 |
+
if use_colormap:
|
292 |
+
final_output = apply_colormap(depth_map, colormap, reverse_colormap)
|
293 |
+
else:
|
294 |
+
final_output = cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB)
|
295 |
+
|
296 |
+
# Handle blending if enabled
|
297 |
+
if enable_blending:
|
298 |
+
final_output = blend_images(
|
299 |
+
image_np,
|
300 |
+
final_output,
|
301 |
+
opacity=blend_opacity,
|
302 |
+
make_bw=make_base_bw,
|
303 |
+
depth_on_top=depth_on_top
|
304 |
+
)
|
305 |
+
|
306 |
+
final_output = Image.fromarray(final_output)
|
307 |
+
|
308 |
+
# Create depth filename
|
309 |
+
base, ext = os.path.splitext(filename)
|
310 |
+
depth_filename = f"{base}_depth{ext}"
|
311 |
+
|
312 |
+
# Save to temp dir
|
313 |
+
temp_depth_path = os.path.join(temp_dir, depth_filename)
|
314 |
+
save_image(Image.fromarray(image_np), temp_image_path)
|
315 |
+
save_image(final_output, temp_depth_path)
|
316 |
+
|
317 |
+
# Save to output dir if specified
|
318 |
+
if output_dir:
|
319 |
+
output_depth_path = os.path.join(output_dir, depth_filename)
|
320 |
+
save_image(Image.fromarray(image_np), output_image_path)
|
321 |
+
save_image(final_output, output_depth_path)
|
322 |
+
|
323 |
+
# Update live preview
|
324 |
+
add_to_live_preview(Image.fromarray(image_np), final_output)
|
325 |
+
|
326 |
+
# Update progress
|
327 |
+
progress_tracker.update()
|
328 |
+
|
329 |
+
return temp_image_path, temp_depth_path
|
330 |
+
except Exception as e:
|
331 |
+
print(f"ERROR processing image {filename}: {e}")
|
332 |
+
traceback.print_exc()
|
333 |
+
return None, None
|
334 |
+
|
335 |
+
def process_images(folder_path, encoder, output_dir=None, max_workers=1, invert_depth=False,
|
336 |
+
colormap=cv2.COLORMAP_TURBO, enable_blending=False, blend_opacity=0.0,
|
337 |
+
make_base_bw=False, depth_on_top=True, use_colormap=True, reverse_colormap=False):
|
338 |
+
"""Process all images in the folder and generate depth maps"""
|
339 |
+
images = []
|
340 |
+
depth_maps = []
|
341 |
+
temp_dir = tempfile.mkdtemp()
|
342 |
+
|
343 |
+
# Create output directory if specified
|
344 |
+
if output_dir and not os.path.exists(output_dir):
|
345 |
+
os.makedirs(output_dir, exist_ok=True)
|
346 |
+
|
347 |
+
# Clear previous live preview
|
348 |
+
global live_preview_queue
|
349 |
+
with live_preview_lock:
|
350 |
+
live_preview_queue = []
|
351 |
+
|
352 |
+
# Validate folder path
|
353 |
+
print(f"Checking folder: {folder_path}")
|
354 |
+
if not os.path.exists(folder_path):
|
355 |
+
print(f"ERROR: Folder path does not exist: {folder_path}")
|
356 |
+
return images, depth_maps, temp_dir
|
357 |
+
|
358 |
+
if not os.path.isdir(folder_path):
|
359 |
+
print(f"ERROR: Path is not a directory: {folder_path}")
|
360 |
+
return images, depth_maps, temp_dir
|
361 |
+
|
362 |
+
# List files and check for images
|
363 |
+
try:
|
364 |
+
all_files = os.listdir(folder_path)
|
365 |
+
print(f"Found {len(all_files)} items in folder")
|
366 |
+
|
367 |
+
# Count image files, excluding depth maps
|
368 |
+
image_files = [f for f in all_files
|
369 |
+
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
|
370 |
+
and not f.lower().endswith('_depth.png')
|
371 |
+
and not f.lower().endswith('_depth.jpg')
|
372 |
+
and not f.lower().endswith('_depth.jpeg')]
|
373 |
+
|
374 |
+
print(f"Found {len(image_files)} original image files (excluding depth maps)")
|
375 |
+
|
376 |
+
if len(image_files) == 0:
|
377 |
+
print("WARNING: No valid image files found in the specified folder")
|
378 |
+
print("Allowed extensions are: .png, .jpg, .jpeg")
|
379 |
+
# Print first 10 files to help debugging
|
380 |
+
if all_files:
|
381 |
+
print("First 10 files in directory:")
|
382 |
+
for f in all_files[:10]:
|
383 |
+
print(f" - {f}")
|
384 |
+
return images, depth_maps, temp_dir
|
385 |
+
|
386 |
+
except Exception as e:
|
387 |
+
print(f"ERROR accessing folder: {e}")
|
388 |
+
return images, depth_maps, temp_dir
|
389 |
+
|
390 |
+
# Setup progress tracking
|
391 |
+
progress_tracker = ProcessProgressTracker(len(image_files))
|
392 |
+
|
393 |
+
# Process images in parallel if using GPU
|
394 |
+
if DEVICE == 'cuda' and max_workers > 1:
|
395 |
+
print(f"Processing images with {max_workers} workers...")
|
396 |
+
|
397 |
+
# Fix process_args creation
|
398 |
+
process_args = [(
|
399 |
+
filename, folder_path, temp_dir, output_dir, encoder,
|
400 |
+
progress_tracker, invert_depth, colormap, enable_blending,
|
401 |
+
blend_opacity, make_base_bw, depth_on_top, use_colormap, reverse_colormap
|
402 |
+
) for filename in image_files]
|
403 |
+
|
404 |
+
# Use ThreadPoolExecutor for parallel processing
|
405 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
406 |
+
results = list(executor.map(process_image, process_args))
|
407 |
+
|
408 |
+
# Filter out any None results from errors
|
409 |
+
valid_results = [(img, depth) for img, depth in results if img is not None]
|
410 |
+
|
411 |
+
if valid_results:
|
412 |
+
images, depth_maps = zip(*valid_results)
|
413 |
+
images = list(images)
|
414 |
+
depth_maps = list(depth_maps)
|
415 |
+
else:
|
416 |
+
# Process sequentially
|
417 |
+
print("Processing images sequentially...")
|
418 |
+
for filename in image_files:
|
419 |
+
result = process_image((filename, folder_path, temp_dir, output_dir, encoder, progress_tracker, invert_depth, colormap, enable_blending, blend_opacity, make_base_bw, depth_on_top, use_colormap))
|
420 |
+
if result[0] is not None:
|
421 |
+
images.append(result[0])
|
422 |
+
depth_maps.append(result[1])
|
423 |
+
|
424 |
+
print(f"Successfully processed {len(images)} images")
|
425 |
+
return images, depth_maps, temp_dir
|
426 |
+
|
427 |
+
def exponential_backoff(retry_count, base_wait=30):
|
428 |
+
"""Calculate wait time with exponential backoff and jitter"""
|
429 |
+
wait_time = min(base_wait * (2 ** retry_count), 3600) # Cap at 1 hour
|
430 |
+
jitter = random.uniform(0.8, 1.2) # Add 20% jitter
|
431 |
+
return wait_time * jitter
|
432 |
+
|
433 |
+
def safe_upload_file(api, path_or_fileobj, path_in_repo, repo_id, token, max_retries=5):
|
434 |
+
"""Upload a file with retry logic for rate limiting"""
|
435 |
+
retry_count = 0
|
436 |
+
|
437 |
+
while retry_count < max_retries:
|
438 |
+
try:
|
439 |
+
api.upload_file(
|
440 |
+
path_or_fileobj=path_or_fileobj,
|
441 |
+
path_in_repo=path_in_repo,
|
442 |
+
repo_id=repo_id,
|
443 |
+
token=token,
|
444 |
+
repo_type="dataset"
|
445 |
+
)
|
446 |
+
return True
|
447 |
+
except Exception as e:
|
448 |
+
error_str = str(e)
|
449 |
+
if "429" in error_str and "rate-limited" in error_str:
|
450 |
+
# Progressive backoff strategy - wait longer with each retry
|
451 |
+
wait_time = (5 + retry_count * 5) * 60 # 5, 10, 15, 20, 25 minutes
|
452 |
+
|
453 |
+
retry_count += 1
|
454 |
+
print(f"Rate limited! Waiting for {wait_time/60:.1f} minutes before retry {retry_count}/{max_retries}")
|
455 |
+
time.sleep(wait_time)
|
456 |
+
else:
|
457 |
+
# For non-rate limit errors, raise the exception
|
458 |
+
print(f"Error uploading file: {e}")
|
459 |
+
raise e
|
460 |
+
|
461 |
+
print(f"Failed to upload after {max_retries} retries: {path_in_repo}")
|
462 |
+
return False
|
463 |
+
|
464 |
+
def create_resume_file(resume_dir, all_files, start_idx, repo_id):
|
465 |
+
"""Create a resume file to continue uploads later"""
|
466 |
+
os.makedirs(resume_dir, exist_ok=True)
|
467 |
+
resume_path = os.path.join(resume_dir, f"resume_{repo_id.replace('/', '_')}.txt")
|
468 |
+
|
469 |
+
with open(resume_path, "w") as f:
|
470 |
+
# Format: current_index, total_files, datetime
|
471 |
+
f.write(f"{start_idx},{len(all_files)},{datetime.now().isoformat()}\n")
|
472 |
+
|
473 |
+
# Write remaining files to upload
|
474 |
+
for idx in range(start_idx, len(all_files)):
|
475 |
+
file_path, file_name, file_type = all_files[idx]
|
476 |
+
f.write(f"{file_path}|{file_name}|{file_type}\n")
|
477 |
+
|
478 |
+
return resume_path
|
479 |
+
|
480 |
+
def upload_to_hf(images, depth_maps, repo_id, break_every=100, resume_dir="upload_resume", resume_file=None):
|
481 |
+
"""Upload images and depth maps to Hugging Face Hub with regular breaks"""
|
482 |
+
api = HfApi()
|
483 |
+
token = HfFolder.get_token()
|
484 |
+
|
485 |
+
# Create combined list of files to upload
|
486 |
+
all_files = []
|
487 |
+
|
488 |
+
# If resuming from file, read the list of files to upload
|
489 |
+
start_idx = 0
|
490 |
+
|
491 |
+
if resume_file and os.path.exists(resume_file):
|
492 |
+
print(f"Resuming upload from {resume_file}")
|
493 |
+
with open(resume_file, "r") as f:
|
494 |
+
lines = f.readlines()
|
495 |
+
header = lines[0].strip().split(",")
|
496 |
+
start_idx = int(header[0])
|
497 |
+
|
498 |
+
# Read file entries
|
499 |
+
for line in lines[1:]:
|
500 |
+
parts = line.strip().split("|")
|
501 |
+
if len(parts) == 3:
|
502 |
+
all_files.append((parts[0], parts[1], parts[2]))
|
503 |
+
|
504 |
+
print(f"Resuming upload from index {start_idx}, {len(all_files)} files remaining")
|
505 |
+
else:
|
506 |
+
# Create new file list
|
507 |
+
for i, (image_path, depth_map_path) in enumerate(zip(images, depth_maps)):
|
508 |
+
all_files.append((image_path, os.path.basename(image_path), f"pair_{i+1}_image"))
|
509 |
+
all_files.append((depth_map_path, os.path.basename(depth_map_path), f"pair_{i+1}_depth"))
|
510 |
+
|
511 |
+
total_files = len(all_files)
|
512 |
+
|
513 |
+
# Validate break interval
|
514 |
+
if break_every <= 0:
|
515 |
+
break_every = 100
|
516 |
+
|
517 |
+
# Create resume file
|
518 |
+
resume_path = create_resume_file(resume_dir, all_files, start_idx, repo_id)
|
519 |
+
print(f"Created resume file: {resume_path}")
|
520 |
+
print(f"If the upload is interrupted, you can resume using this path in the UI")
|
521 |
+
|
522 |
+
# Ensure the repository exists and is of type 'dataset'
|
523 |
+
try:
|
524 |
+
api.repo_info(repo_id=repo_id, token=token)
|
525 |
+
except Exception as e:
|
526 |
+
try:
|
527 |
+
create_repo(repo_id=repo_id, repo_type="dataset", token=token)
|
528 |
+
except Exception as create_e:
|
529 |
+
if "You already created this dataset repo" not in str(create_e):
|
530 |
+
raise create_e
|
531 |
+
|
532 |
+
print(f"Beginning upload of {total_files} files (starting at {start_idx+1})")
|
533 |
+
print(f"Will take a 3-minute break after every {break_every} files to avoid rate limiting")
|
534 |
+
|
535 |
+
# Track upload metrics
|
536 |
+
upload_start_time = time.time()
|
537 |
+
success_count = 0
|
538 |
+
|
539 |
+
# Create progress bar
|
540 |
+
progress_bar = tqdm(total=total_files, initial=start_idx, desc="Uploading",
|
541 |
+
unit="files", dynamic_ncols=True)
|
542 |
+
|
543 |
+
try:
|
544 |
+
# Process files with periodic breaks
|
545 |
+
for idx in range(start_idx, total_files):
|
546 |
+
file_path, file_name, file_type = all_files[idx]
|
547 |
+
|
548 |
+
# Take a break every break_every files (but not at the start)
|
549 |
+
if idx > start_idx and (idx - start_idx) % break_every == 0:
|
550 |
+
break_minutes = 3
|
551 |
+
|
552 |
+
# Longer break after known problematic thresholds
|
553 |
+
if idx >= 125 and idx < 130:
|
554 |
+
break_minutes = 15
|
555 |
+
tqdm.write(f"===== EXTENDED RATE LIMIT PREVENTION BREAK =====")
|
556 |
+
tqdm.write(f"Approaching critical threshold (files 125-130). Taking a longer {break_minutes}-minute break...")
|
557 |
+
else:
|
558 |
+
tqdm.write(f"===== RATE LIMIT PREVENTION BREAK =====")
|
559 |
+
tqdm.write(f"Uploaded {break_every} files. Taking a {break_minutes}-minute break...")
|
560 |
+
|
561 |
+
create_resume_file(resume_dir, all_files, idx, repo_id)
|
562 |
+
|
563 |
+
# Show countdown timer for the break
|
564 |
+
for remaining in range(break_minutes * 60, 0, -10):
|
565 |
+
mins = remaining // 60
|
566 |
+
secs = remaining % 60
|
567 |
+
tqdm.write(f"Resuming in {mins}m {secs}s...")
|
568 |
+
time.sleep(10)
|
569 |
+
|
570 |
+
tqdm.write("Break finished, continuing uploads...")
|
571 |
+
|
572 |
+
# Upload the file
|
573 |
+
tqdm.write(f"Uploading file {idx+1}/{total_files}: {file_name}")
|
574 |
+
success = safe_upload_file(api, file_path, file_name, repo_id, token)
|
575 |
+
|
576 |
+
if not success:
|
577 |
+
tqdm.write(f"Failed to upload {file_name} after multiple retries.")
|
578 |
+
# Update resume file with current position
|
579 |
+
create_resume_file(resume_dir, all_files, idx, repo_id)
|
580 |
+
return False
|
581 |
+
|
582 |
+
# Update tracking
|
583 |
+
success_count += 1
|
584 |
+
progress_bar.update(1)
|
585 |
+
|
586 |
+
# Update resume file every 10 uploads
|
587 |
+
if (idx + 1) % 10 == 0:
|
588 |
+
create_resume_file(resume_dir, all_files, idx + 1, repo_id)
|
589 |
+
|
590 |
+
except KeyboardInterrupt:
|
591 |
+
print("\nUpload interrupted! Creating resume file to continue later...")
|
592 |
+
create_resume_file(resume_dir, all_files, idx, repo_id)
|
593 |
+
return False
|
594 |
+
|
595 |
+
finally:
|
596 |
+
progress_bar.close()
|
597 |
+
|
598 |
+
# Calculate stats
|
599 |
+
total_time = time.time() - upload_start_time
|
600 |
+
files_per_second = success_count / total_time if total_time > 0 else 0
|
601 |
+
|
602 |
+
print(f"\nUpload completed! {success_count} files uploaded in {timedelta(seconds=int(total_time))}")
|
603 |
+
print(f"Average upload rate: {files_per_second:.2f} files/sec")
|
604 |
+
|
605 |
+
return True
|
606 |
+
|
607 |
+
def process_and_upload(folder_path, model_name, invert_depth, colormap_name, output_dir,
|
608 |
+
upload_to_hf_toggle, repo_id, break_every=100, parallel_workers=1,
|
609 |
+
resume_file=None, enable_blending=False, blend_opacity=0.0,
|
610 |
+
make_base_bw=False, depth_on_top=True, use_colormap=True, reverse_colormap=False):
|
611 |
+
"""Process images and upload them to Hugging Face or save locally"""
|
612 |
+
encoder = name2encoder[model_name]
|
613 |
+
colormap = get_colormap_by_name(colormap_name)
|
614 |
+
|
615 |
+
# If resume file is provided, only upload (skip processing)
|
616 |
+
if resume_file and os.path.exists(resume_file) and upload_to_hf_toggle:
|
617 |
+
print(f"Resuming upload from file: {resume_file}")
|
618 |
+
success = upload_to_hf([], [], repo_id, break_every=break_every, resume_file=resume_file)
|
619 |
+
return "Resume upload completed successfully" if success else "Resume upload was interrupted or failed"
|
620 |
+
|
621 |
+
# Process images
|
622 |
+
images, depth_maps, temp_dir = process_images(
|
623 |
+
folder_path,
|
624 |
+
encoder,
|
625 |
+
output_dir=output_dir,
|
626 |
+
max_workers=parallel_workers,
|
627 |
+
invert_depth=invert_depth,
|
628 |
+
colormap=colormap,
|
629 |
+
enable_blending=enable_blending,
|
630 |
+
blend_opacity=blend_opacity,
|
631 |
+
make_base_bw=make_base_bw,
|
632 |
+
depth_on_top=depth_on_top,
|
633 |
+
use_colormap=use_colormap,
|
634 |
+
reverse_colormap=reverse_colormap
|
635 |
+
)
|
636 |
+
|
637 |
+
if not images:
|
638 |
+
return "No images were processed. Check the logs for details."
|
639 |
+
|
640 |
+
# Upload to HF if selected
|
641 |
+
if upload_to_hf_toggle and repo_id:
|
642 |
+
success = upload_to_hf(images, depth_maps, repo_id, break_every=break_every)
|
643 |
+
upload_status = f"Upload {'completed successfully' if success else 'was interrupted or failed'}. "
|
644 |
+
else:
|
645 |
+
upload_status = ""
|
646 |
+
|
647 |
+
# Output status
|
648 |
+
if output_dir:
|
649 |
+
local_status = f"Images and depth maps saved to {output_dir}. "
|
650 |
+
else:
|
651 |
+
local_status = ""
|
652 |
+
|
653 |
+
# Clean up
|
654 |
+
try:
|
655 |
+
shutil.rmtree(temp_dir)
|
656 |
+
except Exception as e:
|
657 |
+
print(f"Warning: Could not clean up temp directory: {e}")
|
658 |
+
|
659 |
+
return f"{local_status}{upload_status}Successfully processed {len(images)} images."
|
660 |
+
|
661 |
+
def colormap_list():
|
662 |
+
"""Get list of available OpenCV colormaps"""
|
663 |
+
return [
|
664 |
+
"TURBO", "JET", "PARULA", "HOT", "WINTER", "RAINBOW",
|
665 |
+
"OCEAN", "SUMMER", "SPRING", "COOL", "HSV",
|
666 |
+
"PINK", "BONE", "VIRIDIS", "PLASMA", "INFERNO"
|
667 |
+
]
|
668 |
+
|
669 |
+
def get_colormap_by_name(name):
|
670 |
+
"""Convert colormap name to OpenCV enum"""
|
671 |
+
colormap_mapping = {
|
672 |
+
"TURBO": cv2.COLORMAP_TURBO,
|
673 |
+
"JET": cv2.COLORMAP_JET,
|
674 |
+
"PARULA": cv2.COLORMAP_PARULA,
|
675 |
+
"HOT": cv2.COLORMAP_HOT,
|
676 |
+
"WINTER": cv2.COLORMAP_WINTER,
|
677 |
+
"RAINBOW": cv2.COLORMAP_RAINBOW,
|
678 |
+
"OCEAN": cv2.COLORMAP_OCEAN,
|
679 |
+
"SUMMER": cv2.COLORMAP_SUMMER,
|
680 |
+
"SPRING": cv2.COLORMAP_SPRING,
|
681 |
+
"COOL": cv2.COLORMAP_COOL,
|
682 |
+
"HSV": cv2.COLORMAP_HSV,
|
683 |
+
"PINK": cv2.COLORMAP_PINK,
|
684 |
+
"BONE": cv2.COLORMAP_BONE,
|
685 |
+
"VIRIDIS": cv2.COLORMAP_VIRIDIS,
|
686 |
+
"PLASMA": cv2.COLORMAP_PLASMA,
|
687 |
+
"INFERNO": cv2.COLORMAP_INFERNO
|
688 |
+
}
|
689 |
+
return colormap_mapping.get(name, cv2.COLORMAP_TURBO)
|
690 |
+
|
691 |
+
def visualize_process(folder_path, model_name, invert_depth, colormap_name,
|
692 |
+
blend_opacity=0.0, make_base_bw=False, depth_on_top=True, sample_count=10):
|
693 |
+
"""Process a sample of images from the folder and visualize them"""
|
694 |
+
encoder = name2encoder[model_name]
|
695 |
+
colormap = get_colormap_by_name(colormap_name)
|
696 |
+
|
697 |
+
# Validate folder path
|
698 |
+
if not os.path.exists(folder_path) or not os.path.isdir(folder_path):
|
699 |
+
return []
|
700 |
+
|
701 |
+
# Get image files
|
702 |
+
image_files = [f for f in os.listdir(folder_path)
|
703 |
+
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
704 |
+
|
705 |
+
if not image_files:
|
706 |
+
return []
|
707 |
+
|
708 |
+
# Take a sample of images
|
709 |
+
if len(image_files) > sample_count:
|
710 |
+
image_files = random.sample(image_files, sample_count)
|
711 |
+
|
712 |
+
# Process images
|
713 |
+
temp_dir = tempfile.mkdtemp()
|
714 |
+
visualization = []
|
715 |
+
|
716 |
+
for filename in image_files:
|
717 |
+
try:
|
718 |
+
image_path = os.path.join(folder_path, filename)
|
719 |
+
temp_image_path = os.path.join(temp_dir, filename)
|
720 |
+
shutil.copy(image_path, temp_image_path)
|
721 |
+
|
722 |
+
image = Image.open(temp_image_path).convert('RGB')
|
723 |
+
image = resize_image(image)
|
724 |
+
image_np = np.array(image)
|
725 |
+
|
726 |
+
depth_map = predict_depth(image_np, encoder, invert_depth, blend_opacity, make_base_bw)
|
727 |
+
depth_map_colored = apply_colormap(depth_map, colormap)
|
728 |
+
|
729 |
+
depth_map_path = os.path.join(temp_dir, f"depth_{filename}")
|
730 |
+
save_image(Image.fromarray(depth_map_colored), depth_map_path)
|
731 |
+
|
732 |
+
visualization.append([image, Image.fromarray(depth_map_colored)])
|
733 |
+
print(f"Previewed {filename}")
|
734 |
+
except Exception as e:
|
735 |
+
print(f"Error processing image for preview: {e}")
|
736 |
+
|
737 |
+
# Clean up temp directory
|
738 |
+
try:
|
739 |
+
shutil.rmtree(temp_dir)
|
740 |
+
except:
|
741 |
+
pass
|
742 |
+
|
743 |
+
return visualization
|
744 |
+
|
745 |
+
def update_live_preview():
|
746 |
+
"""Update the live preview gallery"""
|
747 |
+
return get_live_preview()
|
748 |
+
|
749 |
+
# Create Gradio interface
|
750 |
+
with gr.Blocks() as demo:
|
751 |
+
gr.Markdown("# 🩻 Enhanced Depth Map Generation 🩻")
|
752 |
+
|
753 |
+
with gr.Tab("Generate Depth Maps"):
|
754 |
+
folder_input = gr.Textbox(label="Folder Path", placeholder="Enter the path to the folder with images")
|
755 |
+
|
756 |
+
with gr.Row():
|
757 |
+
model_dropdown = gr.Dropdown(
|
758 |
+
choices=["Small", "Base", "Large"],
|
759 |
+
value="Small",
|
760 |
+
label="Model Size (Small=Fastest, Large=Best Quality)"
|
761 |
+
)
|
762 |
+
|
763 |
+
parallel_workers = gr.Slider(
|
764 |
+
minimum=1,
|
765 |
+
maximum=8,
|
766 |
+
value=1 if DEVICE == 'cpu' else 2,
|
767 |
+
step=1,
|
768 |
+
label="Parallel Workers (GPU only)"
|
769 |
+
)
|
770 |
+
|
771 |
+
with gr.Row():
|
772 |
+
invert_depth = gr.Checkbox(label="Invert Depth Map", value=False)
|
773 |
+
use_colormap = gr.Checkbox(label="Use Colormap", value=True)
|
774 |
+
reverse_colormap = gr.Checkbox(label="Reverse Colormap", value=False)
|
775 |
+
colormap_dropdown = gr.Dropdown(
|
776 |
+
choices=colormap_list(),
|
777 |
+
value="TURBO",
|
778 |
+
label="Colormap Style",
|
779 |
+
interactive=True
|
780 |
+
)
|
781 |
+
|
782 |
+
use_colormap.change(
|
783 |
+
fn=lambda x: gr.update(visible=x),
|
784 |
+
inputs=[use_colormap],
|
785 |
+
outputs=colormap_dropdown
|
786 |
+
)
|
787 |
+
|
788 |
+
with gr.Accordion("Blending Options", open=False):
|
789 |
+
with gr.Row():
|
790 |
+
enable_blending = gr.Checkbox(
|
791 |
+
label="Enable Blending",
|
792 |
+
value=False,
|
793 |
+
info="Blend depth map with original image"
|
794 |
+
)
|
795 |
+
make_base_bw = gr.Checkbox(
|
796 |
+
label="Make Original B&W",
|
797 |
+
value=False,
|
798 |
+
visible=False
|
799 |
+
)
|
800 |
+
depth_on_top = gr.Checkbox(
|
801 |
+
label="Depth on Top",
|
802 |
+
value=True,
|
803 |
+
visible=False
|
804 |
+
)
|
805 |
+
|
806 |
+
with gr.Row():
|
807 |
+
blend_opacity = gr.Slider(
|
808 |
+
minimum=0.0,
|
809 |
+
maximum=1.0,
|
810 |
+
value=0.5,
|
811 |
+
step=0.1,
|
812 |
+
label="Blend Strength",
|
813 |
+
info="0 = Original only, 1 = Depth only",
|
814 |
+
visible=False
|
815 |
+
)
|
816 |
+
|
817 |
+
enable_blending.change(
|
818 |
+
fn=lambda x: {
|
819 |
+
make_base_bw: gr.update(visible=x),
|
820 |
+
depth_on_top: gr.update(visible=x),
|
821 |
+
blend_opacity: gr.update(visible=x)
|
822 |
+
},
|
823 |
+
inputs=[enable_blending],
|
824 |
+
outputs=[make_base_bw, depth_on_top, blend_opacity]
|
825 |
+
)
|
826 |
+
|
827 |
+
with gr.Row():
|
828 |
+
output_dir = gr.Textbox(
|
829 |
+
label="Local Output Directory (Optional)",
|
830 |
+
placeholder="Leave empty to not save locally, or enter path to save files"
|
831 |
+
)
|
832 |
+
|
833 |
+
with gr.Row():
|
834 |
+
upload_to_hf_toggle = gr.Checkbox(label="Upload to Hugging Face", value=True)
|
835 |
+
repo_id_input = gr.Textbox(
|
836 |
+
label="Hugging Face Repo ID",
|
837 |
+
placeholder="username/repo-name",
|
838 |
+
interactive=True
|
839 |
+
)
|
840 |
+
|
841 |
+
with gr.Row():
|
842 |
+
break_every_input = gr.Slider(
|
843 |
+
minimum=50,
|
844 |
+
maximum=200,
|
845 |
+
value=100,
|
846 |
+
step=10,
|
847 |
+
label="Break Interval (for HF upload)"
|
848 |
+
)
|
849 |
+
|
850 |
+
resume_file = gr.Textbox(
|
851 |
+
label="Resume File (Optional)",
|
852 |
+
placeholder="Leave empty for new uploads, or provide path to resume file"
|
853 |
+
)
|
854 |
+
|
855 |
+
process_button = gr.Button("Process Images", variant="primary")
|
856 |
+
output = gr.Textbox(label="Output")
|
857 |
+
|
858 |
+
# Live preview gallery
|
859 |
+
gr.Markdown("### Live Processing Preview")
|
860 |
+
live_preview = gr.Gallery(label="Processing Progress", columns=2, height=400)
|
861 |
+
refresh_button = gr.Button("Refresh Preview")
|
862 |
+
|
863 |
+
with gr.Tab("Preview"):
|
864 |
+
with gr.Row():
|
865 |
+
preview_folder = gr.Textbox(label="Folder Path", placeholder="Enter the path to preview images from")
|
866 |
+
preview_model = gr.Dropdown(
|
867 |
+
choices=["Small", "Base", "Large"],
|
868 |
+
value="Small",
|
869 |
+
label="Model Size"
|
870 |
+
)
|
871 |
+
|
872 |
+
with gr.Row():
|
873 |
+
preview_invert = gr.Checkbox(label="Invert Depth Map", value=False)
|
874 |
+
preview_colormap = gr.Dropdown(
|
875 |
+
choices=colormap_list(),
|
876 |
+
value="TURBO",
|
877 |
+
label="Colormap Style"
|
878 |
+
)
|
879 |
+
|
880 |
+
with gr.Row():
|
881 |
+
preview_blend_opacity = gr.Slider(
|
882 |
+
minimum=0.0,
|
883 |
+
maximum=1.0,
|
884 |
+
value=0.0,
|
885 |
+
step=0.1,
|
886 |
+
label="Preview Blend Opacity"
|
887 |
+
)
|
888 |
+
preview_make_bw = gr.Checkbox(
|
889 |
+
label="Make Base Image Black & White",
|
890 |
+
value=False
|
891 |
+
)
|
892 |
+
preview_depth_on_top = gr.Checkbox(
|
893 |
+
label="Depth Map on Top",
|
894 |
+
value=True
|
895 |
+
)
|
896 |
+
|
897 |
+
visualize_button = gr.Button("Generate Preview", variant="secondary")
|
898 |
+
preview_output = gr.Gallery(label="Sample Depth Maps", columns=2, height=600)
|
899 |
+
|
900 |
+
with gr.Tab("Help"):
|
901 |
+
gr.Markdown("""
|
902 |
+
## Usage Instructions
|
903 |
+
|
904 |
+
### Generate Depth Maps Tab
|
905 |
+
1. **Folder Path**: Enter the full path to the folder containing your images (PNG, JPG, JPEG)
|
906 |
+
2. **Model Size**:
|
907 |
+
- Small: Fastest processing but lowest quality
|
908 |
+
- Base: Good balance between speed and quality
|
909 |
+
- Large: Best quality but slowest processing
|
910 |
+
3. **Parallel Workers**: How many images to process simultaneously (only works with GPU)
|
911 |
+
4. **Invert Depth Map**: Toggle to invert the depth values (far objects bright, near objects dark)
|
912 |
+
5. **Colormap Style**: Choose from various color schemes for the depth visualization
|
913 |
+
6. **Local Output Directory**: Path where you want to save processed images locally
|
914 |
+
7. **Upload to Hugging Face**: Toggle whether to upload to Hugging Face Hub
|
915 |
+
8. **HF Repo ID**: Your Hugging Face username and repository name (e.g., `username/dataset-name`)
|
916 |
+
9. **Break Interval**: The script will take a 3-minute break after uploading this many files
|
917 |
+
10. **Resume File**: If your upload was interrupted, you can provide the resume file path here
|
918 |
+
|
919 |
+
### Live Preview
|
920 |
+
- During processing, a live preview will show the most recent processed images
|
921 |
+
- Click "Refresh Preview" to update the display
|
922 |
+
|
923 |
+
### Preview Tab
|
924 |
+
Quickly preview what the depth maps will look like without uploading anything.
|
925 |
+
|
926 |
+
### Important Notes
|
927 |
+
- Processing is much faster with a GPU
|
928 |
+
- If saving locally, original images and depth maps will be saved with _depth suffix
|
929 |
+
- When uploading to Hugging Face, the script takes breaks to avoid rate limits
|
930 |
+
""")
|
931 |
+
|
932 |
+
# Define event handlers
|
933 |
+
def toggle_hf_fields(upload_enabled):
|
934 |
+
return {
|
935 |
+
repo_id_input: gr.update(interactive=upload_enabled),
|
936 |
+
break_every_input: gr.update(interactive=upload_enabled),
|
937 |
+
resume_file: gr.update(interactive=upload_enabled)
|
938 |
+
}
|
939 |
+
|
940 |
+
# Connect interactive elements
|
941 |
+
upload_to_hf_toggle.change(
|
942 |
+
fn=toggle_hf_fields,
|
943 |
+
inputs=upload_to_hf_toggle,
|
944 |
+
outputs=[repo_id_input, break_every_input, resume_file]
|
945 |
+
)
|
946 |
+
|
947 |
+
# Connect buttons to functions
|
948 |
+
process_button.click(
|
949 |
+
fn=process_and_upload,
|
950 |
+
inputs=[
|
951 |
+
folder_input, model_dropdown, invert_depth, colormap_dropdown,
|
952 |
+
output_dir, upload_to_hf_toggle, repo_id_input,
|
953 |
+
break_every_input, parallel_workers, resume_file,
|
954 |
+
enable_blending, blend_opacity, make_base_bw, depth_on_top,
|
955 |
+
use_colormap, reverse_colormap # Add reverse_colormap
|
956 |
+
],
|
957 |
+
outputs=output
|
958 |
+
)
|
959 |
+
|
960 |
+
refresh_button.click(
|
961 |
+
fn=update_live_preview,
|
962 |
+
inputs=[],
|
963 |
+
outputs=live_preview
|
964 |
+
)
|
965 |
+
|
966 |
+
visualize_button.click(
|
967 |
+
fn=visualize_process,
|
968 |
+
inputs=[preview_folder, preview_model, preview_invert, preview_colormap,
|
969 |
+
preview_blend_opacity, preview_make_bw, preview_depth_on_top],
|
970 |
+
outputs=preview_output
|
971 |
+
)
|
972 |
+
|
973 |
+
# Set up the live preview - just initialize it
|
974 |
+
demo.load(lambda: [], None, live_preview)
|
975 |
+
|
976 |
+
if __name__ == "__main__":
|
977 |
+
demo.launch()
|