Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
from typing import List | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus | |
CAM_METHODS = {"GradCAM": GradCAM, "GradCAM++": GradCAMPlusPlus} | |
def generate_grad_cam_overlay( | |
model: nn.Module, | |
target_layer: nn.Module, | |
input_tensor: torch.Tensor, # (B, C, H, W) | |
original_frames: List[Image.Image], | |
target_class_idx: int, | |
cam_method: str = "GradCAM", | |
temporal_aggregation_method: str = "mean", | |
target_frame_idx: int = -1, | |
): | |
""" | |
Generates a Grad-CAM heatmap and overlays it on the original image. | |
Returns a PIL Image of the overlay or None on error. | |
""" | |
if cam_method not in CAM_METHODS: | |
raise ValueError(f"Unsupported CAM method: {cam_method}.") | |
cam_method = CAM_METHODS[cam_method] | |
with cam_method(model=model, target_layers=[target_layer]) as cam: | |
targets = [ClassifierOutputTarget(target_class_idx)] | |
grayscale_cam = cam(input_tensor=input_tensor, targets=targets) | |
if temporal_aggregation_method == "mean": | |
aggregated_heatmap_2d = np.mean(grayscale_cam, axis=0) | |
if target_frame_idx == -1 or not (0 <= target_frame_idx < len(original_frames)): | |
idx_representative_frame = len(original_frames) // 2 | |
else: | |
idx_representative_frame = target_frame_idx | |
representative_frame = original_frames[idx_representative_frame] | |
rgb_image = np.array(representative_frame) / 255.0 | |
target_h, target_w = rgb_image.shape[:2] | |
heatmap_resized = cv2.resize( | |
aggregated_heatmap_2d, (target_w, target_h), interpolation=cv2.INTER_LINEAR | |
) | |
visualization = show_cam_on_image( | |
rgb_image, heatmap_resized, use_rgb=True, image_weight=0.5 | |
) | |
return Image.fromarray(visualization.astype(np.uint8)) | |