Borcherding commited on
Commit
427f697
·
verified ·
1 Parent(s): 2fe8c1c

Upload generateDataDepthCycleGAN.py

Browse files
src/training/generateDataDepthCycleGAN.py ADDED
@@ -0,0 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import os
6
+ import sys
7
+ import shutil
8
+ import tempfile
9
+ from PIL import Image
10
+ from huggingface_hub import HfApi, HfFolder, hf_hub_download, create_repo
11
+ import time
12
+ import random
13
+ import threading
14
+ from datetime import datetime, timedelta
15
+ from tqdm import tqdm
16
+ import concurrent.futures
17
+ import traceback
18
+ import os
19
+ import sys
20
+
21
+ # Print all environment variables to check
22
+ print("All environment variables:")
23
+ for key, value in os.environ.items():
24
+ if "DEPTH" in key:
25
+ print(f"{key}: {value}")
26
+
27
+ # Check specific variable
28
+ depth_path = os.getenv('DEPTH_ANYTHING_V2_PATH')
29
+ print(f"DEPTH_ANYTHING_V2_PATH value: {depth_path}")
30
+
31
+ # Continue with your code
32
+ if depth_path is None:
33
+ depth_anything_path = os.path.dirname(os.path.abspath(__file__))
34
+ print(f"Environment variable not set. Using current directory: {depth_anything_path}")
35
+ else:
36
+ depth_anything_path = depth_path
37
+ print(f"Using environment variable path: {depth_anything_path}")
38
+
39
+ sys.path.append(depth_anything_path)
40
+ try:
41
+ from depth_anything_v2.dpt import DepthAnythingV2
42
+ print("Successfully imported DepthAnythingV2")
43
+ except ImportError as e:
44
+ print(f"Import error: {e}")
45
+ print(f"Contents of directory: {os.listdir(depth_anything_path)}")
46
+ if os.path.exists(os.path.join(depth_anything_path, 'depth_anything_v2')):
47
+ print(f"Contents of depth_anything_v2: {os.listdir(os.path.join(depth_anything_path, 'depth_anything_v2'))}")
48
+
49
+ # Device selection
50
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
51
+ print(f"Using device: {DEVICE}")
52
+
53
+ # Model configurations
54
+ model_configs = {
55
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
56
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
57
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
58
+ }
59
+
60
+ encoder2name = {
61
+ 'vits': 'Small',
62
+ 'vitb': 'Base',
63
+ 'vitl': 'Large'
64
+ }
65
+
66
+ name2encoder = {v: k for k, v in encoder2name.items()}
67
+
68
+ # Model IDs and filenames for HuggingFace Hub
69
+ MODEL_INFO = {
70
+ 'vits': {
71
+ 'repo_id': 'depth-anything/Depth-Anything-V2-Small',
72
+ 'filename': 'depth_anything_v2_vits.pth'
73
+ },
74
+ 'vitb': {
75
+ 'repo_id': 'depth-anything/Depth-Anything-V2-Base',
76
+ 'filename': 'depth_anything_v2_vitb.pth'
77
+ },
78
+ 'vitl': {
79
+ 'repo_id': 'depth-anything/Depth-Anything-V2-Large',
80
+ 'filename': 'depth_anything_v2_vitl.pth'
81
+ }
82
+ }
83
+
84
+ # Global variables for model management
85
+ current_model = None
86
+ current_encoder = None
87
+
88
+ # Global variable for live preview
89
+ live_preview_queue = []
90
+ live_preview_lock = threading.Lock()
91
+
92
+ def download_model(encoder):
93
+ """Download the specified model from HuggingFace Hub"""
94
+ model_info = MODEL_INFO[encoder]
95
+
96
+ # Check if the file already exists in the checkpoints directory of DEPTH_ANYTHING_V2_PATH
97
+ depth_path = os.getenv('DEPTH_ANYTHING_V2_PATH')
98
+ if depth_path:
99
+ checkpoint_dir = os.path.join(depth_path, 'checkpoints')
100
+ local_file = os.path.join(checkpoint_dir, model_info['filename'])
101
+ if os.path.exists(local_file):
102
+ print(f"Using existing model file: {local_file}")
103
+ return local_file
104
+
105
+ # If not found, download it
106
+ model_path = hf_hub_download(
107
+ repo_id=model_info['repo_id'],
108
+ filename=model_info['filename'],
109
+ local_dir='checkpoints'
110
+ )
111
+ return model_path
112
+
113
+ def load_model(encoder):
114
+ """Load the specified model"""
115
+ global current_model, current_encoder
116
+ if current_encoder != encoder:
117
+ model_path = download_model(encoder)
118
+ current_model = DepthAnythingV2(**model_configs[encoder])
119
+ current_model.load_state_dict(torch.load(model_path, map_location='cpu'))
120
+ current_model = current_model.to(DEVICE).eval()
121
+ current_encoder = encoder
122
+ return current_model
123
+
124
+ def convert_to_bw(image):
125
+ """Convert image to black and white"""
126
+ if isinstance(image, Image.Image):
127
+ return image.convert('L').convert('RGB')
128
+ elif isinstance(image, np.ndarray):
129
+ return cv2.cvtColor(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY), cv2.COLOR_GRAY2RGB)
130
+ return image
131
+
132
+ def blend_images(original, depth_colored, opacity=0.5, make_bw=False, depth_on_top=True):
133
+ """Blend original image with depth map using specified opacity
134
+ opacity: 0.0 = original image only, 1.0 = depth map only
135
+ depth_on_top: If True, depth map is blended on top of original image"""
136
+
137
+ # Convert inputs to numpy arrays if needed
138
+ if isinstance(original, Image.Image):
139
+ original = np.array(original)
140
+ if isinstance(depth_colored, Image.Image):
141
+ depth_colored = np.array(depth_colored)
142
+
143
+ # Convert original to black and white if requested
144
+ if make_bw:
145
+ original = cv2.cvtColor(cv2.cvtColor(original, cv2.COLOR_RGB2GRAY), cv2.COLOR_GRAY2RGB)
146
+
147
+ # Ensure both images are float32 for blending
148
+ original = original.astype(np.float32)
149
+ depth_colored = depth_colored.astype(np.float32)
150
+
151
+ # Calculate blend based on opacity
152
+ if depth_on_top:
153
+ blended = original * (1 - opacity) + depth_colored * opacity
154
+ else:
155
+ blended = original * opacity + depth_colored * (1 - opacity)
156
+
157
+ # Clip values and convert back to uint8
158
+ blended = np.clip(blended, 0, 255).astype(np.uint8)
159
+
160
+ return blended # Return numpy array instead of PIL Image
161
+
162
+ @torch.inference_mode()
163
+ def predict_depth(image, encoder, invert_depth=False):
164
+ """Predict depth using the selected model"""
165
+ model = load_model(encoder)
166
+ if model is None:
167
+ raise ValueError(f"Model for encoder {encoder} could not be loaded.")
168
+
169
+ # Convert to numpy array if PIL Image
170
+ if isinstance(image, Image.Image):
171
+ image = np.array(image)
172
+
173
+ # Get depth prediction
174
+ depth = model.infer_image(image)
175
+
176
+ # Ensure we have valid depth values (no NaNs or infs)
177
+ depth = np.nan_to_num(depth)
178
+
179
+ # Normalize to 0-255 range for visualization
180
+ depth_min = depth.min()
181
+ depth_max = depth.max()
182
+
183
+ if depth_max > depth_min:
184
+ # Linear normalization
185
+ depth_normalized = (depth - depth_min) / (depth_max - depth_min)
186
+ # Apply slight gamma correction to enhance visibility
187
+ depth_normalized = np.power(depth_normalized, 0.8)
188
+ # Scale to 0-255 range
189
+ depth_map = (depth_normalized * 255).astype(np.uint8)
190
+ else:
191
+ depth_map = np.zeros_like(depth, dtype=np.uint8)
192
+
193
+ # Invert if requested (after normalization)
194
+ if invert_depth:
195
+ depth_map = 255 - depth_map
196
+
197
+ return depth_map
198
+
199
+ def apply_colormap(depth, colormap=cv2.COLORMAP_TURBO, reverse_colormap=False):
200
+ """Apply a colormap to the depth image"""
201
+ # Ensure input is a valid numpy array
202
+ if not isinstance(depth, np.ndarray):
203
+ depth = np.array(depth)
204
+
205
+ # Ensure single channel
206
+ if len(depth.shape) > 2:
207
+ depth = cv2.cvtColor(depth, cv2.COLOR_RGB2GRAY)
208
+
209
+ # Reverse depth values if requested
210
+ if reverse_colormap:
211
+ depth = 255 - depth
212
+
213
+ # Apply colormap
214
+ colored = cv2.applyColorMap(depth, colormap)
215
+
216
+ # Convert BGR to RGB
217
+ colored_rgb = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
218
+
219
+ return colored_rgb
220
+
221
+ def resize_image(image, max_size=1200):
222
+ """Resize image if its dimensions exceed max_size"""
223
+ if max(image.size) > max_size:
224
+ ratio = max_size / max(image.size)
225
+ new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
226
+ image = image.resize(new_size, Image.LANCZOS)
227
+ return image
228
+
229
+ def save_image(image, path):
230
+ """Save PIL Image to the specified path"""
231
+ image.save(path, format="PNG")
232
+
233
+ def add_to_live_preview(original_image, depth_image):
234
+ """Add processed images to the live preview queue"""
235
+ global live_preview_queue
236
+ with live_preview_lock:
237
+ # Keep only the most recent 10 pairs
238
+ if len(live_preview_queue) >= 10:
239
+ live_preview_queue.pop(0)
240
+ live_preview_queue.append([original_image, depth_image])
241
+
242
+ def get_live_preview():
243
+ """Get the current live preview images"""
244
+ global live_preview_queue
245
+ with live_preview_lock:
246
+ return live_preview_queue.copy()
247
+
248
+ class ProcessProgressTracker:
249
+ """Track progress of image processing"""
250
+ def __init__(self, total_files):
251
+ self.total_files = total_files
252
+ self.processed_files = 0
253
+ self.start_time = time.time()
254
+ self.lock = threading.Lock()
255
+
256
+ def update(self):
257
+ with self.lock:
258
+ self.processed_files += 1
259
+ elapsed = time.time() - self.start_time
260
+ files_per_sec = self.processed_files / elapsed if elapsed > 0 else 0
261
+ eta = (self.total_files - self.processed_files) / files_per_sec if files_per_sec > 0 else 0
262
+
263
+ # Only print status every 5 files or at completion
264
+ if self.processed_files % 5 == 0 or self.processed_files == self.total_files:
265
+ print(f"Processed {self.processed_files}/{self.total_files} images " +
266
+ f"({self.processed_files/self.total_files*100:.1f}%) " +
267
+ f"- {files_per_sec:.2f} imgs/sec - ETA: {timedelta(seconds=int(eta))}")
268
+
269
+ return self.processed_files, self.total_files
270
+
271
+ def process_image(args):
272
+ """Process a single image for multi-threading"""
273
+ filename, folder_path, temp_dir, output_dir, encoder, progress_tracker, invert_depth, colormap, enable_blending, blend_opacity, make_base_bw, depth_on_top, use_colormap, reverse_colormap = args
274
+
275
+ try:
276
+ image_path = os.path.join(folder_path, filename)
277
+
278
+ # Define output paths
279
+ temp_image_path = os.path.join(temp_dir, filename)
280
+ output_image_path = os.path.join(output_dir, filename) if output_dir else None
281
+
282
+ # Process image
283
+ image = Image.open(image_path).convert('RGB')
284
+ image = resize_image(image)
285
+ image_np = np.array(image)
286
+
287
+ # Generate depth map
288
+ depth_map = predict_depth(image_np, encoder, invert_depth)
289
+
290
+ # Handle colormap and depth visualization
291
+ if use_colormap:
292
+ final_output = apply_colormap(depth_map, colormap, reverse_colormap)
293
+ else:
294
+ final_output = cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB)
295
+
296
+ # Handle blending if enabled
297
+ if enable_blending:
298
+ final_output = blend_images(
299
+ image_np,
300
+ final_output,
301
+ opacity=blend_opacity,
302
+ make_bw=make_base_bw,
303
+ depth_on_top=depth_on_top
304
+ )
305
+
306
+ final_output = Image.fromarray(final_output)
307
+
308
+ # Create depth filename
309
+ base, ext = os.path.splitext(filename)
310
+ depth_filename = f"{base}_depth{ext}"
311
+
312
+ # Save to temp dir
313
+ temp_depth_path = os.path.join(temp_dir, depth_filename)
314
+ save_image(Image.fromarray(image_np), temp_image_path)
315
+ save_image(final_output, temp_depth_path)
316
+
317
+ # Save to output dir if specified
318
+ if output_dir:
319
+ output_depth_path = os.path.join(output_dir, depth_filename)
320
+ save_image(Image.fromarray(image_np), output_image_path)
321
+ save_image(final_output, output_depth_path)
322
+
323
+ # Update live preview
324
+ add_to_live_preview(Image.fromarray(image_np), final_output)
325
+
326
+ # Update progress
327
+ progress_tracker.update()
328
+
329
+ return temp_image_path, temp_depth_path
330
+ except Exception as e:
331
+ print(f"ERROR processing image {filename}: {e}")
332
+ traceback.print_exc()
333
+ return None, None
334
+
335
+ def process_images(folder_path, encoder, output_dir=None, max_workers=1, invert_depth=False,
336
+ colormap=cv2.COLORMAP_TURBO, enable_blending=False, blend_opacity=0.0,
337
+ make_base_bw=False, depth_on_top=True, use_colormap=True, reverse_colormap=False):
338
+ """Process all images in the folder and generate depth maps"""
339
+ images = []
340
+ depth_maps = []
341
+ temp_dir = tempfile.mkdtemp()
342
+
343
+ # Create output directory if specified
344
+ if output_dir and not os.path.exists(output_dir):
345
+ os.makedirs(output_dir, exist_ok=True)
346
+
347
+ # Clear previous live preview
348
+ global live_preview_queue
349
+ with live_preview_lock:
350
+ live_preview_queue = []
351
+
352
+ # Validate folder path
353
+ print(f"Checking folder: {folder_path}")
354
+ if not os.path.exists(folder_path):
355
+ print(f"ERROR: Folder path does not exist: {folder_path}")
356
+ return images, depth_maps, temp_dir
357
+
358
+ if not os.path.isdir(folder_path):
359
+ print(f"ERROR: Path is not a directory: {folder_path}")
360
+ return images, depth_maps, temp_dir
361
+
362
+ # List files and check for images
363
+ try:
364
+ all_files = os.listdir(folder_path)
365
+ print(f"Found {len(all_files)} items in folder")
366
+
367
+ # Count image files, excluding depth maps
368
+ image_files = [f for f in all_files
369
+ if f.lower().endswith(('.png', '.jpg', '.jpeg'))
370
+ and not f.lower().endswith('_depth.png')
371
+ and not f.lower().endswith('_depth.jpg')
372
+ and not f.lower().endswith('_depth.jpeg')]
373
+
374
+ print(f"Found {len(image_files)} original image files (excluding depth maps)")
375
+
376
+ if len(image_files) == 0:
377
+ print("WARNING: No valid image files found in the specified folder")
378
+ print("Allowed extensions are: .png, .jpg, .jpeg")
379
+ # Print first 10 files to help debugging
380
+ if all_files:
381
+ print("First 10 files in directory:")
382
+ for f in all_files[:10]:
383
+ print(f" - {f}")
384
+ return images, depth_maps, temp_dir
385
+
386
+ except Exception as e:
387
+ print(f"ERROR accessing folder: {e}")
388
+ return images, depth_maps, temp_dir
389
+
390
+ # Setup progress tracking
391
+ progress_tracker = ProcessProgressTracker(len(image_files))
392
+
393
+ # Process images in parallel if using GPU
394
+ if DEVICE == 'cuda' and max_workers > 1:
395
+ print(f"Processing images with {max_workers} workers...")
396
+
397
+ # Fix process_args creation
398
+ process_args = [(
399
+ filename, folder_path, temp_dir, output_dir, encoder,
400
+ progress_tracker, invert_depth, colormap, enable_blending,
401
+ blend_opacity, make_base_bw, depth_on_top, use_colormap, reverse_colormap
402
+ ) for filename in image_files]
403
+
404
+ # Use ThreadPoolExecutor for parallel processing
405
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
406
+ results = list(executor.map(process_image, process_args))
407
+
408
+ # Filter out any None results from errors
409
+ valid_results = [(img, depth) for img, depth in results if img is not None]
410
+
411
+ if valid_results:
412
+ images, depth_maps = zip(*valid_results)
413
+ images = list(images)
414
+ depth_maps = list(depth_maps)
415
+ else:
416
+ # Process sequentially
417
+ print("Processing images sequentially...")
418
+ for filename in image_files:
419
+ result = process_image((filename, folder_path, temp_dir, output_dir, encoder, progress_tracker, invert_depth, colormap, enable_blending, blend_opacity, make_base_bw, depth_on_top, use_colormap))
420
+ if result[0] is not None:
421
+ images.append(result[0])
422
+ depth_maps.append(result[1])
423
+
424
+ print(f"Successfully processed {len(images)} images")
425
+ return images, depth_maps, temp_dir
426
+
427
+ def exponential_backoff(retry_count, base_wait=30):
428
+ """Calculate wait time with exponential backoff and jitter"""
429
+ wait_time = min(base_wait * (2 ** retry_count), 3600) # Cap at 1 hour
430
+ jitter = random.uniform(0.8, 1.2) # Add 20% jitter
431
+ return wait_time * jitter
432
+
433
+ def safe_upload_file(api, path_or_fileobj, path_in_repo, repo_id, token, max_retries=5):
434
+ """Upload a file with retry logic for rate limiting"""
435
+ retry_count = 0
436
+
437
+ while retry_count < max_retries:
438
+ try:
439
+ api.upload_file(
440
+ path_or_fileobj=path_or_fileobj,
441
+ path_in_repo=path_in_repo,
442
+ repo_id=repo_id,
443
+ token=token,
444
+ repo_type="dataset"
445
+ )
446
+ return True
447
+ except Exception as e:
448
+ error_str = str(e)
449
+ if "429" in error_str and "rate-limited" in error_str:
450
+ # Progressive backoff strategy - wait longer with each retry
451
+ wait_time = (5 + retry_count * 5) * 60 # 5, 10, 15, 20, 25 minutes
452
+
453
+ retry_count += 1
454
+ print(f"Rate limited! Waiting for {wait_time/60:.1f} minutes before retry {retry_count}/{max_retries}")
455
+ time.sleep(wait_time)
456
+ else:
457
+ # For non-rate limit errors, raise the exception
458
+ print(f"Error uploading file: {e}")
459
+ raise e
460
+
461
+ print(f"Failed to upload after {max_retries} retries: {path_in_repo}")
462
+ return False
463
+
464
+ def create_resume_file(resume_dir, all_files, start_idx, repo_id):
465
+ """Create a resume file to continue uploads later"""
466
+ os.makedirs(resume_dir, exist_ok=True)
467
+ resume_path = os.path.join(resume_dir, f"resume_{repo_id.replace('/', '_')}.txt")
468
+
469
+ with open(resume_path, "w") as f:
470
+ # Format: current_index, total_files, datetime
471
+ f.write(f"{start_idx},{len(all_files)},{datetime.now().isoformat()}\n")
472
+
473
+ # Write remaining files to upload
474
+ for idx in range(start_idx, len(all_files)):
475
+ file_path, file_name, file_type = all_files[idx]
476
+ f.write(f"{file_path}|{file_name}|{file_type}\n")
477
+
478
+ return resume_path
479
+
480
+ def upload_to_hf(images, depth_maps, repo_id, break_every=100, resume_dir="upload_resume", resume_file=None):
481
+ """Upload images and depth maps to Hugging Face Hub with regular breaks"""
482
+ api = HfApi()
483
+ token = HfFolder.get_token()
484
+
485
+ # Create combined list of files to upload
486
+ all_files = []
487
+
488
+ # If resuming from file, read the list of files to upload
489
+ start_idx = 0
490
+
491
+ if resume_file and os.path.exists(resume_file):
492
+ print(f"Resuming upload from {resume_file}")
493
+ with open(resume_file, "r") as f:
494
+ lines = f.readlines()
495
+ header = lines[0].strip().split(",")
496
+ start_idx = int(header[0])
497
+
498
+ # Read file entries
499
+ for line in lines[1:]:
500
+ parts = line.strip().split("|")
501
+ if len(parts) == 3:
502
+ all_files.append((parts[0], parts[1], parts[2]))
503
+
504
+ print(f"Resuming upload from index {start_idx}, {len(all_files)} files remaining")
505
+ else:
506
+ # Create new file list
507
+ for i, (image_path, depth_map_path) in enumerate(zip(images, depth_maps)):
508
+ all_files.append((image_path, os.path.basename(image_path), f"pair_{i+1}_image"))
509
+ all_files.append((depth_map_path, os.path.basename(depth_map_path), f"pair_{i+1}_depth"))
510
+
511
+ total_files = len(all_files)
512
+
513
+ # Validate break interval
514
+ if break_every <= 0:
515
+ break_every = 100
516
+
517
+ # Create resume file
518
+ resume_path = create_resume_file(resume_dir, all_files, start_idx, repo_id)
519
+ print(f"Created resume file: {resume_path}")
520
+ print(f"If the upload is interrupted, you can resume using this path in the UI")
521
+
522
+ # Ensure the repository exists and is of type 'dataset'
523
+ try:
524
+ api.repo_info(repo_id=repo_id, token=token)
525
+ except Exception as e:
526
+ try:
527
+ create_repo(repo_id=repo_id, repo_type="dataset", token=token)
528
+ except Exception as create_e:
529
+ if "You already created this dataset repo" not in str(create_e):
530
+ raise create_e
531
+
532
+ print(f"Beginning upload of {total_files} files (starting at {start_idx+1})")
533
+ print(f"Will take a 3-minute break after every {break_every} files to avoid rate limiting")
534
+
535
+ # Track upload metrics
536
+ upload_start_time = time.time()
537
+ success_count = 0
538
+
539
+ # Create progress bar
540
+ progress_bar = tqdm(total=total_files, initial=start_idx, desc="Uploading",
541
+ unit="files", dynamic_ncols=True)
542
+
543
+ try:
544
+ # Process files with periodic breaks
545
+ for idx in range(start_idx, total_files):
546
+ file_path, file_name, file_type = all_files[idx]
547
+
548
+ # Take a break every break_every files (but not at the start)
549
+ if idx > start_idx and (idx - start_idx) % break_every == 0:
550
+ break_minutes = 3
551
+
552
+ # Longer break after known problematic thresholds
553
+ if idx >= 125 and idx < 130:
554
+ break_minutes = 15
555
+ tqdm.write(f"===== EXTENDED RATE LIMIT PREVENTION BREAK =====")
556
+ tqdm.write(f"Approaching critical threshold (files 125-130). Taking a longer {break_minutes}-minute break...")
557
+ else:
558
+ tqdm.write(f"===== RATE LIMIT PREVENTION BREAK =====")
559
+ tqdm.write(f"Uploaded {break_every} files. Taking a {break_minutes}-minute break...")
560
+
561
+ create_resume_file(resume_dir, all_files, idx, repo_id)
562
+
563
+ # Show countdown timer for the break
564
+ for remaining in range(break_minutes * 60, 0, -10):
565
+ mins = remaining // 60
566
+ secs = remaining % 60
567
+ tqdm.write(f"Resuming in {mins}m {secs}s...")
568
+ time.sleep(10)
569
+
570
+ tqdm.write("Break finished, continuing uploads...")
571
+
572
+ # Upload the file
573
+ tqdm.write(f"Uploading file {idx+1}/{total_files}: {file_name}")
574
+ success = safe_upload_file(api, file_path, file_name, repo_id, token)
575
+
576
+ if not success:
577
+ tqdm.write(f"Failed to upload {file_name} after multiple retries.")
578
+ # Update resume file with current position
579
+ create_resume_file(resume_dir, all_files, idx, repo_id)
580
+ return False
581
+
582
+ # Update tracking
583
+ success_count += 1
584
+ progress_bar.update(1)
585
+
586
+ # Update resume file every 10 uploads
587
+ if (idx + 1) % 10 == 0:
588
+ create_resume_file(resume_dir, all_files, idx + 1, repo_id)
589
+
590
+ except KeyboardInterrupt:
591
+ print("\nUpload interrupted! Creating resume file to continue later...")
592
+ create_resume_file(resume_dir, all_files, idx, repo_id)
593
+ return False
594
+
595
+ finally:
596
+ progress_bar.close()
597
+
598
+ # Calculate stats
599
+ total_time = time.time() - upload_start_time
600
+ files_per_second = success_count / total_time if total_time > 0 else 0
601
+
602
+ print(f"\nUpload completed! {success_count} files uploaded in {timedelta(seconds=int(total_time))}")
603
+ print(f"Average upload rate: {files_per_second:.2f} files/sec")
604
+
605
+ return True
606
+
607
+ def process_and_upload(folder_path, model_name, invert_depth, colormap_name, output_dir,
608
+ upload_to_hf_toggle, repo_id, break_every=100, parallel_workers=1,
609
+ resume_file=None, enable_blending=False, blend_opacity=0.0,
610
+ make_base_bw=False, depth_on_top=True, use_colormap=True, reverse_colormap=False):
611
+ """Process images and upload them to Hugging Face or save locally"""
612
+ encoder = name2encoder[model_name]
613
+ colormap = get_colormap_by_name(colormap_name)
614
+
615
+ # If resume file is provided, only upload (skip processing)
616
+ if resume_file and os.path.exists(resume_file) and upload_to_hf_toggle:
617
+ print(f"Resuming upload from file: {resume_file}")
618
+ success = upload_to_hf([], [], repo_id, break_every=break_every, resume_file=resume_file)
619
+ return "Resume upload completed successfully" if success else "Resume upload was interrupted or failed"
620
+
621
+ # Process images
622
+ images, depth_maps, temp_dir = process_images(
623
+ folder_path,
624
+ encoder,
625
+ output_dir=output_dir,
626
+ max_workers=parallel_workers,
627
+ invert_depth=invert_depth,
628
+ colormap=colormap,
629
+ enable_blending=enable_blending,
630
+ blend_opacity=blend_opacity,
631
+ make_base_bw=make_base_bw,
632
+ depth_on_top=depth_on_top,
633
+ use_colormap=use_colormap,
634
+ reverse_colormap=reverse_colormap
635
+ )
636
+
637
+ if not images:
638
+ return "No images were processed. Check the logs for details."
639
+
640
+ # Upload to HF if selected
641
+ if upload_to_hf_toggle and repo_id:
642
+ success = upload_to_hf(images, depth_maps, repo_id, break_every=break_every)
643
+ upload_status = f"Upload {'completed successfully' if success else 'was interrupted or failed'}. "
644
+ else:
645
+ upload_status = ""
646
+
647
+ # Output status
648
+ if output_dir:
649
+ local_status = f"Images and depth maps saved to {output_dir}. "
650
+ else:
651
+ local_status = ""
652
+
653
+ # Clean up
654
+ try:
655
+ shutil.rmtree(temp_dir)
656
+ except Exception as e:
657
+ print(f"Warning: Could not clean up temp directory: {e}")
658
+
659
+ return f"{local_status}{upload_status}Successfully processed {len(images)} images."
660
+
661
+ def colormap_list():
662
+ """Get list of available OpenCV colormaps"""
663
+ return [
664
+ "TURBO", "JET", "PARULA", "HOT", "WINTER", "RAINBOW",
665
+ "OCEAN", "SUMMER", "SPRING", "COOL", "HSV",
666
+ "PINK", "BONE", "VIRIDIS", "PLASMA", "INFERNO"
667
+ ]
668
+
669
+ def get_colormap_by_name(name):
670
+ """Convert colormap name to OpenCV enum"""
671
+ colormap_mapping = {
672
+ "TURBO": cv2.COLORMAP_TURBO,
673
+ "JET": cv2.COLORMAP_JET,
674
+ "PARULA": cv2.COLORMAP_PARULA,
675
+ "HOT": cv2.COLORMAP_HOT,
676
+ "WINTER": cv2.COLORMAP_WINTER,
677
+ "RAINBOW": cv2.COLORMAP_RAINBOW,
678
+ "OCEAN": cv2.COLORMAP_OCEAN,
679
+ "SUMMER": cv2.COLORMAP_SUMMER,
680
+ "SPRING": cv2.COLORMAP_SPRING,
681
+ "COOL": cv2.COLORMAP_COOL,
682
+ "HSV": cv2.COLORMAP_HSV,
683
+ "PINK": cv2.COLORMAP_PINK,
684
+ "BONE": cv2.COLORMAP_BONE,
685
+ "VIRIDIS": cv2.COLORMAP_VIRIDIS,
686
+ "PLASMA": cv2.COLORMAP_PLASMA,
687
+ "INFERNO": cv2.COLORMAP_INFERNO
688
+ }
689
+ return colormap_mapping.get(name, cv2.COLORMAP_TURBO)
690
+
691
+ def visualize_process(folder_path, model_name, invert_depth, colormap_name,
692
+ blend_opacity=0.0, make_base_bw=False, depth_on_top=True, sample_count=10):
693
+ """Process a sample of images from the folder and visualize them"""
694
+ encoder = name2encoder[model_name]
695
+ colormap = get_colormap_by_name(colormap_name)
696
+
697
+ # Validate folder path
698
+ if not os.path.exists(folder_path) or not os.path.isdir(folder_path):
699
+ return []
700
+
701
+ # Get image files
702
+ image_files = [f for f in os.listdir(folder_path)
703
+ if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
704
+
705
+ if not image_files:
706
+ return []
707
+
708
+ # Take a sample of images
709
+ if len(image_files) > sample_count:
710
+ image_files = random.sample(image_files, sample_count)
711
+
712
+ # Process images
713
+ temp_dir = tempfile.mkdtemp()
714
+ visualization = []
715
+
716
+ for filename in image_files:
717
+ try:
718
+ image_path = os.path.join(folder_path, filename)
719
+ temp_image_path = os.path.join(temp_dir, filename)
720
+ shutil.copy(image_path, temp_image_path)
721
+
722
+ image = Image.open(temp_image_path).convert('RGB')
723
+ image = resize_image(image)
724
+ image_np = np.array(image)
725
+
726
+ depth_map = predict_depth(image_np, encoder, invert_depth, blend_opacity, make_base_bw)
727
+ depth_map_colored = apply_colormap(depth_map, colormap)
728
+
729
+ depth_map_path = os.path.join(temp_dir, f"depth_{filename}")
730
+ save_image(Image.fromarray(depth_map_colored), depth_map_path)
731
+
732
+ visualization.append([image, Image.fromarray(depth_map_colored)])
733
+ print(f"Previewed {filename}")
734
+ except Exception as e:
735
+ print(f"Error processing image for preview: {e}")
736
+
737
+ # Clean up temp directory
738
+ try:
739
+ shutil.rmtree(temp_dir)
740
+ except:
741
+ pass
742
+
743
+ return visualization
744
+
745
+ def update_live_preview():
746
+ """Update the live preview gallery"""
747
+ return get_live_preview()
748
+
749
+ # Create Gradio interface
750
+ with gr.Blocks() as demo:
751
+ gr.Markdown("# 🩻 Enhanced Depth Map Generation 🩻")
752
+
753
+ with gr.Tab("Generate Depth Maps"):
754
+ folder_input = gr.Textbox(label="Folder Path", placeholder="Enter the path to the folder with images")
755
+
756
+ with gr.Row():
757
+ model_dropdown = gr.Dropdown(
758
+ choices=["Small", "Base", "Large"],
759
+ value="Small",
760
+ label="Model Size (Small=Fastest, Large=Best Quality)"
761
+ )
762
+
763
+ parallel_workers = gr.Slider(
764
+ minimum=1,
765
+ maximum=8,
766
+ value=1 if DEVICE == 'cpu' else 2,
767
+ step=1,
768
+ label="Parallel Workers (GPU only)"
769
+ )
770
+
771
+ with gr.Row():
772
+ invert_depth = gr.Checkbox(label="Invert Depth Map", value=False)
773
+ use_colormap = gr.Checkbox(label="Use Colormap", value=True)
774
+ reverse_colormap = gr.Checkbox(label="Reverse Colormap", value=False)
775
+ colormap_dropdown = gr.Dropdown(
776
+ choices=colormap_list(),
777
+ value="TURBO",
778
+ label="Colormap Style",
779
+ interactive=True
780
+ )
781
+
782
+ use_colormap.change(
783
+ fn=lambda x: gr.update(visible=x),
784
+ inputs=[use_colormap],
785
+ outputs=colormap_dropdown
786
+ )
787
+
788
+ with gr.Accordion("Blending Options", open=False):
789
+ with gr.Row():
790
+ enable_blending = gr.Checkbox(
791
+ label="Enable Blending",
792
+ value=False,
793
+ info="Blend depth map with original image"
794
+ )
795
+ make_base_bw = gr.Checkbox(
796
+ label="Make Original B&W",
797
+ value=False,
798
+ visible=False
799
+ )
800
+ depth_on_top = gr.Checkbox(
801
+ label="Depth on Top",
802
+ value=True,
803
+ visible=False
804
+ )
805
+
806
+ with gr.Row():
807
+ blend_opacity = gr.Slider(
808
+ minimum=0.0,
809
+ maximum=1.0,
810
+ value=0.5,
811
+ step=0.1,
812
+ label="Blend Strength",
813
+ info="0 = Original only, 1 = Depth only",
814
+ visible=False
815
+ )
816
+
817
+ enable_blending.change(
818
+ fn=lambda x: {
819
+ make_base_bw: gr.update(visible=x),
820
+ depth_on_top: gr.update(visible=x),
821
+ blend_opacity: gr.update(visible=x)
822
+ },
823
+ inputs=[enable_blending],
824
+ outputs=[make_base_bw, depth_on_top, blend_opacity]
825
+ )
826
+
827
+ with gr.Row():
828
+ output_dir = gr.Textbox(
829
+ label="Local Output Directory (Optional)",
830
+ placeholder="Leave empty to not save locally, or enter path to save files"
831
+ )
832
+
833
+ with gr.Row():
834
+ upload_to_hf_toggle = gr.Checkbox(label="Upload to Hugging Face", value=True)
835
+ repo_id_input = gr.Textbox(
836
+ label="Hugging Face Repo ID",
837
+ placeholder="username/repo-name",
838
+ interactive=True
839
+ )
840
+
841
+ with gr.Row():
842
+ break_every_input = gr.Slider(
843
+ minimum=50,
844
+ maximum=200,
845
+ value=100,
846
+ step=10,
847
+ label="Break Interval (for HF upload)"
848
+ )
849
+
850
+ resume_file = gr.Textbox(
851
+ label="Resume File (Optional)",
852
+ placeholder="Leave empty for new uploads, or provide path to resume file"
853
+ )
854
+
855
+ process_button = gr.Button("Process Images", variant="primary")
856
+ output = gr.Textbox(label="Output")
857
+
858
+ # Live preview gallery
859
+ gr.Markdown("### Live Processing Preview")
860
+ live_preview = gr.Gallery(label="Processing Progress", columns=2, height=400)
861
+ refresh_button = gr.Button("Refresh Preview")
862
+
863
+ with gr.Tab("Preview"):
864
+ with gr.Row():
865
+ preview_folder = gr.Textbox(label="Folder Path", placeholder="Enter the path to preview images from")
866
+ preview_model = gr.Dropdown(
867
+ choices=["Small", "Base", "Large"],
868
+ value="Small",
869
+ label="Model Size"
870
+ )
871
+
872
+ with gr.Row():
873
+ preview_invert = gr.Checkbox(label="Invert Depth Map", value=False)
874
+ preview_colormap = gr.Dropdown(
875
+ choices=colormap_list(),
876
+ value="TURBO",
877
+ label="Colormap Style"
878
+ )
879
+
880
+ with gr.Row():
881
+ preview_blend_opacity = gr.Slider(
882
+ minimum=0.0,
883
+ maximum=1.0,
884
+ value=0.0,
885
+ step=0.1,
886
+ label="Preview Blend Opacity"
887
+ )
888
+ preview_make_bw = gr.Checkbox(
889
+ label="Make Base Image Black & White",
890
+ value=False
891
+ )
892
+ preview_depth_on_top = gr.Checkbox(
893
+ label="Depth Map on Top",
894
+ value=True
895
+ )
896
+
897
+ visualize_button = gr.Button("Generate Preview", variant="secondary")
898
+ preview_output = gr.Gallery(label="Sample Depth Maps", columns=2, height=600)
899
+
900
+ with gr.Tab("Help"):
901
+ gr.Markdown("""
902
+ ## Usage Instructions
903
+
904
+ ### Generate Depth Maps Tab
905
+ 1. **Folder Path**: Enter the full path to the folder containing your images (PNG, JPG, JPEG)
906
+ 2. **Model Size**:
907
+ - Small: Fastest processing but lowest quality
908
+ - Base: Good balance between speed and quality
909
+ - Large: Best quality but slowest processing
910
+ 3. **Parallel Workers**: How many images to process simultaneously (only works with GPU)
911
+ 4. **Invert Depth Map**: Toggle to invert the depth values (far objects bright, near objects dark)
912
+ 5. **Colormap Style**: Choose from various color schemes for the depth visualization
913
+ 6. **Local Output Directory**: Path where you want to save processed images locally
914
+ 7. **Upload to Hugging Face**: Toggle whether to upload to Hugging Face Hub
915
+ 8. **HF Repo ID**: Your Hugging Face username and repository name (e.g., `username/dataset-name`)
916
+ 9. **Break Interval**: The script will take a 3-minute break after uploading this many files
917
+ 10. **Resume File**: If your upload was interrupted, you can provide the resume file path here
918
+
919
+ ### Live Preview
920
+ - During processing, a live preview will show the most recent processed images
921
+ - Click "Refresh Preview" to update the display
922
+
923
+ ### Preview Tab
924
+ Quickly preview what the depth maps will look like without uploading anything.
925
+
926
+ ### Important Notes
927
+ - Processing is much faster with a GPU
928
+ - If saving locally, original images and depth maps will be saved with _depth suffix
929
+ - When uploading to Hugging Face, the script takes breaks to avoid rate limits
930
+ """)
931
+
932
+ # Define event handlers
933
+ def toggle_hf_fields(upload_enabled):
934
+ return {
935
+ repo_id_input: gr.update(interactive=upload_enabled),
936
+ break_every_input: gr.update(interactive=upload_enabled),
937
+ resume_file: gr.update(interactive=upload_enabled)
938
+ }
939
+
940
+ # Connect interactive elements
941
+ upload_to_hf_toggle.change(
942
+ fn=toggle_hf_fields,
943
+ inputs=upload_to_hf_toggle,
944
+ outputs=[repo_id_input, break_every_input, resume_file]
945
+ )
946
+
947
+ # Connect buttons to functions
948
+ process_button.click(
949
+ fn=process_and_upload,
950
+ inputs=[
951
+ folder_input, model_dropdown, invert_depth, colormap_dropdown,
952
+ output_dir, upload_to_hf_toggle, repo_id_input,
953
+ break_every_input, parallel_workers, resume_file,
954
+ enable_blending, blend_opacity, make_base_bw, depth_on_top,
955
+ use_colormap, reverse_colormap # Add reverse_colormap
956
+ ],
957
+ outputs=output
958
+ )
959
+
960
+ refresh_button.click(
961
+ fn=update_live_preview,
962
+ inputs=[],
963
+ outputs=live_preview
964
+ )
965
+
966
+ visualize_button.click(
967
+ fn=visualize_process,
968
+ inputs=[preview_folder, preview_model, preview_invert, preview_colormap,
969
+ preview_blend_opacity, preview_make_bw, preview_depth_on_top],
970
+ outputs=preview_output
971
+ )
972
+
973
+ # Set up the live preview - just initialize it
974
+ demo.load(lambda: [], None, live_preview)
975
+
976
+ if __name__ == "__main__":
977
+ demo.launch()