import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image import os import numpy as np # Generator architecture (simplified ResNet) class ResidualBlock(nn.Module): def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv_block = nn.Sequential( # Changed from 'block' to 'conv_block' nn.ReflectionPad2d(1), nn.Conv2d(channels, channels, 3), nn.InstanceNorm2d(channels), nn.ReLU(inplace=True), nn.ReflectionPad2d(1), nn.Conv2d(channels, channels, 3), nn.InstanceNorm2d(channels) ) def forward(self, x): return x + self.conv_block(x) # Changed from 'block' to 'conv_block' class Generator(nn.Module): def __init__(self, input_channels=3, output_channels=3, n_residual_blocks=9): super(Generator, self).__init__() # Initial convolution model = [ nn.ReflectionPad2d(3), nn.Conv2d(input_channels, 64, 7), nn.InstanceNorm2d(64), nn.ReLU(inplace=True) ] # Downsampling in_features = 64 out_features = in_features * 2 for _ in range(2): model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features * 2 # Residual blocks for _ in range(n_residual_blocks): model += [ResidualBlock(in_features)] # Upsampling out_features = in_features // 2 for _ in range(2): model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features // 2 # Output layer model += [ nn.ReflectionPad2d(3), nn.Conv2d(64, output_channels, 7), nn.Tanh() ] self.model = nn.Sequential(*model) def forward(self, x): return self.model(x) # Image preprocessing def preprocess_image(image_path): image = Image.open(image_path).convert('RGB') transform = transforms.Compose([ transforms.Resize(256), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) return transform(image).unsqueeze(0) # Image postprocessing def postprocess_image(tensor): tensor = tensor.squeeze(0).cpu() tensor = (tensor + 1) / 2 tensor = tensor.clamp(0, 1) tensor = tensor.permute(1, 2, 0).numpy() return (tensor * 255).astype(np.uint8) # Model loading def load_model(model_path): model = Generator() if os.path.exists(model_path): print(f"Loading model from {model_path}") state_dict = torch.load(model_path, map_location='cpu') try: model.load_state_dict(state_dict) except Exception as e: print(f"Warning: {e}") # Try loading with strict=False model.load_state_dict(state_dict, strict=False) print("Loaded model with strict=False") else: print(f"Error: Model file not found at {model_path}") return None model.eval() return model # Inference function # Update the transform_image function to handle numpy arrays from Gradio def transform_image(input_image, direction): if input_image is None: print("No input image provided") return None try: # Ensure input image is RGB if len(input_image.shape) == 2: # Grayscale input_image = np.stack([input_image] * 3, axis=-1) elif input_image.shape[-1] == 4: # RGBA input_image = input_image[..., :3] if direction == "Depth to Image": model_path = "./checkpoints/depth2image/latest_net_G_A.pth" else: model_path = "./checkpoints/depth2image/latest_net_G_B.pth" # Load model model = load_model(model_path) if model is None: print(f"Failed to load model from {model_path}") return None # Convert numpy array to PIL Image input_pil = Image.fromarray(input_image.astype('uint8'), 'RGB') # Create transforms transform = transforms.Compose([ transforms.Resize(256), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Process image input_tensor = transform(input_pil).unsqueeze(0) # Generate output with torch.no_grad(): output_tensor = model(input_tensor) # Convert to image output_image = postprocess_image(output_tensor) return output_image except Exception as e: print(f"Error in transform_image: {e}") import traceback traceback.print_exc() return None # Update the Gradio interface with gr.Blocks(title="CycleGAN Depth2Image Test", analytics_enabled=False) as demo: gr.Markdown("## Test CycleGAN Depth2Image Model") with gr.Row(): with gr.Column(): input_image = gr.Image( label="Input Image", type="numpy", height=256, width=256 ) direction = gr.Radio( choices=["Depth to Image", "Image to Depth"], value="Depth to Image", label="Conversion Direction" ) transform_btn = gr.Button("Transform", variant="primary") with gr.Column(): output_image = gr.Image( label="Generated Output", height=256, width=256 ) error_output = gr.Textbox( label="Status", interactive=False ) # Connect components transform_btn.click( fn=transform_image, inputs=[input_image, direction], outputs=output_image ) gr.Markdown(""" ### Instructions: 1. Upload an image 2. Select conversion direction: - "Depth to Image" converts depth maps to realistic images - "Image to Depth" converts realistic images to depth maps 3. Click "Transform" to generate the output Note: Input images will be resized to 256x256 pixels. """) if __name__ == "__main__": # Make sure checkpoints directory exists os.makedirs("checkpoints/depth2image", exist_ok=True) # Launch with custom server configuration demo.queue(max_size=5).launch( server_name="0.0.0.0", # Allow external connections server_port=7860, # Set specific port show_error=True, # Show detailed errors debug=True # Enable debug mode )