Borcherding commited on
Commit
6b4ae3a
·
verified ·
1 Parent(s): 661b808

Upload 3 files

Browse files
src/inference/cycleGANtest.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import os
8
+ import numpy as np
9
+
10
+ # Generator architecture (simplified ResNet)
11
+ class ResidualBlock(nn.Module):
12
+ def __init__(self, channels):
13
+ super(ResidualBlock, self).__init__()
14
+ self.conv_block = nn.Sequential( # Changed from 'block' to 'conv_block'
15
+ nn.ReflectionPad2d(1),
16
+ nn.Conv2d(channels, channels, 3),
17
+ nn.InstanceNorm2d(channels),
18
+ nn.ReLU(inplace=True),
19
+ nn.ReflectionPad2d(1),
20
+ nn.Conv2d(channels, channels, 3),
21
+ nn.InstanceNorm2d(channels)
22
+ )
23
+
24
+ def forward(self, x):
25
+ return x + self.conv_block(x) # Changed from 'block' to 'conv_block'
26
+
27
+ class Generator(nn.Module):
28
+ def __init__(self, input_channels=3, output_channels=3, n_residual_blocks=9):
29
+ super(Generator, self).__init__()
30
+
31
+ # Initial convolution
32
+ model = [
33
+ nn.ReflectionPad2d(3),
34
+ nn.Conv2d(input_channels, 64, 7),
35
+ nn.InstanceNorm2d(64),
36
+ nn.ReLU(inplace=True)
37
+ ]
38
+
39
+ # Downsampling
40
+ in_features = 64
41
+ out_features = in_features * 2
42
+ for _ in range(2):
43
+ model += [
44
+ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
45
+ nn.InstanceNorm2d(out_features),
46
+ nn.ReLU(inplace=True)
47
+ ]
48
+ in_features = out_features
49
+ out_features = in_features * 2
50
+
51
+ # Residual blocks
52
+ for _ in range(n_residual_blocks):
53
+ model += [ResidualBlock(in_features)]
54
+
55
+ # Upsampling
56
+ out_features = in_features // 2
57
+ for _ in range(2):
58
+ model += [
59
+ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
60
+ nn.InstanceNorm2d(out_features),
61
+ nn.ReLU(inplace=True)
62
+ ]
63
+ in_features = out_features
64
+ out_features = in_features // 2
65
+
66
+ # Output layer
67
+ model += [
68
+ nn.ReflectionPad2d(3),
69
+ nn.Conv2d(64, output_channels, 7),
70
+ nn.Tanh()
71
+ ]
72
+
73
+ self.model = nn.Sequential(*model)
74
+
75
+ def forward(self, x):
76
+ return self.model(x)
77
+
78
+ # Image preprocessing
79
+ def preprocess_image(image_path):
80
+ image = Image.open(image_path).convert('RGB')
81
+ transform = transforms.Compose([
82
+ transforms.Resize(256),
83
+ transforms.ToTensor(),
84
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
85
+ ])
86
+ return transform(image).unsqueeze(0)
87
+
88
+ # Image postprocessing
89
+ def postprocess_image(tensor):
90
+ tensor = tensor.squeeze(0).cpu()
91
+ tensor = (tensor + 1) / 2
92
+ tensor = tensor.clamp(0, 1)
93
+ tensor = tensor.permute(1, 2, 0).numpy()
94
+ return (tensor * 255).astype(np.uint8)
95
+
96
+ # Model loading
97
+ def load_model(model_path):
98
+ model = Generator()
99
+ if os.path.exists(model_path):
100
+ print(f"Loading model from {model_path}")
101
+ state_dict = torch.load(model_path, map_location='cpu')
102
+ try:
103
+ model.load_state_dict(state_dict)
104
+ except Exception as e:
105
+ print(f"Warning: {e}")
106
+ # Try loading with strict=False
107
+ model.load_state_dict(state_dict, strict=False)
108
+ print("Loaded model with strict=False")
109
+ else:
110
+ print(f"Error: Model file not found at {model_path}")
111
+ return None
112
+ model.eval()
113
+ return model
114
+
115
+ # Inference function
116
+ # Update the transform_image function to handle numpy arrays from Gradio
117
+ def transform_image(input_image, direction):
118
+ if input_image is None:
119
+ print("No input image provided")
120
+ return None
121
+
122
+ try:
123
+ # Ensure input image is RGB
124
+ if len(input_image.shape) == 2: # Grayscale
125
+ input_image = np.stack([input_image] * 3, axis=-1)
126
+ elif input_image.shape[-1] == 4: # RGBA
127
+ input_image = input_image[..., :3]
128
+
129
+ if direction == "Depth to Image":
130
+ model_path = "./checkpoints/depth2image/latest_net_G_A.pth"
131
+ else:
132
+ model_path = "./checkpoints/depth2image/latest_net_G_B.pth"
133
+
134
+ # Load model
135
+ model = load_model(model_path)
136
+ if model is None:
137
+ print(f"Failed to load model from {model_path}")
138
+ return None
139
+
140
+ # Convert numpy array to PIL Image
141
+ input_pil = Image.fromarray(input_image.astype('uint8'), 'RGB')
142
+
143
+ # Create transforms
144
+ transform = transforms.Compose([
145
+ transforms.Resize(256),
146
+ transforms.ToTensor(),
147
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
148
+ ])
149
+
150
+ # Process image
151
+ input_tensor = transform(input_pil).unsqueeze(0)
152
+
153
+ # Generate output
154
+ with torch.no_grad():
155
+ output_tensor = model(input_tensor)
156
+
157
+ # Convert to image
158
+ output_image = postprocess_image(output_tensor)
159
+
160
+ return output_image
161
+
162
+ except Exception as e:
163
+ print(f"Error in transform_image: {e}")
164
+ import traceback
165
+ traceback.print_exc()
166
+ return None
167
+
168
+ # Update the Gradio interface
169
+ with gr.Blocks(title="CycleGAN Depth2Image Test", analytics_enabled=False) as demo:
170
+ gr.Markdown("## Test CycleGAN Depth2Image Model")
171
+
172
+ with gr.Row():
173
+ with gr.Column():
174
+ input_image = gr.Image(
175
+ label="Input Image",
176
+ type="numpy",
177
+ height=256,
178
+ width=256
179
+ )
180
+ direction = gr.Radio(
181
+ choices=["Depth to Image", "Image to Depth"],
182
+ value="Depth to Image",
183
+ label="Conversion Direction"
184
+ )
185
+ transform_btn = gr.Button("Transform", variant="primary")
186
+
187
+ with gr.Column():
188
+ output_image = gr.Image(
189
+ label="Generated Output",
190
+ height=256,
191
+ width=256
192
+ )
193
+ error_output = gr.Textbox(
194
+ label="Status",
195
+ interactive=False
196
+ )
197
+
198
+ # Connect components
199
+ transform_btn.click(
200
+ fn=transform_image,
201
+ inputs=[input_image, direction],
202
+ outputs=output_image
203
+ )
204
+
205
+ gr.Markdown("""
206
+ ### Instructions:
207
+ 1. Upload an image
208
+ 2. Select conversion direction:
209
+ - "Depth to Image" converts depth maps to realistic images
210
+ - "Image to Depth" converts realistic images to depth maps
211
+ 3. Click "Transform" to generate the output
212
+
213
+ Note: Input images will be resized to 256x256 pixels.
214
+ """)
215
+
216
+ if __name__ == "__main__":
217
+ # Make sure checkpoints directory exists
218
+ os.makedirs("checkpoints/depth2image", exist_ok=True)
219
+
220
+ # Launch with custom server configuration
221
+ demo.queue(max_size=5).launch(
222
+ server_name="0.0.0.0", # Allow external connections
223
+ server_port=7860, # Set specific port
224
+ show_error=True, # Show detailed errors
225
+ debug=True # Enable debug mode
226
+ )
src/inference/merged-discord-app.py ADDED
@@ -0,0 +1,1194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+ import sys
8
+ import os
9
+ import threading
10
+ import pyvirtualcam
11
+ from pyvirtualcam import PixelFormat
12
+ from huggingface_hub import hf_hub_download, login, upload_file
13
+ import torch.nn as nn
14
+ import time
15
+ import mss
16
+ import traceback
17
+
18
+ # Ensure required environment variables
19
+ depth_anything_path = os.getenv('DEPTH_ANYTHING_V2_PATH')
20
+ if depth_anything_path is None:
21
+ raise ValueError("Environment variable DEPTH_ANYTHING_V2_PATH is not set. Please set it to the path of Depth-Anything-V2")
22
+ sys.path.append(depth_anything_path)
23
+ from depth_anything_v2.dpt import DepthAnythingV2
24
+
25
+ # --- Global variables and constants ---
26
+ # Updated colormaps to match the dataset generator
27
+ DEPTH_COLORMAPS = {
28
+ "TURBO": cv2.COLORMAP_TURBO,
29
+ "JET": cv2.COLORMAP_JET,
30
+ "PARULA": cv2.COLORMAP_PARULA,
31
+ "HOT": cv2.COLORMAP_HOT,
32
+ "WINTER": cv2.COLORMAP_WINTER,
33
+ "RAINBOW": cv2.COLORMAP_RAINBOW,
34
+ "OCEAN": cv2.COLORMAP_OCEAN,
35
+ "SUMMER": cv2.COLORMAP_SUMMER,
36
+ "SPRING": cv2.COLORMAP_SPRING,
37
+ "COOL": cv2.COLORMAP_COOL,
38
+ "HSV": cv2.COLORMAP_HSV,
39
+ "PINK": cv2.COLORMAP_PINK,
40
+ "BONE": cv2.COLORMAP_BONE,
41
+ "VIRIDIS": cv2.COLORMAP_VIRIDIS,
42
+ "PLASMA": cv2.COLORMAP_PLASMA,
43
+ "INFERNO": cv2.COLORMAP_INFERNO,
44
+ "MAGMA": cv2.COLORMAP_MAGMA # Keeping this one from your webcam app
45
+ }
46
+
47
+ # Add these global variables to store current settings
48
+ current_colormap = "TURBO"
49
+ current_mode = "Depth to Robot"
50
+ current_model_name = "Small"
51
+ current_webcam_id = 0
52
+ current_invert_depth = False
53
+ current_input_source = "Webcam" # or "Desktop"
54
+ current_bypass_depth = False
55
+ current_blend_opacity = 0.1 # New: default opacity for blending
56
+ current_blend_enabled = False # New: option to enable/disable blending
57
+ # Add these at the top with other globals
58
+ DEPTH2ROBOT_LOCAL_PATH = './checkpoints/depth2image/latest_net_G_A.pth'
59
+ current_gan_source = "Local" # or "HuggingFace"
60
+ # At the top with other globals, add:
61
+ current_gan_input = None # Store the current GAN input for display
62
+ # First add a new global variable to track direction
63
+ current_direction = "Depth to Image" # or "Image to Depth"
64
+
65
+ # --- Device selection ---
66
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
67
+ print(f"Using device: {DEVICE}")
68
+
69
+ # Global variables for thread management
70
+ webcam_thread = None
71
+ stop_signal = False
72
+
73
+ # --- Depth-Anything-V2 Model Configurations ---
74
+ model_configs = {
75
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
76
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
77
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
78
+ }
79
+
80
+ encoder2name = {
81
+ 'vits': 'Small',
82
+ 'vitb': 'Base',
83
+ 'vitl': 'Large'
84
+ }
85
+
86
+ # Model IDs and filenames for HuggingFace Hub
87
+ DEPTH_MODEL_INFO = {
88
+ 'vits': {
89
+ 'repo_id': 'depth-anything/Depth-Anything-V2-Small',
90
+ 'filename': 'depth_anything_v2_vits.pth'
91
+ },
92
+ 'vitb': {
93
+ 'repo_id': 'depth-anything/Depth-Anything-V2-Base',
94
+ 'filename': 'depth_anything_v2_vitb.pth'
95
+ },
96
+ 'vitl': {
97
+ 'repo_id': 'depth-anything/Depth-Anything-V2-Large',
98
+ 'filename': 'depth_anything_v2_vitl.pth'
99
+ }
100
+ }
101
+
102
+ # --- CycleGAN Network Architecture ---
103
+ class ResnetBlock(nn.Module):
104
+ def __init__(self, dim, padding_type='reflect', norm_layer=nn.InstanceNorm2d, use_dropout=False):
105
+ super(ResnetBlock, self).__init__()
106
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout)
107
+
108
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout):
109
+ conv_block = []
110
+ p = 0
111
+ if padding_type == 'reflect':
112
+ conv_block += [nn.ReflectionPad2d(1)]
113
+ elif padding_type == 'replicate':
114
+ conv_block += [nn.ReplicationPad2d(1)]
115
+ elif padding_type == 'zero':
116
+ p = 1
117
+ else:
118
+ raise NotImplementedError(f'padding {padding_type} is not implemented')
119
+
120
+ conv_block += [
121
+ nn.Conv2d(dim, dim, kernel_size=3, padding=p),
122
+ norm_layer(dim),
123
+ nn.ReLU(True)
124
+ ]
125
+ if use_dropout:
126
+ conv_block += [nn.Dropout(0.5)]
127
+
128
+ p = 0
129
+ if padding_type == 'reflect':
130
+ conv_block += [nn.ReflectionPad2d(1)]
131
+ elif padding_type == 'replicate':
132
+ conv_block += [nn.ReplicationPad2d(1)]
133
+ elif padding_type == 'zero':
134
+ p = 1
135
+ else:
136
+ raise NotImplementedError(f'padding {padding_type} is not implemented')
137
+
138
+ conv_block += [
139
+ nn.Conv2d(dim, dim, kernel_size=3, padding=p),
140
+ norm_layer(dim)
141
+ ]
142
+
143
+ return nn.Sequential(*conv_block)
144
+
145
+ def forward(self, x):
146
+ return x + self.conv_block(x)
147
+
148
+
149
+ class Generator(nn.Module):
150
+ def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9, norm_layer=nn.InstanceNorm2d):
151
+ super(Generator, self).__init__()
152
+
153
+ model = [
154
+ nn.ReflectionPad2d(3),
155
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
156
+ norm_layer(ngf),
157
+ nn.ReLU(True)
158
+ ]
159
+
160
+ # Downsampling
161
+ n_downsampling = 2
162
+ for i in range(n_downsampling):
163
+ mult = 2 ** i
164
+ model += [
165
+ nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
166
+ norm_layer(ngf * mult * 2),
167
+ nn.ReLU(True)
168
+ ]
169
+
170
+ # Resnet blocks
171
+ mult = 2 ** n_downsampling
172
+ for i in range(n_blocks):
173
+ model += [ResnetBlock(ngf * mult, norm_layer=norm_layer)]
174
+
175
+ # Upsampling
176
+ for i in range(n_downsampling):
177
+ mult = 2 ** (n_downsampling - i)
178
+ model += [
179
+ nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
180
+ norm_layer(int(ngf * mult / 2)),
181
+ nn.ReLU(True)
182
+ ]
183
+
184
+ model += [
185
+ nn.ReflectionPad2d(3),
186
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
187
+ nn.Tanh()
188
+ ]
189
+
190
+ self.model = nn.Sequential(*model)
191
+
192
+ def forward(self, input):
193
+ return self.model(input)
194
+
195
+
196
+ # --- Global variables for model management ---
197
+ current_depth_model = None
198
+ current_encoder = None
199
+ current_gan_model = None
200
+
201
+ # --- Model paths and HuggingFace configuration ---
202
+ DEPTH2ROBOT_MODEL_PATH = './checkpoints/depth2image/latest_net_G_A.pth'
203
+ DEPTH2ROBOT_HF_REPO = 'Borcherding/depth2AnythingCycleGAN_RobotsV2' # Replace with your HF username
204
+
205
+ def download_depth_model(encoder):
206
+ """Download the specified depth model from HuggingFace Hub"""
207
+ model_info = DEPTH_MODEL_INFO[encoder]
208
+ model_path = hf_hub_download(
209
+ repo_id=model_info['repo_id'],
210
+ filename=model_info['filename'],
211
+ local_dir='checkpoints'
212
+ )
213
+ return model_path
214
+
215
+ def load_depth_model(encoder):
216
+ """Load the specified depth model"""
217
+ global current_depth_model, current_encoder
218
+ if current_encoder != encoder:
219
+ model_path = download_depth_model(encoder)
220
+ current_depth_model = DepthAnythingV2(**model_configs[encoder])
221
+ current_depth_model.load_state_dict(torch.load(model_path, map_location='cpu'))
222
+ current_depth_model = current_depth_model.to(DEVICE).eval()
223
+ current_encoder = encoder
224
+ return current_depth_model
225
+
226
+ def apply_colormap(depth, colormap=cv2.COLORMAP_TURBO):
227
+ """Apply a colormap to the depth image"""
228
+ # COLORMAP_TURBO provides better visualization than COLORMAP_JET
229
+ # It has a wider color spectrum and better perceptual properties
230
+ return cv2.applyColorMap(depth, colormap)
231
+
232
+ # Modify load_gan_model to handle both directions
233
+ def load_gan_model():
234
+ global current_gan_model, current_direction
235
+
236
+ try:
237
+ print(f"\nLoading GAN model for direction: {current_direction}")
238
+
239
+ # Select correct model file
240
+ if current_direction == "Depth to Image":
241
+ model_path = './checkpoints/depth2image/latest_net_G_A.pth'
242
+ else:
243
+ model_path = './checkpoints/depth2image/latest_net_G_B.pth'
244
+
245
+ print(f"Loading from: {os.path.abspath(model_path)}")
246
+
247
+ if not os.path.exists(model_path):
248
+ print(f"Model file not found: {model_path}")
249
+ return None
250
+
251
+ # Initialize model
252
+ current_gan_model = Generator().to(DEVICE)
253
+ state_dict = torch.load(model_path, map_location=DEVICE)
254
+
255
+ try:
256
+ current_gan_model.load_state_dict(state_dict, strict=False)
257
+ print("Model loaded successfully")
258
+ except Exception as e:
259
+ print(f"Error loading state dict: {e}")
260
+ return None
261
+
262
+ current_gan_model.eval()
263
+ return "GAN model loaded successfully"
264
+
265
+ except Exception as e:
266
+ return f"Error loading GAN: {str(e)}"
267
+
268
+ def update_gan_source(source, path):
269
+ """Update GAN model source and path"""
270
+ global current_gan_source, DEPTH2ROBOT_HF_REPO, current_gan_model, DEPTH2ROBOT_MODEL_PATH
271
+
272
+ current_gan_source = source
273
+ if source == "HuggingFace":
274
+ DEPTH2ROBOT_HF_REPO = path
275
+ else: # Local
276
+ DEPTH2ROBOT_MODEL_PATH = path # Update the model path globally
277
+
278
+ # Force reload of GAN model
279
+ current_gan_model = None
280
+
281
+ # Test loading
282
+ model = load_gan_model()
283
+ if model is not None:
284
+ return f"✅ Successfully updated GAN source to {source} using path: {path}"
285
+ else:
286
+ return "❌ Failed to load GAN model with new settings"
287
+
288
+ def toggle_invert_depth():
289
+ """Toggle depth inversion without restarting the webcam"""
290
+ global current_invert_depth
291
+
292
+ if webcam_thread and webcam_thread.is_alive():
293
+ current_invert_depth = not current_invert_depth
294
+ orientation = "light=near, dark=far" if current_invert_depth else "dark=near, light=far"
295
+ return f"✅ Depth colors swapped: {orientation}"
296
+ else:
297
+ return "⚠️ Webcam is not running. Please start it first."
298
+
299
+ def reverse_depth_colormap():
300
+ """Reverse the depth colormap colors without restarting the webcam"""
301
+ global current_invert_depth
302
+
303
+ if webcam_thread and webcam_thread.is_alive():
304
+ current_invert_depth = not current_invert_depth
305
+ orientation = "dark=near, light=far" if current_invert_depth else "light=near, dark=far"
306
+ return f"✅ Depth colors reversed: {orientation}"
307
+ else:
308
+ return "⚠️ Webcam is not running. Please start it first."
309
+
310
+ def blend_images(original, depth, opacity=0.1):
311
+ """
312
+ Blend original image with depth map
313
+ original: Top layer (webcam/desktop)
314
+ depth: Bottom layer (depth map)
315
+ opacity: 0.0 = depth only, 1.0 = original only
316
+ """
317
+ # Convert inputs to numpy arrays if needed
318
+ if not isinstance(original, np.ndarray):
319
+ original = np.array(original)
320
+ if not isinstance(depth, np.ndarray):
321
+ depth = np.array(depth)
322
+
323
+ # Ensure both images are float32 for blending
324
+ original = original.astype(np.float32)
325
+ depth = depth.astype(np.float32)
326
+
327
+ # Reverse the opacity interpretation for consistency with the UI
328
+ # (0 = depth only, 1 = original/webcam only)
329
+ blended = depth * (1 - opacity) + original * opacity
330
+
331
+ # Clip values and convert back to uint8
332
+ blended = np.clip(blended, 0, 255).astype(np.uint8)
333
+
334
+ return blended
335
+
336
+ def toggle_blend_enabled():
337
+ """Toggle blending without restarting the webcam"""
338
+ global current_blend_enabled
339
+
340
+ if webcam_thread and webcam_thread.is_alive():
341
+ current_blend_enabled = not current_blend_enabled
342
+ status = "enabled" if current_blend_enabled else "disabled"
343
+ return f"✅ Image blending {status}"
344
+ else:
345
+ return "⚠️ Webcam is not running. Please start it first."
346
+
347
+ def update_blend_opacity(opacity):
348
+ """Update the blend opacity without restarting the webcam"""
349
+ global current_blend_opacity
350
+
351
+ if webcam_thread and webcam_thread.is_alive():
352
+ current_blend_opacity = opacity
353
+ return f"✅ Updated blend opacity to {opacity:.1f}"
354
+ else:
355
+ return "⚠️ Webcam is not running. Please start it first."
356
+
357
+ @torch.inference_mode()
358
+ def predict_depth(image, encoder, invert_depth=None):
359
+ """Predict depth using the selected model with pure output"""
360
+ model = load_depth_model(encoder)
361
+ depth = model.infer_image(image)
362
+
363
+ # Linear normalization to 0-255 range without enhancing contrast
364
+ depth = depth - depth.min()
365
+ max_val = depth.max()
366
+ if (max_val > 0): # Avoid division by zero
367
+ depth = (depth / max_val * 255.0)
368
+
369
+ # Convert to uint8 without any additional processing
370
+ depth = depth.astype(np.uint8)
371
+
372
+ # Simple inversion if requested
373
+ if invert_depth:
374
+ depth = 255 - depth
375
+
376
+ return depth
377
+
378
+ @torch.inference_mode()
379
+ def depth_to_robot(depth_image):
380
+ """Convert depth image to robot image using CycleGAN"""
381
+ try:
382
+ model = load_gan_model()
383
+ if model is None:
384
+ print("No GAN model loaded!")
385
+ return depth_image
386
+
387
+ print(f"Input shape: {depth_image.shape}, dtype: {depth_image.dtype}")
388
+
389
+ # Ensure input is in correct format
390
+ if depth_image.dtype != np.uint8:
391
+ depth_image = depth_image.astype(np.uint8)
392
+
393
+ # Normalize to [-1, 1] range for GAN
394
+ depth_tensor = torch.from_numpy(depth_image).float().permute(2, 0, 1).unsqueeze(0)
395
+ depth_tensor = (depth_tensor / 127.5) - 1.0
396
+
397
+ print(f"Tensor shape: {depth_tensor.shape}, device: {depth_tensor.device}")
398
+
399
+ # Process through GAN
400
+ depth_tensor = depth_tensor.to(DEVICE)
401
+ with torch.no_grad():
402
+ robot_tensor = model(depth_tensor)
403
+
404
+ print(f"Output tensor shape: {robot_tensor.shape}")
405
+
406
+ # Convert back to image (0-255 range)
407
+ robot_tensor = (robot_tensor + 1.0) * 127.5
408
+ robot_image = robot_tensor[0].permute(1, 2, 0).cpu().numpy().astype(np.uint8)
409
+
410
+ return robot_image
411
+ except Exception as e:
412
+ print(f"Error in depth_to_robot: {e}")
413
+ traceback.print_exc()
414
+ return depth_image
415
+
416
+ def toggle_depth_bypass():
417
+ """Toggle depth map bypass"""
418
+ global current_bypass_depth
419
+
420
+ if webcam_thread and webcam_thread.is_alive():
421
+ current_bypass_depth = not current_bypass_depth
422
+ status = "enabled" if current_bypass_depth else "disabled"
423
+ return f"✅ Depth bypass {status}"
424
+ else:
425
+ return "⚠️ Webcam is not running. Please start it first."
426
+
427
+ def process_frame(frame, encoder, use_gan=True, colormap="WINTER"):
428
+ """Process a single frame matching the test app's pattern"""
429
+ global current_invert_depth, current_bypass_depth, current_blend_enabled
430
+ global current_blend_opacity, current_gan_input, current_direction, current_gan_model
431
+
432
+ try:
433
+ # Convert frame to RGB
434
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
435
+
436
+ # Load GAN model if not loaded
437
+ if use_gan and current_gan_model is None:
438
+ current_gan_model = load_gan_model()
439
+ if current_gan_model is None:
440
+ print("Failed to load GAN model, falling back to depth only")
441
+ use_gan = False
442
+
443
+ if current_direction == "Depth to Image":
444
+ # First get depth map
445
+ depth = predict_depth(frame_rgb, encoder, invert_depth=current_invert_depth)
446
+
447
+ # Apply colormap to depth
448
+ selected_colormap = DEPTH_COLORMAPS.get(colormap, cv2.COLORMAP_WINTER)
449
+ depth_colored = cv2.applyColorMap(depth, selected_colormap)
450
+
451
+ # Apply blending if enabled
452
+ if current_blend_enabled:
453
+ depth_colored = blend_images(frame_rgb, depth_colored, current_blend_opacity)
454
+
455
+ # Store the input we're sending to GAN
456
+ current_gan_input = depth_colored.copy()
457
+
458
+ if use_gan and current_gan_model is not None:
459
+ try:
460
+ # Convert to PIL and process like in test app
461
+ input_pil = Image.fromarray(depth_colored)
462
+ transform = transforms.Compose([
463
+ transforms.Resize(256),
464
+ transforms.ToTensor(),
465
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
466
+ ])
467
+ input_tensor = transform(input_pil).unsqueeze(0)
468
+
469
+ # Process through GAN
470
+ input_tensor = input_tensor.to(DEVICE)
471
+ with torch.no_grad():
472
+ output_tensor = current_gan_model(input_tensor)
473
+
474
+ # Convert back using same post-processing as test app
475
+ output_tensor = output_tensor.squeeze(0).cpu()
476
+ output_tensor = (output_tensor + 1) / 2
477
+ output_tensor = output_tensor.clamp(0, 1)
478
+ output_tensor = output_tensor.permute(1, 2, 0).numpy()
479
+ processed = (output_tensor * 255).astype(np.uint8)
480
+ except Exception as e:
481
+ print(f"Error processing through GAN: {e}")
482
+ processed = depth_colored
483
+ else:
484
+ processed = depth_colored
485
+
486
+ else: # Image to Depth
487
+ # Store original as GAN input
488
+ current_gan_input = frame_rgb.copy()
489
+
490
+ if use_gan:
491
+ # Process like test app
492
+ input_pil = Image.fromarray(frame_rgb)
493
+ transform = transforms.Compose([
494
+ transforms.Resize(256),
495
+ transforms.ToTensor(),
496
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
497
+ ])
498
+ input_tensor = transform(input_pil).unsqueeze(0)
499
+
500
+ input_tensor = input_tensor.to(DEVICE)
501
+ with torch.no_grad():
502
+ output_tensor = current_gan_model(input_tensor)
503
+
504
+ output_tensor = output_tensor.squeeze(0).cpu()
505
+ output_tensor = (output_tensor + 1) / 2
506
+ output_tensor = output_tensor.clamp(0, 1)
507
+ output_tensor = output_tensor.permute(1, 2, 0).numpy()
508
+ processed = (output_tensor * 255).astype(np.uint8)
509
+ else:
510
+ processed = frame_rgb
511
+
512
+ except Exception as e:
513
+ print(f"Error in processing: {e}")
514
+ traceback.print_exc()
515
+ processed = frame
516
+
517
+ return processed
518
+
519
+ def virtual_webcam_stream(encoder, use_gan=True, webcam_id=0):
520
+ """Stream depth display or robot conversion to virtual webcam with dynamic colormap changes"""
521
+ global current_colormap, current_input_source, current_gan_input, current_gan_model
522
+
523
+ try:
524
+ # First, ensure GAN model is loaded if needed
525
+ if use_gan and current_gan_model is None:
526
+ result = load_gan_model()
527
+ if isinstance(result, str) and "successfully" not in result.lower():
528
+ print("Failed to load GAN model, falling back to depth only")
529
+ use_gan = False
530
+
531
+ # Initialize webcam
532
+ cap = None
533
+ if current_input_source == "Webcam":
534
+ cap = cv2.VideoCapture(int(webcam_id))
535
+ if not cap.isOpened():
536
+ raise RuntimeError(f"Failed to open webcam {webcam_id}")
537
+ print(f"Successfully opened webcam {webcam_id}")
538
+
539
+ # Initialize screen capture if using desktop
540
+ sct = None
541
+ if current_input_source == "Desktop":
542
+ sct = mss.mss()
543
+ monitor = sct.monitors[1]
544
+
545
+ # Try different virtual camera backends in order of preference
546
+ cam = None
547
+ backends = ['droidcam', 'unity', 'obs'] # Changed order to try droidcam first
548
+ errors = []
549
+
550
+ for backend in backends:
551
+ try:
552
+ cam = pyvirtualcam.Camera(
553
+ width=640,
554
+ height=480,
555
+ fps=30,
556
+ fmt=PixelFormat.BGR,
557
+ backend=backend,
558
+ device='/dev/video2' if backend == 'v4l2' else None # For Linux
559
+ )
560
+ print(f'Successfully initialized virtual camera using {backend} backend')
561
+ break
562
+ except Exception as e:
563
+ errors.append(f'{backend} error: {str(e)}')
564
+ continue
565
+
566
+ if cam is None:
567
+ raise RuntimeError("Failed to initialize any virtual camera backend:\n" +
568
+ "\n".join(errors) +
569
+ "\nPlease install OBS Virtual Camera or another compatible virtual camera.")
570
+
571
+ print(f'Using virtual camera: {cam.device}')
572
+ print(f'Mode: {"Depth to Robot" if use_gan else "Depth Only"}')
573
+ print(f'Input Source: {current_input_source}')
574
+
575
+ frame_count = 0
576
+ last_time = time.time()
577
+ fps = 0
578
+
579
+ while not stop_signal:
580
+ try:
581
+ # Get frame based on input source
582
+ if current_input_source == "Webcam":
583
+ ret, frame = cap.read()
584
+ if not ret:
585
+ print("Failed to get frame from webcam")
586
+ time.sleep(0.1) # Add small delay before retry
587
+ continue
588
+ else: # Desktop
589
+ frame = get_screen_capture()
590
+
591
+ # Calculate FPS
592
+ frame_count += 1
593
+ if frame_count % 30 == 0:
594
+ current_time = time.time()
595
+ fps = 30 / (current_time - last_time)
596
+ last_time = current_time
597
+ print(f"FPS: {fps:.1f}")
598
+
599
+ # Resize frame to match virtual camera resolution
600
+ frame = cv2.resize(frame, (640, 480))
601
+
602
+ # Process the frame
603
+ processed = process_frame(frame, encoder, use_gan, current_colormap)
604
+
605
+ # Add GAN input preview if available
606
+ if current_gan_input is not None and use_gan:
607
+ preview_width = 160
608
+ preview_height = 120
609
+ preview = cv2.resize(current_gan_input, (preview_width, preview_height))
610
+
611
+ y_offset = 10
612
+ x_offset = processed.shape[1] - preview_width - 10
613
+
614
+ # Create a copy for modification
615
+ output = processed.copy()
616
+
617
+ # Add semi-transparent black background
618
+ overlay = np.zeros((preview_height + 2, preview_width + 2, 3), dtype=np.uint8)
619
+ alpha = 0.7
620
+ output[y_offset-1:y_offset+preview_height+1,
621
+ x_offset-1:x_offset+preview_width+1] = cv2.addWeighted(
622
+ output[y_offset-1:y_offset+preview_height+1,
623
+ x_offset-1:x_offset+preview_width+1],
624
+ 1 - alpha,
625
+ overlay,
626
+ alpha,
627
+ 0
628
+ )
629
+
630
+ # Add the preview
631
+ output[y_offset:y_offset+preview_height,
632
+ x_offset:x_offset+preview_width] = preview
633
+
634
+ processed = output
635
+
636
+ # Send to virtual camera
637
+ cam.send(processed)
638
+ cam.sleep_until_next_frame()
639
+
640
+ except Exception as e:
641
+ print(f"Error processing frame: {e}")
642
+ traceback.print_exc()
643
+ time.sleep(0.1) # Add delay before retry
644
+ continue
645
+
646
+ # Cleanup
647
+ if cap is not None:
648
+ cap.release()
649
+ if sct is not None:
650
+ sct.close()
651
+
652
+ except Exception as e:
653
+ print(f"Critical error in virtual_webcam_stream: {e}")
654
+ traceback.print_exc()
655
+ return False
656
+
657
+ return True
658
+
659
+ def toggle_input_source():
660
+ """Toggle between webcam and desktop capture"""
661
+ global current_input_source, webcam_thread, stop_signal
662
+
663
+ # Stop current stream
664
+ if webcam_thread and webcam_thread.is_alive():
665
+ stop_signal = True
666
+ webcam_thread.join(timeout=1.0)
667
+
668
+ # Toggle source
669
+ current_input_source = "Desktop" if current_input_source == "Webcam" else "Webcam"
670
+
671
+ # Restart stream if it was running
672
+ if webcam_thread:
673
+ return start_webcam_thread(
674
+ current_model_name,
675
+ current_mode,
676
+ current_webcam_id,
677
+ current_colormap
678
+ )
679
+
680
+ return f"✅ Switched to {current_input_source} input"
681
+
682
+ def get_screen_capture():
683
+ """Capture the desktop screen"""
684
+ import mss
685
+ sct = mss.mss()
686
+ monitor = sct.monitors[1] # Primary monitor
687
+ screenshot = np.array(sct.grab(monitor))
688
+ return cv2.cvtColor(screenshot, cv2.COLOR_BGRA2BGR)
689
+
690
+ def verify_model_path():
691
+ """Verify the GAN model file exists"""
692
+ model_path = './checkpoints/depth2image/latest_net_G_A.pth'
693
+ if not os.path.exists(model_path):
694
+ print(f"Model file not found at: {model_path}")
695
+ print("Current working directory:", os.getcwd())
696
+ return False
697
+ return True
698
+
699
+ # Add this check before starting the webcam:
700
+ def start_webcam_thread(model_name, mode, webcam_id=0, colormap="TURBO"):
701
+ global webcam_thread, stop_signal, current_colormap, current_mode
702
+ global current_model_name, current_webcam_id, current_direction
703
+
704
+ # Verify model exists if using GAN
705
+ if mode != "Depth Only" and not verify_model_path():
706
+ return "❌ GAN model file not found! Please check the model path."
707
+
708
+ # Update current settings
709
+ current_colormap = colormap
710
+ current_mode = mode
711
+ current_model_name = model_name
712
+ current_webcam_id = webcam_id
713
+
714
+ # Set direction based on mode
715
+ if mode == "Depth to Image":
716
+ current_direction = "Depth to Image"
717
+ elif mode == "Image to Depth":
718
+ current_direction = "Image to Depth"
719
+
720
+ # If a thread is already running, stop it
721
+ if webcam_thread and webcam_thread.is_alive():
722
+ stop_signal = True
723
+ webcam_thread.join(timeout=1.0)
724
+
725
+ # Reset stop signal
726
+ stop_signal = False
727
+
728
+ # Start new thread
729
+ encoder = {v: k for k, v in encoder2name.items()}[model_name]
730
+ use_gan = (mode != "Depth Only")
731
+
732
+ webcam_thread = threading.Thread(
733
+ target=virtual_webcam_stream,
734
+ args=(encoder, use_gan, int(webcam_id)),
735
+ daemon=True
736
+ )
737
+ webcam_thread.start()
738
+
739
+ return f"✅ Started virtual webcam: {mode} with {model_name} model using {colormap} colormap"
740
+
741
+ def update_colormap(colormap):
742
+ """Update the colormap without restarting the webcam"""
743
+ global current_colormap
744
+
745
+ if webcam_thread and webcam_thread.is_alive():
746
+ current_colormap = colormap
747
+ return f"✅ Updated colormap to {colormap}"
748
+ else:
749
+ return "⚠️ Webcam is not running. Please start it first."
750
+
751
+ def stop_webcam():
752
+ """Stop the webcam thread"""
753
+ global webcam_thread, stop_signal
754
+
755
+ if webcam_thread and webcam_thread.is_alive():
756
+ stop_signal = True
757
+ webcam_thread.join(timeout=1.0)
758
+ return "✅ Webcam stopped"
759
+ else:
760
+ return "No webcam is running"
761
+
762
+ def set_device_mode(choice):
763
+ """Set the device to use for model inference"""
764
+ global DEVICE
765
+ if choice == "Auto":
766
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
767
+ elif choice == "CUDA":
768
+ DEVICE = 'cuda'
769
+ else:
770
+ DEVICE = 'cpu'
771
+
772
+ # Reset loaded models to ensure they're on the correct device
773
+ global current_depth_model, current_gan_model
774
+ current_depth_model = None
775
+ current_gan_model = None
776
+
777
+ return f"Device set to: {DEVICE}"
778
+
779
+ def test_gan_model():
780
+ """Test if the GAN model loads and runs correctly"""
781
+ try:
782
+ # Try loading the model to verify it works
783
+ model = load_gan_model()
784
+ if model is None:
785
+ return "❌ Failed to load GAN model. Check console for errors."
786
+
787
+ # Create a simple test tensor
788
+ test_input = torch.zeros(1, 3, 64, 64).to(DEVICE)
789
+
790
+ # Try running inference
791
+ with torch.no_grad():
792
+ output = model(test_input)
793
+
794
+ return f"✅ GAN model loaded and tested successfully on {DEVICE}!"
795
+ except Exception as e:
796
+ return f"❌ Error testing GAN model: {str(e)}"
797
+
798
+ def upload_to_huggingface(hf_token, repo_name=None):
799
+ """Upload the GAN model to HuggingFace"""
800
+ if not repo_name:
801
+ repo_name = DEPTH2ROBOT_HF_REPO
802
+
803
+ if not os.path.exists(DEPTH2ROBOT_MODEL_PATH):
804
+ return "❌ Model file not found. Please make sure it exists at: ./checkpoints/depth2image/latest_net_G.pth"
805
+
806
+ try:
807
+ # Login to HuggingFace
808
+ login(token=hf_token)
809
+
810
+ # Upload the model file
811
+ upload_info = upload_file(
812
+ path_or_fileobj=DEPTH2ROBOT_MODEL_PATH,
813
+ path_in_repo="latest_net_G.pth",
814
+ repo_id=repo_name,
815
+ repo_type="model",
816
+ create_pr=False
817
+ )
818
+
819
+ # Create a simple model card if it doesn't exist
820
+ model_card = """---
821
+ tags:
822
+ - depth-to-robot
823
+ - image-to-image
824
+ - cyclegan
825
+ ---
826
+
827
+ # Depth2Robot GAN Model
828
+
829
+ This model transforms depth maps into robot-style images using CycleGAN.
830
+
831
+ ## Model Description
832
+
833
+ - This model was trained on depth maps and robot images.
834
+ - It converts grayscale depth maps to colorful robot-style imagery.
835
+ - Trained using CycleGAN architecture.
836
+
837
+ ## Usage
838
+
839
+ ```python
840
+ import torch
841
+ from huggingface_hub import hf_hub_download
842
+
843
+ # Download the model
844
+ model_path = hf_hub_download(repo_id="{repo_name}", filename="latest_net_G.pth")
845
+
846
+ # Load the model (you need to define the Generator class)
847
+ model = Generator()
848
+ model.load_state_dict(torch.load(model_path), strict=False)
849
+ model.eval()
850
+
851
+ # Use the model for inference
852
+ # ...
853
+ ```
854
+ """.format(repo_name=repo_name)
855
+
856
+ # Create a temporary model card file
857
+ with open("./README.md", "w") as f:
858
+ f.write(model_card)
859
+
860
+ # Upload the model card
861
+ upload_file(
862
+ path_or_fileobj="./README.md",
863
+ path_in_repo="README.md",
864
+ repo_id=repo_name,
865
+ repo_type="model",
866
+ create_pr=False
867
+ )
868
+
869
+ # Clean up
870
+ os.remove("./README.md")
871
+
872
+ return f"✅ Successfully uploaded model to HuggingFace!\n\nYou can view it at: https://huggingface.co/{repo_name}"
873
+ except Exception as e:
874
+ return f"❌ Error uploading to HuggingFace: {e}"
875
+
876
+ def toggle_mode():
877
+ """Quick toggle between Depth Only and Depth to Robot modes"""
878
+ global current_mode
879
+ if webcam_thread and webcam_thread.is_alive():
880
+ current_mode = "Depth Only" if current_mode == "Depth to Robot" else "Depth to Robot"
881
+ return start_webcam_thread(
882
+ current_model_name,
883
+ current_mode,
884
+ current_webcam_id,
885
+ current_colormap
886
+ )
887
+ return "⚠️ Webcam is not running. Please start it first."
888
+
889
+ def update_gan_preview():
890
+ """Update the GAN input preview"""
891
+ global current_gan_input
892
+ if current_gan_input is not None:
893
+ return current_gan_input
894
+ return None
895
+
896
+ def test_webcams():
897
+ """Test available webcams"""
898
+ available_cams = []
899
+ for i in range(10): # Test first 10 indices
900
+ cap = cv2.VideoCapture(i)
901
+ if cap.isOpened():
902
+ ret, _ = cap.read()
903
+ if ret:
904
+ available_cams.append(i)
905
+ cap.release()
906
+ return available_cams
907
+
908
+ def stop_gan():
909
+ """Stop the GAN processing"""
910
+ global current_gan_model
911
+ if current_gan_model is not None:
912
+ current_gan_model = None
913
+ return "✅ GAN processing stopped"
914
+ return "GAN was not running"
915
+
916
+ # --- Gradio UI ---
917
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple")) as demo:
918
+ gr.Markdown("# 🤖 Depth Anything V2 to Robot Virtual Webcam for Discord")
919
+
920
+ with gr.Row():
921
+ with gr.Column(scale=2):
922
+ with gr.Group():
923
+ gr.Markdown("### 📹 Webcam Settings")
924
+
925
+ # First define the status box that will be used in connections
926
+ webcam_status = gr.Textbox(
927
+ label="Status",
928
+ placeholder="Not started",
929
+ interactive=False
930
+ )
931
+
932
+ with gr.Row():
933
+ with gr.Column(scale=1):
934
+ model_dropdown = gr.Dropdown(
935
+ choices=list(encoder2name.values()),
936
+ value="Small",
937
+ label="Depth Model Size",
938
+ info="Smaller = faster, larger = more detailed"
939
+ )
940
+
941
+ with gr.Column(scale=1):
942
+ mode_dropdown = gr.Dropdown(
943
+ choices=["Depth Only", "Depth to Image", "Image to Depth"],
944
+ value="Depth to Image",
945
+ label="Output Mode",
946
+ info="Select conversion direction or depth visualization"
947
+ )
948
+
949
+ # Add in the UI section after mode_dropdown
950
+ with gr.Column(scale=1):
951
+ gan_source_radio = gr.Radio(
952
+ choices=["Local", "HuggingFace"],
953
+ value="Local",
954
+ label="GAN Model Source",
955
+ info="Choose between local model or download from HuggingFace"
956
+ )
957
+ gan_path = gr.Textbox(
958
+ value=DEPTH2ROBOT_LOCAL_PATH,
959
+ label="Local GAN Path/HF Repo",
960
+ info="Local path or HuggingFace repo name"
961
+ )
962
+ with gr.Column(scale=1):
963
+ colormap_dropdown = gr.Dropdown(
964
+ choices=list(DEPTH_COLORMAPS.keys()),
965
+ value="TURBO",
966
+ label="Depth Colormap",
967
+ info="Color scheme for depth visualization"
968
+ )
969
+
970
+ with gr.Row():
971
+ update_colormap_button = gr.Button("Update Colormap", variant="secondary")
972
+ reverse_depth_button = gr.Button("Reverse Depth Colors", variant="secondary")
973
+ bypass_depth_button = gr.Button("Toggle Depth Bypass", variant="secondary")
974
+ toggle_invert_button = gr.Button("Toggle Depth Inversion", variant="secondary")
975
+
976
+ webcam_id = gr.Number(
977
+ value=0,
978
+ label="Webcam ID",
979
+ info="Usually 0 for built-in webcam, try 1 or 2 for external cameras",
980
+ precision=0
981
+ )
982
+
983
+ with gr.Row():
984
+ start_button = gr.Button("▶️ Start Webcam", variant="primary", scale=2)
985
+ stop_button = gr.Button("⏹️ Stop Webcam", variant="stop", scale=1)
986
+
987
+ with gr.Row():
988
+ quick_mode_toggle = gr.Button("🔄 Toggle Depth/Robot Mode", variant="primary")
989
+ input_source_button = gr.Button("🔄 Toggle Webcam/Desktop", variant="secondary")
990
+
991
+ webcam_status = gr.Textbox(
992
+ label="Status",
993
+ placeholder="Not started",
994
+ interactive=False
995
+ )
996
+
997
+ with gr.Group():
998
+ gr.Markdown("### 🎨 Blending Settings")
999
+
1000
+ with gr.Row():
1001
+ blend_enabled_toggle = gr.Checkbox(
1002
+ label="Enable Blending",
1003
+ value=False,
1004
+ info="Blend original video on top of depth map"
1005
+ )
1006
+
1007
+ blend_opacity_slider = gr.Slider(
1008
+ minimum=0.0,
1009
+ maximum=1.0,
1010
+ value=0.1,
1011
+ step=0.1,
1012
+ label="Blend Opacity",
1013
+ info="0 = Depth only, 1 = Camera only"
1014
+ )
1015
+
1016
+ with gr.Row():
1017
+ update_blend_button = gr.Button("Update Blend Settings", variant="secondary")
1018
+
1019
+ # Testing section
1020
+ with gr.Group():
1021
+ gr.Markdown("### 🧪 Test Your GAN Model")
1022
+ test_button = gr.Button("🧪 Test GAN Model", variant="secondary")
1023
+ test_output = gr.Textbox(label="Test Results", interactive=False)
1024
+
1025
+ # Right column
1026
+ with gr.Column(scale=1):
1027
+ with gr.Group():
1028
+ gr.Markdown("### ⚙️ Advanced Settings")
1029
+
1030
+ device_radio = gr.Radio(
1031
+ choices=["Auto", "CUDA", "CPU"],
1032
+ value="Auto",
1033
+ label="Device Selection",
1034
+ info="Use CPU if you experience GPU errors"
1035
+ )
1036
+
1037
+ device_output = gr.Textbox(
1038
+ label="Device Status",
1039
+ value=f"Current device: {DEVICE}",
1040
+ interactive=False
1041
+ )
1042
+
1043
+ device_radio.change(fn=set_device_mode, inputs=device_radio, outputs=device_output)
1044
+
1045
+ with gr.Group():
1046
+ gr.Markdown("### 🚀 Upload to HuggingFace")
1047
+
1048
+ hf_token = gr.Textbox(
1049
+ label="HuggingFace API Token",
1050
+ placeholder="hf_...",
1051
+ type="password",
1052
+ info="Get your token from huggingface.co/settings/tokens"
1053
+ )
1054
+
1055
+ repo_name = gr.Textbox(
1056
+ label="Repository Name",
1057
+ placeholder=f"username/depth2robot-model",
1058
+ info="Format: username/repo-name"
1059
+ )
1060
+
1061
+ upload_button = gr.Button("📤 Upload Model", variant="secondary")
1062
+ upload_result = gr.Textbox(label="Upload Result", interactive=False)
1063
+
1064
+ with gr.Group():
1065
+ gr.Markdown("### 🎥 Test Webcams")
1066
+ test_webcams_button = gr.Button("Scan for Webcams")
1067
+ webcams_output = gr.Textbox(label="Available Webcams", interactive=False)
1068
+
1069
+ # Connect UI elements to functions - MOVED ALL CONNECTIONS HERE
1070
+ start_button.click(
1071
+ fn=start_webcam_thread,
1072
+ inputs=[model_dropdown, mode_dropdown, webcam_id, colormap_dropdown],
1073
+ outputs=webcam_status
1074
+ )
1075
+
1076
+ stop_button.click(fn=stop_webcam, inputs=[], outputs=webcam_status)
1077
+
1078
+ # Add this with the other connections near the bottom of the file
1079
+ update_colormap_button.click(
1080
+ fn=update_colormap,
1081
+ inputs=colormap_dropdown,
1082
+ outputs=webcam_status
1083
+ )
1084
+
1085
+ # Add with other connections
1086
+ input_source_button.click(
1087
+ fn=toggle_input_source,
1088
+ inputs=[],
1089
+ outputs=webcam_status
1090
+ )
1091
+
1092
+ # Add this with the other connections near the bottom of the file
1093
+ blend_enabled_toggle.change(
1094
+ fn=toggle_blend_enabled,
1095
+ inputs=[],
1096
+ outputs=webcam_status
1097
+ )
1098
+
1099
+ # Add this with the other connections near the bottom of the file
1100
+ update_blend_button.click(
1101
+ fn=update_blend_opacity,
1102
+ inputs=blend_opacity_slider,
1103
+ outputs=webcam_status
1104
+ )
1105
+
1106
+ # Add the toggle invert button connection here
1107
+ reverse_depth_button.click(
1108
+ fn=reverse_depth_colormap,
1109
+ inputs=[],
1110
+ outputs=webcam_status
1111
+ )
1112
+
1113
+ # Add with other connections
1114
+ gan_source_radio.change(
1115
+ fn=update_gan_source,
1116
+ inputs=[gan_source_radio, gan_path],
1117
+ outputs=webcam_status
1118
+ )
1119
+
1120
+ test_button.click(fn=test_gan_model, inputs=[], outputs=test_output)
1121
+
1122
+ # Add this with the other connections near the bottom of the file
1123
+ upload_button.click(
1124
+ fn=upload_to_huggingface,
1125
+ inputs=[hf_token, repo_name],
1126
+ outputs=upload_result
1127
+ )
1128
+
1129
+ # Add this with the other connections near the bottom of the file
1130
+ quick_mode_toggle.click(
1131
+ fn=toggle_mode,
1132
+ inputs=[],
1133
+ outputs=webcam_status
1134
+ )
1135
+
1136
+ # Add this with the other connections near the bottom of the file
1137
+ test_webcams_button.click(
1138
+ fn=test_webcams,
1139
+ inputs=[],
1140
+ outputs=webcams_output
1141
+ )
1142
+
1143
+ # Add to the connections section with the other button connections:
1144
+ bypass_depth_button.click(
1145
+ fn=toggle_depth_bypass,
1146
+ inputs=[],
1147
+ outputs=webcam_status
1148
+ )
1149
+
1150
+ # In the Gradio UI section, add this button:
1151
+ with gr.Row():
1152
+ stop_gan_button = gr.Button("⏹️ Stop GAN", variant="stop")
1153
+
1154
+ # Add the connection:
1155
+ stop_gan_button.click(
1156
+ fn=stop_gan,
1157
+ inputs=[],
1158
+ outputs=webcam_status
1159
+ )
1160
+
1161
+ # Add to the UI:
1162
+ with gr.Row():
1163
+ gan_status = gr.Textbox(
1164
+ label="GAN Status",
1165
+ value="Not loaded",
1166
+ interactive=False
1167
+ )
1168
+
1169
+ # Help section
1170
+ with gr.Accordion("Help & Troubleshooting", open=False):
1171
+ gr.Markdown("""
1172
+ ## Common Issues
1173
+
1174
+ ### Model not loading
1175
+ - Make sure your model file is in `./checkpoints/depth2image/latest_net_G.pth`
1176
+ - Try using the "Test GAN Model" button to check if it loads correctly
1177
+ - If you see errors about missing keys, the model structure is different - this script uses `strict=False` to load it anyway
1178
+
1179
+ ### Virtual camera not showing in Discord
1180
+ - Make sure OBS Virtual Camera is installed
1181
+ - Try stopping and starting the webcam
1182
+ - Restart Discord after starting the virtual camera
1183
+
1184
+ ### Performance issues
1185
+ - Use the "Small" depth model for better performance
1186
+ - Try the "CPU" device option if you're having GPU memory issues
1187
+ """)
1188
+
1189
+ if __name__ == "__main__":
1190
+ # Make sure the checkpoints directory exists
1191
+ os.makedirs("checkpoints/depth2image", exist_ok=True)
1192
+
1193
+ # Launch the Gradio interface
1194
+ demo.launch()
src/training/trainDepth2AnythingGAN.ipynb ADDED
The diff for this file is too large to render. See raw diff