video-redaction / main.py
vikhyatk's picture
initial commit
195fd31
#!/usr/bin/env python3
import cv2, os, subprocess, argparse
from PIL import Image
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import numpy as np
from datetime import datetime
# Constants
TEST_MODE_DURATION = 3 # Process only first 3 seconds in test mode
FFMPEG_PRESETS = [
"ultrafast",
"superfast",
"veryfast",
"faster",
"fast",
"medium",
"slow",
"slower",
"veryslow",
]
FONT = cv2.FONT_HERSHEY_SIMPLEX # Font for bounding-box-style labels
# Detection parameters
IOU_THRESHOLD = 0.5 # IoU threshold for considering boxes related
# Hitmarker parameters
HITMARKER_SIZE = 20 # Size of the hitmarker in pixels
HITMARKER_GAP = 3 # Size of the empty space in the middle (reduced from 8)
HITMARKER_THICKNESS = 2 # Thickness of hitmarker lines
HITMARKER_COLOR = (255, 255, 255) # White color for hitmarker
HITMARKER_SHADOW_COLOR = (80, 80, 80) # Lighter gray for shadow effect
HITMARKER_SHADOW_OFFSET = 1 # Smaller shadow offset
def load_moondream():
"""Load Moondream model and tokenizer."""
model = AutoModelForCausalLM.from_pretrained(
"vikhyatk/moondream2", trust_remote_code=True, device_map={"": "cuda"}
)
tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2")
return model, tokenizer
def get_video_properties(video_path):
"""Get basic video properties."""
video = cv2.VideoCapture(video_path)
fps = video.get(cv2.CAP_PROP_FPS)
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
video.release()
return {"fps": fps, "frame_count": frame_count, "width": width, "height": height}
def is_valid_box(box):
"""Check if box coordinates are reasonable."""
x1, y1, x2, y2 = box
width = x2 - x1
height = y2 - y1
# Reject boxes that are too large (over 90% of frame in both dimensions)
if width > 0.9 and height > 0.9:
return False
# Reject boxes that are too small (less than 1% of frame)
if width < 0.01 or height < 0.01:
return False
return True
def split_frame_into_tiles(frame, rows, cols):
"""Split a frame into a grid of tiles."""
height, width = frame.shape[:2]
tile_height = height // rows
tile_width = width // cols
tiles = []
tile_positions = []
for i in range(rows):
for j in range(cols):
y1 = i * tile_height
y2 = (i + 1) * tile_height if i < rows - 1 else height
x1 = j * tile_width
x2 = (j + 1) * tile_width if j < cols - 1 else width
tile = frame[y1:y2, x1:x2]
tiles.append(tile)
tile_positions.append((x1, y1, x2, y2))
return tiles, tile_positions
def convert_tile_coords_to_frame(box, tile_pos, frame_shape):
"""Convert coordinates from tile space to frame space."""
frame_height, frame_width = frame_shape[:2]
tile_x1, tile_y1, tile_x2, tile_y2 = tile_pos
tile_width = tile_x2 - tile_x1
tile_height = tile_y2 - tile_y1
x1_tile_abs = box[0] * tile_width
y1_tile_abs = box[1] * tile_height
x2_tile_abs = box[2] * tile_width
y2_tile_abs = box[3] * tile_height
x1_frame_abs = tile_x1 + x1_tile_abs
y1_frame_abs = tile_y1 + y1_tile_abs
x2_frame_abs = tile_x1 + x2_tile_abs
y2_frame_abs = tile_y1 + y2_tile_abs
x1_norm = x1_frame_abs / frame_width
y1_norm = y1_frame_abs / frame_height
x2_norm = x2_frame_abs / frame_width
y2_norm = y2_frame_abs / frame_height
x1_norm = max(0.0, min(1.0, x1_norm))
y1_norm = max(0.0, min(1.0, y1_norm))
x2_norm = max(0.0, min(1.0, x2_norm))
y2_norm = max(0.0, min(1.0, y2_norm))
return [x1_norm, y1_norm, x2_norm, y2_norm]
def merge_tile_detections(tile_detections, iou_threshold=0.5):
"""Merge detections from different tiles using NMS-like approach."""
if not tile_detections:
return []
all_boxes = []
all_keywords = []
# Collect all boxes and their keywords
for detections in tile_detections:
for box, keyword in detections:
all_boxes.append(box)
all_keywords.append(keyword)
if not all_boxes:
return []
# Convert to numpy for easier processing
boxes = np.array(all_boxes)
# Calculate areas
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1) * (y2 - y1)
# Sort boxes by area
order = areas.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
if order.size == 1:
break
# Calculate IoU with rest of boxes
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1)
h = np.maximum(0.0, yy2 - yy1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
# Get indices of boxes with IoU less than threshold
inds = np.where(ovr <= iou_threshold)[0]
order = order[inds + 1]
return [(all_boxes[i], all_keywords[i]) for i in keep]
def detect_ads_in_frame(model, tokenizer, image, detect_keyword, rows=1, cols=1):
"""Detect objects in a frame using grid-based detection."""
if rows == 1 and cols == 1:
return detect_ads_in_frame_single(model, tokenizer, image, detect_keyword)
# Convert numpy array to PIL Image if needed
if not isinstance(image, Image.Image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Split frame into tiles
tiles, tile_positions = split_frame_into_tiles(image, rows, cols)
# Process each tile
tile_detections = []
for tile, tile_pos in zip(tiles, tile_positions):
# Convert tile to PIL Image
tile_pil = Image.fromarray(tile)
# Detect objects in tile
response = model.detect(tile_pil, detect_keyword)
if response and "objects" in response and response["objects"]:
objects = response["objects"]
tile_objects = []
for obj in objects:
if all(k in obj for k in ["x_min", "y_min", "x_max", "y_max"]):
box = [obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"]]
if is_valid_box(box):
# Convert tile coordinates to frame coordinates
frame_box = convert_tile_coords_to_frame(
box, tile_pos, image.shape
)
tile_objects.append((frame_box, detect_keyword))
if tile_objects: # Only append if we found valid objects
tile_detections.append(tile_objects)
# Merge detections from all tiles
merged_detections = merge_tile_detections(tile_detections)
return merged_detections
def detect_ads_in_frame_single(model, tokenizer, image, detect_keyword):
"""Single-frame detection function."""
detected_objects = []
# Convert numpy array to PIL Image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# Detect objects
response = model.detect(image, detect_keyword)
# Check if we have valid objects
if response and "objects" in response and response["objects"]:
objects = response["objects"]
for obj in objects:
if all(k in obj for k in ["x_min", "y_min", "x_max", "y_max"]):
box = [obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"]]
# If box is valid (not full-frame), add it
if is_valid_box(box):
detected_objects.append((box, detect_keyword))
return detected_objects
def draw_hitmarker(
frame, center_x, center_y, size=HITMARKER_SIZE, color=HITMARKER_COLOR, shadow=True
):
"""Draw a COD-style hitmarker cross with more space in the middle."""
half_size = size // 2
# Draw shadow first if enabled
if shadow:
# Top-left to center shadow
cv2.line(
frame,
(
center_x - half_size + HITMARKER_SHADOW_OFFSET,
center_y - half_size + HITMARKER_SHADOW_OFFSET,
),
(
center_x - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
center_y - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
),
HITMARKER_SHADOW_COLOR,
HITMARKER_THICKNESS,
)
# Top-right to center shadow
cv2.line(
frame,
(
center_x + half_size + HITMARKER_SHADOW_OFFSET,
center_y - half_size + HITMARKER_SHADOW_OFFSET,
),
(
center_x + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
center_y - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
),
HITMARKER_SHADOW_COLOR,
HITMARKER_THICKNESS,
)
# Bottom-left to center shadow
cv2.line(
frame,
(
center_x - half_size + HITMARKER_SHADOW_OFFSET,
center_y + half_size + HITMARKER_SHADOW_OFFSET,
),
(
center_x - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
center_y + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
),
HITMARKER_SHADOW_COLOR,
HITMARKER_THICKNESS,
)
# Bottom-right to center shadow
cv2.line(
frame,
(
center_x + half_size + HITMARKER_SHADOW_OFFSET,
center_y + half_size + HITMARKER_SHADOW_OFFSET,
),
(
center_x + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
center_y + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
),
HITMARKER_SHADOW_COLOR,
HITMARKER_THICKNESS,
)
# Draw main hitmarker
# Top-left to center
cv2.line(
frame,
(center_x - half_size, center_y - half_size),
(center_x - HITMARKER_GAP, center_y - HITMARKER_GAP),
color,
HITMARKER_THICKNESS,
)
# Top-right to center
cv2.line(
frame,
(center_x + half_size, center_y - half_size),
(center_x + HITMARKER_GAP, center_y - HITMARKER_GAP),
color,
HITMARKER_THICKNESS,
)
# Bottom-left to center
cv2.line(
frame,
(center_x - half_size, center_y + half_size),
(center_x - HITMARKER_GAP, center_y + HITMARKER_GAP),
color,
HITMARKER_THICKNESS,
)
# Bottom-right to center
cv2.line(
frame,
(center_x + half_size, center_y + half_size),
(center_x + HITMARKER_GAP, center_y + HITMARKER_GAP),
color,
HITMARKER_THICKNESS,
)
def draw_ad_boxes(frame, detected_objects, detect_keyword, box_style="censor"):
"""Draw detection visualizations over detected objects.
Args:
frame: The video frame to draw on
detected_objects: List of (box, keyword) tuples
detect_keyword: The detection keyword
box_style: Visualization style ('censor', 'bounding-box', or 'hitmarker')
"""
height, width = frame.shape[:2]
for box, keyword in detected_objects:
try:
# Convert normalized coordinates to pixel coordinates
x1 = int(box[0] * width)
y1 = int(box[1] * height)
x2 = int(box[2] * width)
y2 = int(box[3] * height)
# Ensure coordinates are within frame boundaries
x1 = max(0, min(x1, width - 1))
y1 = max(0, min(y1, height - 1))
x2 = max(0, min(x2, width - 1))
y2 = max(0, min(y2, height - 1))
# Only draw if box has reasonable size
if x2 > x1 and y2 > y1:
if box_style == "censor":
# Draw solid black rectangle
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 0), -1)
elif box_style == "bounding-box":
# Draw red rectangle with thicker line
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 3)
# Add label with background
label = detect_keyword # Use exact capitalization
label_size = cv2.getTextSize(label, FONT, 0.7, 2)[0]
cv2.rectangle(
frame, (x1, y1 - 25), (x1 + label_size[0], y1), (0, 0, 255), -1
)
cv2.putText(
frame,
label,
(x1, y1 - 6),
FONT,
0.7,
(255, 255, 255),
2,
cv2.LINE_AA,
)
elif box_style == "hitmarker":
# Calculate center of the box
center_x = (x1 + x2) // 2
center_y = (y1 + y2) // 2
# Draw hitmarker at the center
draw_hitmarker(frame, center_x, center_y)
# Optional: Add small label above hitmarker
label = detect_keyword # Use exact capitalization
label_size = cv2.getTextSize(label, FONT, 0.5, 1)[0]
cv2.putText(
frame,
label,
(center_x - label_size[0] // 2, center_y - HITMARKER_SIZE - 5),
FONT,
0.5,
HITMARKER_COLOR,
1,
cv2.LINE_AA,
)
except Exception as e:
print(f"Error drawing {box_style} style box: {str(e)}")
return frame
def filter_temporal_outliers(detections_dict):
"""Filter out extremely large detections that take up most of the frame.
Only keeps detections that are reasonable in size.
Args:
detections_dict: Dictionary of {frame_number: [(box, keyword), ...]}
"""
filtered_detections = {}
for t, detections in detections_dict.items():
# Only keep detections that aren't too large
valid_detections = []
for box, keyword in detections:
# Calculate box size as percentage of frame
width = box[2] - box[0]
height = box[3] - box[1]
area = width * height
# If box is less than 90% of frame, keep it
if area < 0.9:
valid_detections.append((box, keyword))
if valid_detections:
filtered_detections[t] = valid_detections
return filtered_detections
def describe_frames(
video_path, model, tokenizer, detect_keyword, test_mode=False, rows=1, cols=1
):
"""Extract and detect objects in frames."""
props = get_video_properties(video_path)
fps = props["fps"]
# If in test mode, only process first 3 seconds
if test_mode:
frame_count = min(int(fps * TEST_MODE_DURATION), props["frame_count"])
else:
frame_count = props["frame_count"]
ad_detections = {} # Store detection results by frame number
print("Extracting frames and detecting objects...")
video = cv2.VideoCapture(video_path)
# Process every frame
frame_count_processed = 0
with tqdm(total=frame_count) as pbar:
while frame_count_processed < frame_count:
ret, frame = video.read()
if not ret:
break
# Detect objects in the frame
detected_objects = detect_ads_in_frame(
model, tokenizer, frame, detect_keyword, rows=rows, cols=cols
)
# Store results for every frame, even if empty
ad_detections[frame_count_processed] = detected_objects
frame_count_processed += 1
pbar.update(1)
video.release()
if frame_count_processed == 0:
print("No frames could be read from video")
return {}
# Filter out only extremely large detections
ad_detections = filter_temporal_outliers(ad_detections)
return ad_detections
def create_detection_video(
video_path,
ad_detections,
detect_keyword,
output_path=None,
ffmpeg_preset="medium",
test_mode=False,
box_style="censor",
):
"""Create video with detection boxes."""
if output_path is None:
# Create outputs directory if it doesn't exist
outputs_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "outputs"
)
os.makedirs(outputs_dir, exist_ok=True)
# Clean the detect_keyword for filename
safe_keyword = "".join(
x for x in detect_keyword if x.isalnum() or x in (" ", "_", "-")
)
safe_keyword = safe_keyword.replace(" ", "_")
# Create output filename
base_name = os.path.splitext(os.path.basename(video_path))[0]
output_path = os.path.join(
outputs_dir, f"{box_style}_{safe_keyword}_{base_name}.mp4"
)
print(f"Will save output to: {output_path}")
props = get_video_properties(video_path)
fps, width, height = props["fps"], props["width"], props["height"]
# If in test mode, only process first few seconds
if test_mode:
frame_count = min(int(fps * TEST_MODE_DURATION), props["frame_count"])
else:
frame_count = props["frame_count"]
video = cv2.VideoCapture(video_path)
# Create temp output path by adding _temp before the extension
base, ext = os.path.splitext(output_path)
temp_output = f"{base}_temp{ext}"
out = cv2.VideoWriter(
temp_output, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
)
print("Creating detection video...")
frame_count_processed = 0
with tqdm(total=frame_count) as pbar:
while frame_count_processed < frame_count:
ret, frame = video.read()
if not ret:
break
# Get detections for this exact frame
if frame_count_processed in ad_detections:
current_detections = ad_detections[frame_count_processed]
if current_detections:
frame = draw_ad_boxes(
frame, current_detections, detect_keyword, box_style=box_style
)
out.write(frame)
frame_count_processed += 1
pbar.update(1)
video.release()
out.release()
# Convert to web-compatible format more efficiently
try:
subprocess.run(
[
"ffmpeg",
"-y",
"-i",
temp_output,
"-c:v",
"libx264",
"-preset",
ffmpeg_preset,
"-crf",
"23",
"-movflags",
"+faststart", # Better web playback
"-loglevel",
"error",
output_path,
],
check=True,
)
os.remove(temp_output) # Remove the temporary file
if not os.path.exists(output_path):
print(
f"Warning: FFmpeg completed but output file not found at {output_path}"
)
return None
return output_path
except subprocess.CalledProcessError as e:
print(f"Error running FFmpeg: {str(e)}")
if os.path.exists(temp_output):
os.remove(temp_output)
return None
def process_video(
video_path,
detect_keyword,
test_mode=False,
ffmpeg_preset="medium",
rows=1,
cols=1,
box_style="censor",
):
"""Process a single video file."""
print(f"\nProcessing: {video_path}")
print(f"Looking for: {detect_keyword}")
# Load model
print("Loading Moondream model...")
model, tokenizer = load_moondream()
# Process video - detect objects
ad_detections = describe_frames(
video_path, model, tokenizer, detect_keyword, test_mode, rows, cols
)
# Create video with detection boxes
output_path = create_detection_video(
video_path,
ad_detections,
detect_keyword,
ffmpeg_preset=ffmpeg_preset,
test_mode=test_mode,
box_style=box_style,
)
if output_path is None:
print("\nError: Failed to create output video")
return None
print(f"\nOutput saved to: {output_path}")
return output_path
def main():
"""Process all videos in the inputs directory."""
parser = argparse.ArgumentParser(
description="Detect objects in videos using Moondream2"
)
parser.add_argument(
"--test", action="store_true", help="Process only first 3 seconds of each video"
)
parser.add_argument(
"--preset",
choices=FFMPEG_PRESETS,
default="medium",
help="FFmpeg encoding preset (default: medium). Faster presets = lower quality",
)
parser.add_argument(
"--detect",
type=str,
default="face",
help='Object to detect in the video (default: face, use --detect "thing to detect" to override)',
)
parser.add_argument(
"--rows",
type=int,
default=1,
help="Number of rows to split each frame into (default: 1)",
)
parser.add_argument(
"--cols",
type=int,
default=1,
help="Number of columns to split each frame into (default: 1)",
)
parser.add_argument(
"--box-style",
choices=["censor", "bounding-box", "hitmarker"],
default="censor",
help="Style of detection visualization (default: censor)",
)
args = parser.parse_args()
input_dir = "inputs"
os.makedirs(input_dir, exist_ok=True)
os.makedirs("outputs", exist_ok=True)
video_files = [
f
for f in os.listdir(input_dir)
if f.lower().endswith((".mp4", ".avi", ".mov", ".mkv", ".webm"))
]
if not video_files:
print("No video files found in 'inputs' directory")
return
print(f"Found {len(video_files)} videos to process")
print(f"Will detect: {args.detect}")
if args.test:
print("Running in test mode - processing only first 3 seconds of each video")
print(f"Using FFmpeg preset: {args.preset}")
print(f"Grid size: {args.rows}x{args.cols}")
print(f"Box style: {args.box_style}")
success_count = 0
for video_file in video_files:
video_path = os.path.join(input_dir, video_file)
output_path = process_video(
video_path,
args.detect,
test_mode=args.test,
ffmpeg_preset=args.preset,
rows=args.rows,
cols=args.cols,
box_style=args.box_style,
)
if output_path:
success_count += 1
print(
f"\nProcessing complete. Successfully processed {success_count} out of {len(video_files)} videos."
)
if __name__ == "__main__":
main()