Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import to_pil_image | |
| import matplotlib.pyplot as plt | |
| from torch.utils.data import DataLoader, Dataset | |
| from PIL import Image | |
| import os | |
| import numpy as np | |
| import warnings | |
| from transformers import AutoProcessor, CLIPModel | |
| import cv2 | |
| import re | |
| from huggingface_hub import hf_hub_download | |
| import io | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| class ImageDataset(Dataset): | |
| def __init__(self, image, transform=None, face_only=True, dataset_name=None): | |
| # Modified to accept a single PIL image instead of a list of paths | |
| self.image = image | |
| self.transform = transform | |
| self.face_only = face_only | |
| self.dataset_name = dataset_name | |
| # Load face detector | |
| self.face_detector = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') | |
| def __len__(self): | |
| return 1 # Only one image | |
| def detect_face(self, image_np): | |
| """Detect face in image and return the face region""" | |
| gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY) | |
| faces = self.face_detector.detectMultiScale(gray, 1.1, 5) | |
| # If no face is detected, use the whole image | |
| if len(faces) == 0: | |
| print("No face detected, using whole image") | |
| h, w = image_np.shape[:2] | |
| return (0, 0, w, h), image_np | |
| # Get the largest face | |
| if len(faces) > 1: | |
| # Choose the largest face by area | |
| areas = [w*h for (x, y, w, h) in faces] | |
| largest_idx = np.argmax(areas) | |
| x, y, w, h = faces[largest_idx] | |
| else: | |
| x, y, w, h = faces[0] | |
| # Add padding around the face (5% on each side - reduced padding) | |
| padding_x = int(w * 0.05) | |
| padding_y = int(h * 0.05) | |
| # Ensure padding doesn't go outside image bounds | |
| x1 = max(0, x - padding_x) | |
| y1 = max(0, y - padding_y) | |
| x2 = min(image_np.shape[1], x + w + padding_x) | |
| y2 = min(image_np.shape[0], y + h + padding_y) | |
| # Extract the face region | |
| face_img = image_np[y1:y2, x1:x2] | |
| return (x1, y1, x2-x1, y2-y1), face_img | |
| def __getitem__(self, idx): | |
| # Use the single image provided | |
| image_np = np.array(self.image) | |
| label = 0 # Default label; will be overridden by prediction in app.py | |
| # Store original image for visualization | |
| original_image = self.image.copy() | |
| # Detect face if required | |
| if self.face_only: | |
| face_box, face_img_np = self.detect_face(image_np) | |
| face_img = Image.fromarray(face_img_np) | |
| # Apply transform to face image | |
| if self.transform: | |
| face_tensor = self.transform(face_img) | |
| else: | |
| face_tensor = transforms.ToTensor()(face_img) | |
| return face_tensor, label, "uploaded_image", original_image, face_box, self.dataset_name | |
| else: | |
| # Process the whole image | |
| if self.transform: | |
| image_tensor = self.transform(self.image) | |
| else: | |
| image_tensor = transforms.ToTensor()(self.image) | |
| return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name | |
| class GradCAM: | |
| def __init__(self, model, target_layer): | |
| self.model = model | |
| self.target_layer = target_layer | |
| self.gradients = None | |
| self.activations = None | |
| self._register_hooks() | |
| def _register_hooks(self): | |
| def forward_hook(module, input, output): | |
| if isinstance(output, tuple): | |
| self.activations = output[0] | |
| else: | |
| self.activations = output | |
| def backward_hook(module, grad_in, grad_out): | |
| if isinstance(grad_out, tuple): | |
| self.gradients = grad_out[0] | |
| else: | |
| self.gradients = grad_out | |
| layer = dict([*self.model.named_modules()])[self.target_layer] | |
| layer.register_forward_hook(forward_hook) | |
| layer.register_backward_hook(backward_hook) | |
| def generate(self, input_tensor, class_idx): | |
| self.model.zero_grad() | |
| try: | |
| # Use only the vision part of the model for gradient calculation | |
| vision_outputs = self.model.vision_model(pixel_values=input_tensor) | |
| # Get the pooler output | |
| features = vision_outputs.pooler_output | |
| # Create a dummy gradient for the feature based on the class idx | |
| one_hot = torch.zeros_like(features) | |
| one_hot[0, class_idx] = 1 | |
| # Manually backpropagate | |
| features.backward(gradient=one_hot) | |
| # Check for None values | |
| if self.gradients is None or self.activations is None: | |
| print("Warning: Gradients or activations are None. Using fallback CAM.") | |
| return np.ones((14, 14), dtype=np.float32) * 0.5 | |
| # Process gradients and activations | |
| if len(self.gradients.shape) == 4: # Expected shape for convolutional layers | |
| gradients = self.gradients.cpu().detach().numpy() | |
| activations = self.activations.cpu().detach().numpy() | |
| weights = np.mean(gradients, axis=(2, 3)) | |
| cam = np.zeros(activations.shape[2:], dtype=np.float32) | |
| for i, w in enumerate(weights[0]): | |
| cam += w * activations[0, i, :, :] | |
| else: | |
| # Handle transformer model format | |
| gradients = self.gradients.cpu().detach().numpy() | |
| activations = self.activations.cpu().detach().numpy() | |
| if len(activations.shape) == 3: # [batch, sequence_length, hidden_dim] | |
| seq_len = activations.shape[1] | |
| # CLIP ViT typically has 196 patch tokens (14×14) + 1 class token = 197 | |
| if seq_len == 197: | |
| # Skip the class token (first token) and reshape the patch tokens into a square | |
| patch_tokens = activations[0, 1:, :] # Remove the class token | |
| # Take the mean across the hidden dimension | |
| token_importance = np.mean(np.abs(patch_tokens), axis=1) | |
| # Reshape to the expected grid size (14×14 for CLIP ViT-B/16) | |
| cam = token_importance.reshape(14, 14) | |
| else: | |
| # Try to find factors close to a square | |
| side_len = int(np.sqrt(seq_len)) | |
| # Use the mean across features as importance | |
| token_importance = np.mean(np.abs(activations[0]), axis=1) | |
| # Create as square-like shape as possible | |
| cam = np.zeros((side_len, side_len)) | |
| # Fill the cam with available values | |
| flat_cam = cam.flatten() | |
| flat_cam[:min(len(token_importance), len(flat_cam))] = token_importance[:min(len(token_importance), len(flat_cam))] | |
| cam = flat_cam.reshape(side_len, side_len) | |
| else: | |
| # Fallback | |
| print("Using fallback CAM shape (14x14)") | |
| cam = np.ones((14, 14), dtype=np.float32) * 0.5 # Default fallback | |
| # Ensure we have valid values | |
| if cam is None or cam.size == 0: | |
| print("Warning: Generated CAM is empty. Using fallback.") | |
| cam = np.ones((14, 14), dtype=np.float32) * 0.5 | |
| cam = np.maximum(cam, 0) | |
| if np.max(cam) > 0: | |
| cam = cam / np.max(cam) | |
| return cam | |
| except Exception as e: | |
| print(f"Error in GradCAM.generate: {str(e)}") | |
| return np.ones((14, 14), dtype=np.float32) * 0.5 | |
| def overlay_cam_on_image(image, cam, face_box=None, alpha=0.5): | |
| if face_box is not None: | |
| x, y, w, h = face_box | |
| # Create a mask for the entire image (all zeros initially) | |
| img_np = np.array(image) | |
| full_h, full_w = img_np.shape[:2] | |
| full_cam = np.zeros((full_h, full_w), dtype=np.float32) | |
| # Resize CAM to match face region | |
| face_cam = cv2.resize(cam, (w, h)) | |
| # Copy the face CAM into the full image CAM at the face position | |
| full_cam[y:y+h, x:x+w] = face_cam | |
| # Convert full CAM to image | |
| cam_resized = Image.fromarray((full_cam * 255).astype(np.uint8)) | |
| cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] # Apply colormap | |
| cam_colormap = (cam_colormap * 255).astype(np.uint8) | |
| else: | |
| cam_resized = Image.fromarray((cam * 255).astype(np.uint8)).resize(image.size, Image.BILINEAR) | |
| cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] # Apply colormap | |
| cam_colormap = (cam_colormap * 255).astype(np.uint8) | |
| blended = Image.blend(image, Image.fromarray(cam_colormap), alpha=alpha) | |
| return blended | |
| def save_comparison(image, cam, overlay, face_box=None): | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
| # Original Image | |
| axes[0].imshow(image) | |
| axes[0].set_title("Original") | |
| if face_box is not None: | |
| x, y, w, h = face_box | |
| rect = plt.Rectangle((x, y), w, h, edgecolor='lime', linewidth=2, fill=False) | |
| axes[0].add_patch(rect) | |
| axes[0].axis("off") | |
| # CAM | |
| if face_box is not None: | |
| # Create a full image CAM that highlights only the face | |
| img_np = np.array(image) | |
| h, w = img_np.shape[:2] | |
| full_cam = np.zeros((h, w)) | |
| x, y, fw, fh = face_box | |
| # Resize CAM to face size | |
| face_cam = cv2.resize(cam, (fw, fh)) | |
| # Place it in the right position | |
| full_cam[y:y+fh, x:x+fw] = face_cam | |
| axes[1].imshow(full_cam, cmap="jet") | |
| else: | |
| axes[1].imshow(cam, cmap="jet") | |
| axes[1].set_title("CAM") | |
| axes[1].axis("off") | |
| # Overlay | |
| axes[2].imshow(overlay) | |
| axes[2].set_title("Overlay") | |
| axes[2].axis("off") | |
| plt.tight_layout() | |
| # Convert plot to PIL Image for Streamlit display | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png", bbox_inches="tight") | |
| plt.close() | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def load_clip_model(): | |
| # Modified to load checkpoint from Hugging Face | |
| model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | |
| processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
| checkpoint_path = hf_hub_download(repo_id="drg31/model", filename="model.pth") | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| model_dict = model.state_dict() | |
| checkpoint = {k: v for k, v in checkpoint.items() if k in model_dict and model_dict[k].shape == v.shape} | |
| model_dict.update(checkpoint) | |
| model.load_state_dict(model_dict) | |
| model.eval() | |
| return model, processor | |
| def get_target_layer_clip(model): | |
| # For CLIP ViT large, use a layer that will have activations in the right format | |
| return "vision_model.encoder.layers.23" | |
| def process_images(dataloader, model, cam_extractor, device, pred_class): | |
| # Modified to process a single image and return results for Streamlit | |
| for batch in dataloader: | |
| input_tensor, label, img_paths, original_images, face_boxes, dataset_names = batch | |
| original_image = original_images[0] | |
| face_box = face_boxes[0] | |
| print(f"Processing uploaded image...") | |
| # Move tensors and model to device | |
| input_tensor = input_tensor.to(device) | |
| model = model.to(device) | |
| try: | |
| # Forward pass and Grad-CAM generation | |
| output = model.vision_model(pixel_values=input_tensor).pooler_output | |
| class_idx = pred_class # Use predicted class from app.py | |
| cam = cam_extractor.generate(input_tensor, class_idx) | |
| # Generate CAM image | |
| if face_box is not None: | |
| x, y, w, h = face_box | |
| img_np = np.array(original_image) | |
| h_full, w_full = img_np.shape[:2] | |
| full_cam = np.zeros((h_full, w_full)) | |
| face_cam = cv2.resize(cam, (w, h)) | |
| full_cam[y:y+h, x:x+w] = face_cam | |
| cam_img = Image.fromarray((plt.cm.jet(full_cam)[:, :, :3] * 255).astype(np.uint8)) | |
| else: | |
| cam_resized = Image.fromarray((cam * 255).astype(np.uint8)).resize(original_image.size, Image.BILINEAR) | |
| cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] | |
| cam_colormap = (cam_colormap * 255).astype(np.uint8) | |
| cam_img = Image.fromarray(cam_colormap) | |
| # Generate Overlay | |
| overlay = overlay_cam_on_image(original_image, cam, face_box) | |
| # Generate Comparison | |
| comparison = save_comparison(original_image, cam, overlay, face_box) | |
| return cam, cam_img, overlay, comparison | |
| except Exception as e: | |
| print(f"Error processing image: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| # Return default values in case of error | |
| default_cam = np.ones((14, 14), dtype=np.float32) * 0.5 | |
| cam_resized = Image.fromarray((default_cam * 255).astype(np.uint8)).resize(original_image.size, Image.BILINEAR) | |
| cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] | |
| cam_colormap = (cam_colormap * 255).astype(np.uint8) | |
| cam_img = Image.fromarray(cam_colormap) | |
| overlay = overlay_cam_on_image(original_image, default_cam, face_box) | |
| comparison = save_comparison(original_image, default_cam, overlay, face_box) | |
| return default_cam, cam_img, overlay, comparison |