ParsaKhaz commited on
Commit
2db60a3
·
verified ·
1 Parent(s): 1162936

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/cig.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ examples/clip-cig.gif filter=lfs diff=lfs merge=lfs -text
38
+ examples/clip-conflag.gif filter=lfs diff=lfs merge=lfs -text
39
+ examples/clip-gu.gif filter=lfs diff=lfs merge=lfs -text
40
+ examples/conf.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ examples/gun.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ examples/homealone.mp4 filter=lfs diff=lfs merge=lfs -text
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
.gitignore ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ *.dll
23
+
24
+ # Virtual Environment
25
+ venv/
26
+ env/
27
+ ENV/
28
+ .venv/
29
+
30
+ # IDE
31
+ .idea/
32
+ .vscode/
33
+ *.swp
34
+ *.swo
35
+
36
+ # Project specific
37
+ inputs/*
38
+ outputs/*
39
+ !inputs/.gitkeep
40
+ !outputs/.gitkeep
41
+ inputs/
42
+ outputs/
43
+
44
+ # Model files
45
+ *.pth
46
+ *.onnx
47
+ *.pt
48
+
49
+ # Logs
50
+ *.log
51
+
52
+ certificate.pem
README.md CHANGED
@@ -1,12 +1,165 @@
1
- ---
2
- title: Promptable Content Moderation
3
- emoji:
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.16.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: promptable-content-moderation
3
+ app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 5.16.1
6
+ ---
7
+ # Promptable Content Moderation with Moondream
8
+
9
+ Welcome to the future of content moderation with Moondream 2B, a powerful and lightweight vision-language model that enables detection and moderation of video content using natural language prompts.
10
+
11
+ [Try it now.](https://huggingface.co/spaces/moondream/content-moderation)
12
+
13
+ ## Features
14
+
15
+ - Content moderation through natural language prompts
16
+ - Multiple visualization styles
17
+ - Intelligent scene detection and tracking:
18
+ - DeepSORT tracking with scene-aware reset
19
+ - Persistent moderation across frames
20
+ - Smart tracker reset at scene boundaries
21
+ - Optional grid-based detection for improved accuracy on complex scenes
22
+ - Frame-by-frame processing with IoU-based merging
23
+ - Web-compatible output format
24
+ - Test mode (process only first X seconds)
25
+ - Advanced moderation analysis with multiple visualization plots
26
+
27
+ ## Examples
28
+
29
+ | Example Outputs |
30
+ |------|
31
+ | ![Demo](./examples/clip-cig.gif) |
32
+ | ![Demo](./examples/clip-gu.gif) |
33
+ | ![Demo](./examples/clip-conflag.gif) |
34
+
35
+ ## Requirements
36
+
37
+ ### Python Dependencies
38
+
39
+ For Windows users, before installing other requirements, first install PyTorch with CUDA support:
40
+
41
+ ```bash
42
+ pip install torch==2.5.1+cu121 torchvision==0.20.1+cu121 --index-url https://download.pytorch.org/whl/cu121
43
+ ```
44
+
45
+ Then install the remaining dependencies:
46
+
47
+ ```bash
48
+ pip install -r requirements.txt
49
+ ```
50
+
51
+ ### System Requirements
52
+
53
+ - FFmpeg (required for video processing)
54
+ - libvips (required for image processing)
55
+
56
+ Installation by platform:
57
+
58
+ - Ubuntu/Debian: `sudo apt-get install ffmpeg libvips`
59
+ - macOS: `brew install ffmpeg libvips`
60
+ - Windows:
61
+ - Download FFmpeg from [ffmpeg.org](https://ffmpeg.org/download.html)
62
+ - Follow [libvips Windows installation guide](https://docs.moondream.ai/quick-start)
63
+
64
+ ## Installation
65
+
66
+ 1. Clone this repository and create a new virtual environment:
67
+
68
+ ```bash
69
+ git clone https://github.com/vikhyat/moondream/blob/main/recipes/promptable-video-redaction
70
+ python -m venv .venv
71
+ source .venv/bin/activate # On Windows: .venv\Scripts\activate
72
+ ```
73
+
74
+ 2. Install Python dependencies:
75
+
76
+ ```bash
77
+ pip install -r requirements.txt
78
+ ```
79
+
80
+ 3. Install ffmpeg and libvips:
81
+ - On Ubuntu/Debian: `sudo apt-get install ffmpeg libvips`
82
+ - On macOS: `brew install ffmpeg`
83
+ - On Windows: Download from [ffmpeg.org](https://ffmpeg.org/download.html)
84
+
85
+ > Downloading libvips for Windows requires some additional steps, see [here](https://docs.moondream.ai/quick-start)
86
+
87
+ ## Usage
88
+
89
+ The easiest way to use this tool is through its web interface, which provides a user-friendly experience for video content moderation.
90
+
91
+ ### Web Interface
92
+
93
+ 1. Start the web interface:
94
+
95
+ ```bash
96
+ python app.py
97
+ ```
98
+
99
+ 2. Open the provided URL in your browser (typically <http://localhost:7860>)
100
+
101
+ 3. Use the interface to:
102
+ - Upload your video file
103
+ - Specify content to moderate (e.g., "face", "cigarette", "gun")
104
+ - Choose redaction style (default: obfuscated-pixel)
105
+ - OPTIONAL: Configure advanced settings
106
+ - Processing speed/quality
107
+ - Grid size for detection
108
+ - Test mode for quick validation (default: on, 3 seconds)
109
+ - Process the video and download results
110
+ - Analyze detection patterns with visualization tools
111
+
112
+ ## Output Files
113
+
114
+ The tool generates two types of output files in the `outputs` directory:
115
+
116
+ 1. Processed Videos:
117
+ - Format: `[style]_[content_type]_[original_filename].mp4`
118
+ - Example: `censor_inappropriate_video.mp4`
119
+
120
+ 2. Detection Data:
121
+ - Format: `[style]_[content_type]_[original_filename]_detections.json`
122
+ - Contains frame-by-frame detection information
123
+ - Used for visualization and analysis
124
+
125
+ ## Technical Details
126
+
127
+ ### Scene Detection and Tracking
128
+
129
+ The tool uses advanced scene detection and object tracking:
130
+
131
+ 1. Scene Detection:
132
+ - Powered by PySceneDetect's ContentDetector
133
+ - Automatically identifies scene changes in videos
134
+ - Configurable detection threshold (default: 30.0)
135
+ - Helps maintain tracking accuracy across scene boundaries
136
+
137
+ 2. Object Tracking:
138
+ - DeepSORT tracking for consistent object identification
139
+ - Automatic tracker reset at scene changes
140
+ - Maintains object identity within scenes
141
+ - Prevents tracking errors across scene boundaries
142
+
143
+ 3. Integration Benefits:
144
+ - More accurate object tracking
145
+ - Better handling of scene transitions
146
+ - Reduced false positives in tracking
147
+ - Improved tracking consistency
148
+
149
+ ## Best Practices
150
+
151
+ - Use test mode for initial configuration
152
+ - Enable grid-based detection for complex scenes
153
+ - Choose appropriate redaction style based on content type:
154
+ - Censor: Complete content blocking
155
+ - Blur styles: Less intrusive moderation
156
+ - Bounding Box: Content review and analysis
157
+ - Monitor system resources during processing
158
+ - Use appropriate processing quality settings based on your needs
159
+
160
+ ## Notes
161
+
162
+ - Processing time depends on video length, resolution, GPU availability, and chosen settings
163
+ - GPU is strongly recommended for faster processing
164
+ - Grid-based detection increases accuracy but requires more processing time (each grid cell is processed independently)
165
+ - Test mode processes only first X seconds (default: 3 seconds) for quick validation
app.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import gradio as gr
3
+ import os
4
+ from main import load_moondream, process_video, load_sam_model
5
+ import shutil
6
+ import torch
7
+ from visualization import visualize_detections
8
+ from persistence import load_detection_data
9
+ import matplotlib.pyplot as plt
10
+ import io
11
+ from PIL import Image
12
+ import pandas as pd
13
+ from video_visualization import create_video_visualization
14
+
15
+ # import spaces
16
+ import spaces
17
+ # Get absolute path to workspace root
18
+ WORKSPACE_ROOT = os.path.dirname(os.path.abspath(__file__))
19
+
20
+ # Check CUDA availability
21
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
22
+ # We want to get True
23
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
24
+ # GPU Name
25
+
26
+ # Initialize Moondream model globally for reuse (will be loaded on first use)
27
+ model, tokenizer = None, None
28
+
29
+ # Uncomment for Hugging Face Spaces
30
+ @spaces.GPU(duration=120)
31
+ def process_video_file(
32
+ video_file, target_object, box_style, ffmpeg_preset, grid_rows, grid_cols, test_mode, test_duration
33
+ ):
34
+ """Process a video file through the Gradio interface."""
35
+ try:
36
+ if not video_file:
37
+ raise gr.Error("Please upload a video file")
38
+
39
+ # Load models if not already loaded
40
+ global model, tokenizer
41
+ if model is None or tokenizer is None:
42
+ model, tokenizer = load_moondream()
43
+
44
+ # Ensure input/output directories exist using absolute paths
45
+ inputs_dir = os.path.join(WORKSPACE_ROOT, "inputs")
46
+ outputs_dir = os.path.join(WORKSPACE_ROOT, "outputs")
47
+ os.makedirs(inputs_dir, exist_ok=True)
48
+ os.makedirs(outputs_dir, exist_ok=True)
49
+
50
+ # Copy uploaded video to inputs directory
51
+ video_filename = f"input_{os.path.basename(video_file)}"
52
+ input_video_path = os.path.join(inputs_dir, video_filename)
53
+ shutil.copy2(video_file, input_video_path)
54
+
55
+ try:
56
+ # Process the video
57
+ output_path = process_video(
58
+ input_video_path,
59
+ target_object,
60
+ test_mode=test_mode,
61
+ test_duration=test_duration,
62
+ ffmpeg_preset=ffmpeg_preset,
63
+ grid_rows=grid_rows,
64
+ grid_cols=grid_cols,
65
+ box_style=box_style,
66
+ )
67
+
68
+ # Get the corresponding JSON path
69
+ base_name = os.path.splitext(os.path.basename(video_filename))[0]
70
+ json_path = os.path.join(outputs_dir, f"{box_style}_{target_object}_{base_name}_detections.json")
71
+
72
+ # Verify output exists and is readable
73
+ if not output_path or not os.path.exists(output_path):
74
+ print(f"Warning: Output path {output_path} does not exist")
75
+ # Try to find the output based on expected naming convention
76
+ expected_output = os.path.join(
77
+ outputs_dir, f"{box_style}_{target_object}_{video_filename}"
78
+ )
79
+ if os.path.exists(expected_output):
80
+ output_path = expected_output
81
+ else:
82
+ # Try searching in outputs directory for any matching file
83
+ matching_files = [
84
+ f
85
+ for f in os.listdir(outputs_dir)
86
+ if f.startswith(f"{box_style}_{target_object}_")
87
+ ]
88
+ if matching_files:
89
+ output_path = os.path.join(outputs_dir, matching_files[0])
90
+ else:
91
+ raise gr.Error("Failed to locate output video")
92
+
93
+ # Convert output path to absolute path if it isn't already
94
+ if not os.path.isabs(output_path):
95
+ output_path = os.path.join(WORKSPACE_ROOT, output_path)
96
+
97
+ print(f"Returning output path: {output_path}")
98
+ return output_path, json_path
99
+
100
+ finally:
101
+ # Clean up input file
102
+ try:
103
+ if os.path.exists(input_video_path):
104
+ os.remove(input_video_path)
105
+ except:
106
+ pass
107
+
108
+ except Exception as e:
109
+ print(f"Error in process_video_file: {str(e)}")
110
+ raise gr.Error(f"Error processing video: {str(e)}")
111
+
112
+ def create_visualization_plots(json_path):
113
+ """Create visualization plots and return them as images."""
114
+ try:
115
+ # Load the data
116
+ data = load_detection_data(json_path)
117
+ if not data:
118
+ return None, None, None, None, None, None, None, None, "No data found"
119
+
120
+ # Convert to DataFrame
121
+ rows = []
122
+ for frame_data in data["frame_detections"]:
123
+ frame = frame_data["frame"]
124
+ timestamp = frame_data["timestamp"]
125
+ for obj in frame_data["objects"]:
126
+ rows.append({
127
+ "frame": frame,
128
+ "timestamp": timestamp,
129
+ "keyword": obj["keyword"],
130
+ "x1": obj["bbox"][0],
131
+ "y1": obj["bbox"][1],
132
+ "x2": obj["bbox"][2],
133
+ "y2": obj["bbox"][3],
134
+ "area": (obj["bbox"][2] - obj["bbox"][0]) * (obj["bbox"][3] - obj["bbox"][1]),
135
+ "center_x": (obj["bbox"][0] + obj["bbox"][2]) / 2,
136
+ "center_y": (obj["bbox"][1] + obj["bbox"][3]) / 2
137
+ })
138
+
139
+ if not rows:
140
+ return None, None, None, None, None, None, None, None, "No detections found in the data"
141
+
142
+ df = pd.DataFrame(rows)
143
+ plots = []
144
+
145
+ # Create each plot and convert to image
146
+ for plot_num in range(8): # Increased to 8 plots
147
+ plt.figure(figsize=(8, 6))
148
+
149
+ if plot_num == 0:
150
+ # Plot 1: Number of detections per frame (Original)
151
+ detections_per_frame = df.groupby("frame").size()
152
+ plt.plot(detections_per_frame.index, detections_per_frame.values)
153
+ plt.xlabel("Frame")
154
+ plt.ylabel("Number of Detections")
155
+ plt.title("Detections Per Frame")
156
+
157
+ elif plot_num == 1:
158
+ # Plot 2: Distribution of detection areas (Original)
159
+ df["area"].hist(bins=30)
160
+ plt.xlabel("Detection Area (normalized)")
161
+ plt.ylabel("Count")
162
+ plt.title("Distribution of Detection Areas")
163
+
164
+ elif plot_num == 2:
165
+ # Plot 3: Average detection area over time (Original)
166
+ avg_area = df.groupby("frame")["area"].mean()
167
+ plt.plot(avg_area.index, avg_area.values)
168
+ plt.xlabel("Frame")
169
+ plt.ylabel("Average Detection Area")
170
+ plt.title("Average Detection Area Over Time")
171
+
172
+ elif plot_num == 3:
173
+ # Plot 4: Heatmap of detection centers (Original)
174
+ plt.hist2d(df["center_x"], df["center_y"], bins=30)
175
+ plt.colorbar()
176
+ plt.xlabel("X Position")
177
+ plt.ylabel("Y Position")
178
+ plt.title("Detection Center Heatmap")
179
+
180
+ elif plot_num == 4:
181
+ # Plot 5: Time-based Detection Density
182
+ # Shows when in the video most detections occur
183
+ df["time_bucket"] = pd.qcut(df["timestamp"], q=20, labels=False)
184
+ time_density = df.groupby("time_bucket").size()
185
+ plt.bar(time_density.index, time_density.values)
186
+ plt.xlabel("Video Timeline (20 segments)")
187
+ plt.ylabel("Number of Detections")
188
+ plt.title("Detection Density Over Video Duration")
189
+
190
+ elif plot_num == 5:
191
+ # Plot 6: Screen Region Analysis
192
+ # Divide screen into 3x3 grid and show detection counts
193
+ try:
194
+ df["grid_x"] = pd.qcut(df["center_x"], q=3, labels=["Left", "Center", "Right"], duplicates='drop')
195
+ df["grid_y"] = pd.qcut(df["center_y"], q=3, labels=["Top", "Middle", "Bottom"], duplicates='drop')
196
+ region_counts = df.groupby(["grid_y", "grid_x"]).size().unstack(fill_value=0)
197
+ plt.imshow(region_counts, cmap="YlOrRd")
198
+ plt.colorbar(label="Detection Count")
199
+ for i in range(3):
200
+ for j in range(3):
201
+ plt.text(j, i, region_counts.iloc[i, j], ha="center", va="center")
202
+ plt.xticks(range(3), ["Left", "Center", "Right"])
203
+ plt.yticks(range(3), ["Top", "Middle", "Bottom"])
204
+ plt.title("Screen Region Analysis")
205
+ except Exception as e:
206
+ plt.text(0.5, 0.5, "Insufficient variation in detection positions",
207
+ ha='center', va='center')
208
+ plt.title("Screen Region Analysis (Not Available)")
209
+
210
+ elif plot_num == 6:
211
+ # Plot 7: Detection Size Categories
212
+ # Categorize detections by size for content moderation
213
+ try:
214
+ size_labels = [
215
+ "Small (likely far/background)",
216
+ "Medium-small",
217
+ "Medium-large",
218
+ "Large (likely foreground/close)"
219
+ ]
220
+
221
+ # Handle cases with limited unique values
222
+ unique_areas = df["area"].nunique()
223
+ if unique_areas >= 4:
224
+ df["size_category"] = pd.qcut(df["area"], q=4, labels=size_labels, duplicates='drop')
225
+ else:
226
+ # Alternative binning for limited unique values
227
+ df["size_category"] = pd.cut(df["area"],
228
+ bins=unique_areas,
229
+ labels=size_labels[:unique_areas])
230
+
231
+ size_dist = df["size_category"].value_counts()
232
+ plt.pie(size_dist.values, labels=size_dist.index, autopct="%1.1f%%")
233
+ plt.title("Detection Size Distribution")
234
+ except Exception as e:
235
+ plt.text(0.5, 0.5, "Insufficient variation in detection sizes",
236
+ ha='center', va='center')
237
+ plt.title("Detection Size Distribution (Not Available)")
238
+
239
+ elif plot_num == 7:
240
+ # Plot 8: Temporal Pattern Analysis
241
+ # Show patterns of when detections occur in sequence
242
+ try:
243
+ detection_gaps = df.sort_values("frame")["frame"].diff()
244
+ if len(detection_gaps.dropna().unique()) > 1:
245
+ plt.hist(detection_gaps.dropna(), bins=min(30, len(detection_gaps.dropna().unique())),
246
+ edgecolor="black")
247
+ plt.xlabel("Frames Between Detections")
248
+ plt.ylabel("Frequency")
249
+ plt.title("Detection Temporal Pattern Analysis")
250
+ else:
251
+ plt.text(0.5, 0.5, "Uniform detection intervals", ha='center', va='center')
252
+ plt.title("Temporal Pattern Analysis (Uniform)")
253
+ except Exception as e:
254
+ plt.text(0.5, 0.5, "Insufficient temporal data", ha='center', va='center')
255
+ plt.title("Temporal Pattern Analysis (Not Available)")
256
+
257
+ # Save plot to bytes
258
+ buf = io.BytesIO()
259
+ plt.savefig(buf, format='png', bbox_inches='tight')
260
+ buf.seek(0)
261
+ plots.append(Image.open(buf))
262
+ plt.close()
263
+
264
+ # Enhanced summary text
265
+ summary = f"""Summary Statistics:
266
+ Total frames analyzed: {len(data['frame_detections'])}
267
+ Total detections: {len(df)}
268
+ Average detections per frame: {len(df) / len(data['frame_detections']):.2f}
269
+
270
+ Detection Patterns:
271
+ - Peak detection count: {df.groupby('frame').size().max()} (in a single frame)
272
+ - Most common screen region: {df.groupby(['grid_y', 'grid_x']).size().idxmax()}
273
+ - Average detection size: {df['area'].mean():.3f}
274
+ - Median frames between detections: {detection_gaps.median():.1f}
275
+
276
+ Video metadata:
277
+ """
278
+ for key, value in data["video_metadata"].items():
279
+ summary += f"{key}: {value}\n"
280
+
281
+ return plots[0], plots[1], plots[2], plots[3], plots[4], plots[5], plots[6], plots[7], summary
282
+
283
+ except Exception as e:
284
+ print(f"Error creating visualization: {str(e)}")
285
+ import traceback
286
+ traceback.print_exc()
287
+ return None, None, None, None, None, None, None, None, f"Error creating visualization: {str(e)}"
288
+
289
+ # Create the Gradio interface
290
+ with gr.Blocks(title="Promptable Content Moderation") as app:
291
+ with gr.Tabs():
292
+ with gr.Tab("Process Video"):
293
+ gr.Markdown("# Promptable Content Moderation with Moondream")
294
+ gr.Markdown(
295
+ """
296
+ Powered by [Moondream 2B](https://github.com/vikhyat/moondream).
297
+
298
+ Upload a video and specify what to moderate. The app will process each frame and moderate any visual content that matches the prompt. For help, join the [Moondream Discord](https://discord.com/invite/tRUdpjDQfH).
299
+ """
300
+ )
301
+
302
+ with gr.Row():
303
+ with gr.Column():
304
+ # Input components
305
+ video_input = gr.Video(label="Upload Video")
306
+
307
+ detect_input = gr.Textbox(
308
+ label="What to Moderate",
309
+ placeholder="e.g. face, cigarette, gun, etc.",
310
+ value="face",
311
+ info="Moondream can moderate anything that you can describe in natural language",
312
+ )
313
+
314
+ gr.Examples(
315
+ examples=[
316
+ ["examples/cig.mp4", "cigarette"],
317
+ ["examples/gun.mp4", "gun"],
318
+ ["examples/homealone.mp4", "face"],
319
+ ["examples/conf.mp4", "confederate flag"],
320
+ ],
321
+ inputs=[video_input, detect_input],
322
+ label="Try these examples",
323
+ )
324
+
325
+ process_btn = gr.Button("Process Video", variant="primary")
326
+
327
+ with gr.Accordion("Advanced Settings", open=False):
328
+ box_style_input = gr.Radio(
329
+ choices=["censor", "bounding-box", "hitmarker", "sam", "sam-fast", "fuzzy-blur", "pixelated-blur", "intense-pixelated-blur", "obfuscated-pixel"],
330
+ value="obfuscated-pixel",
331
+ label="Visualization Style",
332
+ info="Choose how to display moderations: censor (black boxes), bounding-box (red boxes with labels), hitmarker (COD-style markers), sam (precise segmentation), sam-fast (faster but less precise segmentation), fuzzy-blur (Gaussian blur), pixelated-blur (pixelated with blur), obfuscated-pixel (advanced pixelation with neighborhood averaging)",
333
+ )
334
+ preset_input = gr.Dropdown(
335
+ choices=[
336
+ "ultrafast",
337
+ "superfast",
338
+ "veryfast",
339
+ "faster",
340
+ "fast",
341
+ "medium",
342
+ "slow",
343
+ "slower",
344
+ "veryslow",
345
+ ],
346
+ value="medium",
347
+ label="Processing Speed (faster = lower quality)",
348
+ )
349
+ with gr.Row():
350
+ rows_input = gr.Slider(
351
+ minimum=1, maximum=4, value=1, step=1, label="Grid Rows"
352
+ )
353
+ cols_input = gr.Slider(
354
+ minimum=1, maximum=4, value=1, step=1, label="Grid Columns"
355
+ )
356
+
357
+ test_mode_input = gr.Checkbox(
358
+ label="Test Mode (Process first 3 seconds only)",
359
+ value=True,
360
+ info="Enable to quickly test settings on a short clip before processing the full video (recommended). If using the data visualizations, disable.",
361
+ )
362
+
363
+ test_duration_input = gr.Slider(
364
+ minimum=1,
365
+ maximum=10,
366
+ value=3,
367
+ step=1,
368
+ label="Test Mode Duration (seconds)",
369
+ info="Number of seconds to process in test mode"
370
+ )
371
+
372
+ gr.Markdown(
373
+ """
374
+ Note: Processing in test mode will only process the first 3 seconds of the video and is recommended for testing settings.
375
+ """
376
+ )
377
+
378
+ gr.Markdown(
379
+ """
380
+ We can get a rough estimate of how long the video will take to process by multiplying the videos framerate * seconds * the number of rows and columns and assuming 0.12 seconds processing time per detection.
381
+ For example, a 3 second video at 30fps with 2x2 grid, the estimated time is 3 * 30 * 2 * 2 * 0.12 = 43.2 seconds (tested on a 4090 GPU).
382
+
383
+ Note: Using the SAM visualization style will increase processing time significantly as it performs additional segmentation for each detection. The sam-fast option uses a smaller model for faster processing at the cost of some accuracy.
384
+ """
385
+ )
386
+
387
+ with gr.Column():
388
+ # Output components
389
+ video_output = gr.Video(label="Processed Video")
390
+ json_output = gr.Text(label="Detection Data Path", visible=False)
391
+
392
+ # About section under the video output
393
+ gr.Markdown(
394
+ """
395
+ ### Links:
396
+ - [GitHub Repository](https://github.com/vikhyat/moondream)
397
+ - [Hugging Face](https://huggingface.co/vikhyatk/moondream2)
398
+ - [Quick Start](https://docs.moondream.ai/quick-start)
399
+ - [Moondream Recipes](https://docs.moondream.ai/recipes)
400
+ """
401
+ )
402
+
403
+ with gr.Tab("Analyze Results"):
404
+ gr.Markdown("# Detection Analysis")
405
+ gr.Markdown(
406
+ """
407
+ Analyze the detection results from processed videos. The analysis includes:
408
+ - Basic detection statistics and patterns
409
+ - Temporal and spatial distribution analysis
410
+ - Size-based categorization
411
+ - Screen region analysis
412
+ - Detection density patterns
413
+ """
414
+ )
415
+
416
+ with gr.Row():
417
+ json_input = gr.File(
418
+ label="Upload Detection Data (JSON)",
419
+ file_types=[".json"],
420
+ )
421
+ analyze_btn = gr.Button("Analyze", variant="primary")
422
+
423
+ with gr.Row():
424
+ with gr.Column():
425
+ plot1 = gr.Image(
426
+ label="Detections Per Frame",
427
+ )
428
+ plot2 = gr.Image(
429
+ label="Detection Areas Distribution",
430
+ )
431
+ plot5 = gr.Image(
432
+ label="Detection Density Timeline",
433
+ )
434
+ plot6 = gr.Image(
435
+ label="Screen Region Analysis",
436
+ )
437
+
438
+ with gr.Column():
439
+ plot3 = gr.Image(
440
+ label="Average Detection Area Over Time",
441
+ )
442
+ plot4 = gr.Image(
443
+ label="Detection Center Heatmap",
444
+ )
445
+ plot7 = gr.Image(
446
+ label="Detection Size Categories",
447
+ )
448
+ plot8 = gr.Image(
449
+ label="Temporal Pattern Analysis",
450
+ )
451
+
452
+ stats_output = gr.Textbox(
453
+ label="Statistics",
454
+ info="Summary of key metrics and patterns found in the detection data.",
455
+ lines=12,
456
+ max_lines=15,
457
+ interactive=False
458
+ )
459
+
460
+ # with gr.Tab("Video Visualizations"):
461
+ # gr.Markdown("# Real-time Detection Visualization")
462
+ # gr.Markdown(
463
+ # """
464
+ # Watch the detection patterns unfold in real-time. Choose from:
465
+ # - Timeline: Shows number of detections over time
466
+ # - Gauge: Simple yes/no indicator for current frame detections
467
+ # """
468
+ # )
469
+
470
+ # with gr.Row():
471
+ # json_input_realtime = gr.File(
472
+ # label="Upload Detection Data (JSON)",
473
+ # file_types=[".json"],
474
+ # )
475
+ # viz_style = gr.Radio(
476
+ # choices=["timeline", "gauge"],
477
+ # value="timeline",
478
+ # label="Visualization Style",
479
+ # info="Choose between timeline view or simple gauge indicator"
480
+ # )
481
+ # visualize_btn = gr.Button("Visualize", variant="primary")
482
+
483
+ # with gr.Row():
484
+ # video_visualization = gr.Video(
485
+ # label="Detection Visualization",
486
+ # interactive=False
487
+ # )
488
+ # stats_realtime = gr.Textbox(
489
+ # label="Video Statistics",
490
+ # lines=6,
491
+ # max_lines=8,
492
+ # interactive=False
493
+ # )
494
+
495
+ # Event handlers
496
+ process_outputs = process_btn.click(
497
+ fn=process_video_file,
498
+ inputs=[
499
+ video_input,
500
+ detect_input,
501
+ box_style_input,
502
+ preset_input,
503
+ rows_input,
504
+ cols_input,
505
+ test_mode_input,
506
+ test_duration_input,
507
+ ],
508
+ outputs=[video_output, json_output],
509
+ )
510
+
511
+ # Auto-analyze after processing
512
+ process_outputs.then(
513
+ fn=create_visualization_plots,
514
+ inputs=[json_output],
515
+ outputs=[plot1, plot2, plot3, plot4, plot5, plot6, plot7, plot8, stats_output],
516
+ )
517
+
518
+ # Manual analysis button
519
+ analyze_btn.click(
520
+ fn=create_visualization_plots,
521
+ inputs=[json_input],
522
+ outputs=[plot1, plot2, plot3, plot4, plot5, plot6, plot7, plot8, stats_output],
523
+ )
524
+
525
+ # Video visualization button
526
+ # visualize_btn.click(
527
+ # fn=lambda json_file, style: create_video_visualization(json_file.name if json_file else None, style),
528
+ # inputs=[json_input_realtime, viz_style],
529
+ # outputs=[video_visualization, stats_realtime],
530
+ # )
531
+
532
+ if __name__ == "__main__":
533
+ app.launch(share=True)
deep_sort_integration.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from deep_sort_realtime.deepsort_tracker import DeepSort
4
+ from datetime import datetime
5
+
6
+ class DeepSORTTracker:
7
+ def __init__(self, max_age=5):
8
+ """Initialize DeepSORT tracker."""
9
+ self.max_age = max_age
10
+ self.tracker = self._create_tracker()
11
+
12
+ def _create_tracker(self):
13
+ """Create a new instance of DeepSort tracker."""
14
+ return DeepSort(
15
+ max_age=self.max_age,
16
+ embedder='mobilenet', # Using default MobileNetV2 embedder
17
+ today=datetime.now().date() # For track naming and daily ID reset
18
+ )
19
+
20
+ def reset(self):
21
+ """Reset the tracker state by creating a new instance."""
22
+ print("Resetting DeepSORT tracker...")
23
+ self.tracker = self._create_tracker()
24
+
25
+ def update(self, frame, detections):
26
+ """Update tracking with new detections.
27
+
28
+ Args:
29
+ frame: Current video frame (numpy array)
30
+ detections: List of (box, keyword) tuples where box is [x1, y1, x2, y2] normalized
31
+
32
+ Returns:
33
+ List of (box, keyword, track_id) tuples
34
+ """
35
+ if not detections:
36
+ return []
37
+
38
+ height, width = frame.shape[:2]
39
+
40
+ # Convert normalized coordinates to absolute and format detections
41
+ detection_list = []
42
+ for box, keyword in detections:
43
+ x1 = int(box[0] * width)
44
+ y1 = int(box[1] * height)
45
+ x2 = int(box[2] * width)
46
+ y2 = int(box[3] * height)
47
+ w = x2 - x1
48
+ h = y2 - y1
49
+
50
+ # Format: ([left,top,w,h], confidence, detection_class)
51
+ detection_list.append(([x1, y1, w, h], 1.0, keyword))
52
+
53
+ # Update tracker
54
+ tracks = self.tracker.update_tracks(detection_list, frame=frame)
55
+
56
+ # Convert back to normalized coordinates with track IDs
57
+ tracked_objects = []
58
+ for track in tracks:
59
+ if not track.is_confirmed():
60
+ continue
61
+
62
+ ltrb = track.to_ltrb() # Get [left,top,right,bottom] format
63
+ x1, y1, x2, y2 = ltrb
64
+
65
+ # Normalize coordinates
66
+ x1 = max(0.0, min(1.0, x1 / width))
67
+ y1 = max(0.0, min(1.0, y1 / height))
68
+ x2 = max(0.0, min(1.0, x2 / width))
69
+ y2 = max(0.0, min(1.0, y2 / height))
70
+
71
+ tracked_objects.append(([x1, y1, x2, y2], track.det_class, track.track_id))
72
+
73
+ return tracked_objects
examples/cig.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4a8bed3b243fca48b8d02ba86be87045d084d9ee8fd739a124b7c16a0e2200c
3
+ size 9771113
examples/clip-cig.gif ADDED

Git LFS Details

  • SHA256: 79f77ecfabe3b04abc1847535e3794a55f261c08e7cc9ea30cb86560bcbec191
  • Pointer size: 132 Bytes
  • Size of remote file: 7.24 MB
examples/clip-conflag.gif ADDED

Git LFS Details

  • SHA256: ccfcd45c83b7aae3888983b981cce83ffa45727d4a3431fdc4123e628e88e594
  • Pointer size: 133 Bytes
  • Size of remote file: 12.5 MB
examples/clip-gu.gif ADDED

Git LFS Details

  • SHA256: 71dfd0828ede5b196814c4a06714ae248f45a28df9783a59209713660d4b4e87
  • Pointer size: 132 Bytes
  • Size of remote file: 6.8 MB
examples/conf.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad4ffb858565c3ed6d6a955c3ba802c76f5d9bb95a0054c92c74747b694b253b
3
+ size 20389258
examples/gun.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80168314bef9499a80c563980429e3f674a4450333f34d6315b188eb16f8f85b
3
+ size 7369081
examples/homealone.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec8d4410d2b3937f57b40d6084e4f4fd8538b766a69229feb1320891d3ee78e5
3
+ size 11023032
main.py ADDED
@@ -0,0 +1,1236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import cv2, os, subprocess, argparse
3
+ from PIL import Image
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, SamModel, SamProcessor
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ from datetime import datetime
9
+ from deep_sort_integration import DeepSORTTracker
10
+ from scenedetect import detect, ContentDetector
11
+ from functools import lru_cache
12
+
13
+ # Constants
14
+ DEFAULT_TEST_MODE_DURATION = 3 # Process only first 3 seconds in test mode by default
15
+ FFMPEG_PRESETS = [
16
+ "ultrafast",
17
+ "superfast",
18
+ "veryfast",
19
+ "faster",
20
+ "fast",
21
+ "medium",
22
+ "slow",
23
+ "slower",
24
+ "veryslow",
25
+ ]
26
+ FONT = cv2.FONT_HERSHEY_SIMPLEX # Font for bounding-box-style labels
27
+
28
+ # Detection parameters
29
+ IOU_THRESHOLD = 0.5 # IoU threshold for considering boxes related
30
+
31
+ # Hitmarker parameters
32
+ HITMARKER_SIZE = 20 # Size of the hitmarker in pixels
33
+ HITMARKER_GAP = 3 # Size of the empty space in the middle (reduced from 8)
34
+ HITMARKER_THICKNESS = 2 # Thickness of hitmarker lines
35
+ HITMARKER_COLOR = (255, 255, 255) # White color for hitmarker
36
+ HITMARKER_SHADOW_COLOR = (80, 80, 80) # Lighter gray for shadow effect
37
+ HITMARKER_SHADOW_OFFSET = 1 # Smaller shadow offset
38
+
39
+ # SAM parameters
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+
42
+ # Initialize model variables as None
43
+ sam_model = None
44
+ sam_processor = None
45
+ slimsam_model = None
46
+ slimsam_processor = None
47
+
48
+ @lru_cache(maxsize=2) # Cache both regular and slim SAM models
49
+ def get_sam_model(slim=False):
50
+ """Get cached SAM model and processor."""
51
+ global sam_model, sam_processor, slimsam_model, slimsam_processor
52
+
53
+ if slim:
54
+ if slimsam_model is None:
55
+ print("Loading SlimSAM model for the first time...")
56
+ slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to(device)
57
+ slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")
58
+ return slimsam_model, slimsam_processor
59
+ else:
60
+ if sam_model is None:
61
+ print("Loading SAM model for the first time...")
62
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
63
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
64
+ return sam_model, sam_processor
65
+
66
+ def load_sam_model(slim=False):
67
+ """Load SAM model and processor with caching."""
68
+ return get_sam_model(slim=slim)
69
+
70
+ def generate_color_pair():
71
+ """Generate a generic light blue and dark blue color pair for SAM visualization."""
72
+ dark_rgb = [0, 0, 139] # Dark blue
73
+ light_rgb = [173, 216, 230] # Light blue
74
+ return dark_rgb, light_rgb
75
+
76
+ def create_mask_overlay(image, masks, points=None, labels=None):
77
+ """Create a mask overlay with contours for multiple SAM visualizations.
78
+
79
+ Args:
80
+ image: PIL Image to overlay masks on
81
+ masks: List of binary masks or single mask
82
+ points: Optional list of (x,y) points for labels
83
+ labels: Optional list of label strings for each point
84
+ """
85
+ # Convert single mask to list for uniform processing
86
+ if not isinstance(masks, list):
87
+ masks = [masks]
88
+
89
+ # Create empty overlays
90
+ overlay = np.zeros((*image.size[::-1], 4), dtype=np.uint8)
91
+ outline = np.zeros((*image.size[::-1], 4), dtype=np.uint8)
92
+
93
+ # Process each mask
94
+ for i, mask in enumerate(masks):
95
+ # Convert binary mask to uint8
96
+ mask_uint8 = (mask > 0).astype(np.uint8)
97
+
98
+ # Dilation to fill gaps
99
+ kernel = np.ones((5, 5), np.uint8)
100
+ mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=1)
101
+
102
+ # Find contours of the dilated mask
103
+ contours, _ = cv2.findContours(mask_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
104
+
105
+ # Generate random color pair for this segmentation
106
+ dark_color, light_color = generate_color_pair()
107
+
108
+ # Add to the overlays
109
+ overlay[mask_dilated > 0] = [*light_color, 90] # Light color with 35% opacity
110
+ cv2.drawContours(outline, contours, -1, (*dark_color, 255), 2) # Dark color outline
111
+
112
+ # Convert to PIL images
113
+ mask_overlay = Image.fromarray(overlay, 'RGBA')
114
+ outline_overlay = Image.fromarray(outline, 'RGBA')
115
+
116
+ # Composite the layers
117
+ result = image.convert('RGBA')
118
+ result.paste(mask_overlay, (0, 0), mask_overlay)
119
+ result.paste(outline_overlay, (0, 0), outline_overlay)
120
+
121
+ # Add labels if provided
122
+ if points and labels:
123
+ result_array = np.array(result)
124
+ for (x, y), label in zip(points, labels):
125
+ label_size = cv2.getTextSize(label, FONT, 0.5, 1)[0]
126
+ cv2.putText(
127
+ result_array,
128
+ label,
129
+ (int(x - label_size[0] // 2), int(y - 20)),
130
+ FONT,
131
+ 0.5,
132
+ (255, 255, 255),
133
+ 1,
134
+ cv2.LINE_AA,
135
+ )
136
+ result = Image.fromarray(result_array)
137
+
138
+ return result
139
+
140
+ def process_sam_detection(image, center_x, center_y, slim=False):
141
+ """Process a single detection point with SAM.
142
+
143
+ Returns:
144
+ tuple: (mask, result_pil) where mask is the binary mask and result_pil is the visualization
145
+ """
146
+ if not isinstance(image, Image.Image):
147
+ image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
148
+
149
+ # Get appropriate model from cache
150
+ model, processor = get_sam_model(slim)
151
+
152
+ # Process the image with SAM
153
+ inputs = processor(
154
+ image,
155
+ input_points=[[[center_x, center_y]]],
156
+ return_tensors="pt"
157
+ ).to(device)
158
+
159
+ with torch.no_grad():
160
+ outputs = model(**inputs)
161
+
162
+ mask = processor.post_process_masks(
163
+ outputs.pred_masks.cpu(),
164
+ inputs["original_sizes"].cpu(),
165
+ inputs["reshaped_input_sizes"].cpu()
166
+ )[0][0][0].numpy()
167
+
168
+ # Create the visualization
169
+ result = create_mask_overlay(image, mask)
170
+ return mask, result
171
+
172
+ def load_moondream():
173
+ """Load Moondream model and tokenizer."""
174
+ model = AutoModelForCausalLM.from_pretrained(
175
+ "vikhyatk/moondream2", trust_remote_code=True, device_map={"": "cuda"}
176
+ )
177
+ tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2")
178
+ return model, tokenizer
179
+
180
+
181
+ def get_video_properties(video_path):
182
+ """Get basic video properties."""
183
+ video = cv2.VideoCapture(video_path)
184
+ fps = video.get(cv2.CAP_PROP_FPS)
185
+ frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
186
+ width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
187
+ height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
188
+ video.release()
189
+ return {"fps": fps, "frame_count": frame_count, "width": width, "height": height}
190
+
191
+
192
+ def is_valid_bounding_box(bounding_box):
193
+ """Check if bounding box coordinates are reasonable."""
194
+ x1, y1, x2, y2 = bounding_box
195
+ width = x2 - x1
196
+ height = y2 - y1
197
+
198
+ # Reject boxes that are too large (over 90% of frame in both dimensions)
199
+ if width > 0.9 and height > 0.9:
200
+ return False
201
+
202
+ # Reject boxes that are too small (less than 1% of frame)
203
+ if width < 0.01 or height < 0.01:
204
+ return False
205
+
206
+ return True
207
+
208
+
209
+ def split_frame_into_grid(frame, grid_rows, grid_cols):
210
+ """Split a frame into a grid of tiles."""
211
+ height, width = frame.shape[:2]
212
+ tile_height = height // grid_rows
213
+ tile_width = width // grid_cols
214
+ tiles = []
215
+ tile_positions = []
216
+
217
+ for i in range(grid_rows):
218
+ for j in range(grid_cols):
219
+ y1 = i * tile_height
220
+ y2 = (i + 1) * tile_height if i < grid_rows - 1 else height
221
+ x1 = j * tile_width
222
+ x2 = (j + 1) * tile_width if j < grid_cols - 1 else width
223
+
224
+ tile = frame[y1:y2, x1:x2]
225
+ tiles.append(tile)
226
+ tile_positions.append((x1, y1, x2, y2))
227
+
228
+ return tiles, tile_positions
229
+
230
+
231
+ def convert_tile_coords_to_frame(box, tile_pos, frame_shape):
232
+ """Convert coordinates from tile space to frame space."""
233
+ frame_height, frame_width = frame_shape[:2]
234
+ tile_x1, tile_y1, tile_x2, tile_y2 = tile_pos
235
+ tile_width = tile_x2 - tile_x1
236
+ tile_height = tile_y2 - tile_y1
237
+
238
+ x1_tile_abs = box[0] * tile_width
239
+ y1_tile_abs = box[1] * tile_height
240
+ x2_tile_abs = box[2] * tile_width
241
+ y2_tile_abs = box[3] * tile_height
242
+
243
+ x1_frame_abs = tile_x1 + x1_tile_abs
244
+ y1_frame_abs = tile_y1 + y1_tile_abs
245
+ x2_frame_abs = tile_x1 + x2_tile_abs
246
+ y2_frame_abs = tile_y1 + y2_tile_abs
247
+
248
+ x1_norm = x1_frame_abs / frame_width
249
+ y1_norm = y1_frame_abs / frame_height
250
+ x2_norm = x2_frame_abs / frame_width
251
+ y2_norm = y2_frame_abs / frame_height
252
+
253
+ x1_norm = max(0.0, min(1.0, x1_norm))
254
+ y1_norm = max(0.0, min(1.0, y1_norm))
255
+ x2_norm = max(0.0, min(1.0, x2_norm))
256
+ y2_norm = max(0.0, min(1.0, y2_norm))
257
+
258
+ return [x1_norm, y1_norm, x2_norm, y2_norm]
259
+
260
+
261
+ def merge_tile_detections(tile_detections, iou_threshold=0.5):
262
+ """Merge detections from different tiles using NMS-like approach."""
263
+ if not tile_detections:
264
+ return []
265
+
266
+ all_boxes = []
267
+ all_keywords = []
268
+
269
+ # Collect all boxes and their keywords
270
+ for detections in tile_detections:
271
+ for box, keyword in detections:
272
+ all_boxes.append(box)
273
+ all_keywords.append(keyword)
274
+
275
+ if not all_boxes:
276
+ return []
277
+
278
+ # Convert to numpy for easier processing
279
+ boxes = np.array(all_boxes)
280
+
281
+ # Calculate areas
282
+ x1 = boxes[:, 0]
283
+ y1 = boxes[:, 1]
284
+ x2 = boxes[:, 2]
285
+ y2 = boxes[:, 3]
286
+ areas = (x2 - x1) * (y2 - y1)
287
+
288
+ # Sort boxes by area
289
+ order = areas.argsort()[::-1]
290
+
291
+ keep = []
292
+ while order.size > 0:
293
+ i = order[0]
294
+ keep.append(i)
295
+
296
+ if order.size == 1:
297
+ break
298
+
299
+ # Calculate IoU with rest of boxes
300
+ xx1 = np.maximum(x1[i], x1[order[1:]])
301
+ yy1 = np.maximum(y1[i], y1[order[1:]])
302
+ xx2 = np.minimum(x2[i], x2[order[1:]])
303
+ yy2 = np.minimum(y2[i], y2[order[1:]])
304
+
305
+ w = np.maximum(0.0, xx2 - xx1)
306
+ h = np.maximum(0.0, yy2 - yy1)
307
+ inter = w * h
308
+
309
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
310
+
311
+ # Get indices of boxes with IoU less than threshold
312
+ inds = np.where(ovr <= iou_threshold)[0]
313
+ order = order[inds + 1]
314
+
315
+ return [(all_boxes[i], all_keywords[i]) for i in keep]
316
+
317
+
318
+ def detect_objects_in_frame(model, tokenizer, image, target_object, grid_rows=1, grid_cols=1):
319
+ """Detect specified objects in a frame using grid-based analysis."""
320
+ if grid_rows == 1 and grid_cols == 1:
321
+ return detect_objects_in_frame_single(model, tokenizer, image, target_object)
322
+
323
+ # Convert numpy array to PIL Image if needed
324
+ if not isinstance(image, Image.Image):
325
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
326
+
327
+ # Split frame into tiles
328
+ tiles, tile_positions = split_frame_into_grid(image, grid_rows, grid_cols)
329
+
330
+ # Process each tile
331
+ tile_detections = []
332
+ for tile, tile_pos in zip(tiles, tile_positions):
333
+ # Convert tile to PIL Image
334
+ tile_pil = Image.fromarray(tile)
335
+
336
+ # Detect objects in tile
337
+ response = model.detect(tile_pil, target_object)
338
+
339
+ if response and "objects" in response and response["objects"]:
340
+ objects = response["objects"]
341
+ tile_objects = []
342
+
343
+ for obj in objects:
344
+ if all(k in obj for k in ["x_min", "y_min", "x_max", "y_max"]):
345
+ box = [obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"]]
346
+
347
+ if is_valid_bounding_box(box):
348
+ # Convert tile coordinates to frame coordinates
349
+ frame_box = convert_tile_coords_to_frame(
350
+ box, tile_pos, image.shape
351
+ )
352
+ tile_objects.append((frame_box, target_object))
353
+
354
+ if tile_objects: # Only append if we found valid objects
355
+ tile_detections.append(tile_objects)
356
+
357
+ # Merge detections from all tiles
358
+ merged_detections = merge_tile_detections(tile_detections)
359
+ return merged_detections
360
+
361
+
362
+ def detect_objects_in_frame_single(model, tokenizer, image, target_object):
363
+ """Single-frame detection function."""
364
+ detected_objects = []
365
+
366
+ # Convert numpy array to PIL Image if needed
367
+ if not isinstance(image, Image.Image):
368
+ image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
369
+
370
+ # Detect objects
371
+ response = model.detect(image, target_object)
372
+
373
+ # Check if we have valid objects
374
+ if response and "objects" in response and response["objects"]:
375
+ objects = response["objects"]
376
+
377
+ for obj in objects:
378
+ if all(k in obj for k in ["x_min", "y_min", "x_max", "y_max"]):
379
+ box = [obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"]]
380
+ # If box is valid (not full-frame), add it
381
+ if is_valid_bounding_box(box):
382
+ detected_objects.append((box, target_object))
383
+
384
+ return detected_objects
385
+
386
+
387
+ def draw_hitmarker(
388
+ frame, center_x, center_y, size=HITMARKER_SIZE, color=HITMARKER_COLOR, shadow=True
389
+ ):
390
+ """Draw a COD-style hitmarker cross with more space in the middle."""
391
+ half_size = size // 2
392
+
393
+ # Draw shadow first if enabled
394
+ if shadow:
395
+ # Top-left to center shadow
396
+ cv2.line(
397
+ frame,
398
+ (
399
+ center_x - half_size + HITMARKER_SHADOW_OFFSET,
400
+ center_y - half_size + HITMARKER_SHADOW_OFFSET,
401
+ ),
402
+ (
403
+ center_x - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
404
+ center_y - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
405
+ ),
406
+ HITMARKER_SHADOW_COLOR,
407
+ HITMARKER_THICKNESS,
408
+ )
409
+ # Top-right to center shadow
410
+ cv2.line(
411
+ frame,
412
+ (
413
+ center_x + half_size + HITMARKER_SHADOW_OFFSET,
414
+ center_y - half_size + HITMARKER_SHADOW_OFFSET,
415
+ ),
416
+ (
417
+ center_x + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
418
+ center_y - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
419
+ ),
420
+ HITMARKER_SHADOW_COLOR,
421
+ HITMARKER_THICKNESS,
422
+ )
423
+ # Bottom-left to center shadow
424
+ cv2.line(
425
+ frame,
426
+ (
427
+ center_x - half_size + HITMARKER_SHADOW_OFFSET,
428
+ center_y + half_size + HITMARKER_SHADOW_OFFSET,
429
+ ),
430
+ (
431
+ center_x - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
432
+ center_y + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
433
+ ),
434
+ HITMARKER_SHADOW_COLOR,
435
+ HITMARKER_THICKNESS,
436
+ )
437
+ # Bottom-right to center shadow
438
+ cv2.line(
439
+ frame,
440
+ (
441
+ center_x + half_size + HITMARKER_SHADOW_OFFSET,
442
+ center_y + half_size + HITMARKER_SHADOW_OFFSET,
443
+ ),
444
+ (
445
+ center_x + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
446
+ center_y + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,
447
+ ),
448
+ HITMARKER_SHADOW_COLOR,
449
+ HITMARKER_THICKNESS,
450
+ )
451
+
452
+ # Draw main hitmarker
453
+ # Top-left to center
454
+ cv2.line(
455
+ frame,
456
+ (center_x - half_size, center_y - half_size),
457
+ (center_x - HITMARKER_GAP, center_y - HITMARKER_GAP),
458
+ color,
459
+ HITMARKER_THICKNESS,
460
+ )
461
+ # Top-right to center
462
+ cv2.line(
463
+ frame,
464
+ (center_x + half_size, center_y - half_size),
465
+ (center_x + HITMARKER_GAP, center_y - HITMARKER_GAP),
466
+ color,
467
+ HITMARKER_THICKNESS,
468
+ )
469
+ # Bottom-left to center
470
+ cv2.line(
471
+ frame,
472
+ (center_x - half_size, center_y + half_size),
473
+ (center_x - HITMARKER_GAP, center_y + HITMARKER_GAP),
474
+ color,
475
+ HITMARKER_THICKNESS,
476
+ )
477
+ # Bottom-right to center
478
+ cv2.line(
479
+ frame,
480
+ (center_x + half_size, center_y + half_size),
481
+ (center_x + HITMARKER_GAP, center_y + HITMARKER_GAP),
482
+ color,
483
+ HITMARKER_THICKNESS,
484
+ )
485
+
486
+
487
+ def draw_ad_boxes(frame, detected_objects, detect_keyword, model, box_style="censor"):
488
+ height, width = frame.shape[:2]
489
+
490
+ points = []
491
+ # Only get points if we need them for hitmarker or SAM styles
492
+ if box_style in ["hitmarker", "sam", "sam-fast"]:
493
+ frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
494
+ try:
495
+ point_response = model.point(frame_pil, detect_keyword)
496
+
497
+ if isinstance(point_response, dict) and 'points' in point_response:
498
+ points = point_response['points']
499
+ except Exception as e:
500
+ print(f"Error during point detection: {str(e)}")
501
+ points = []
502
+
503
+ # Only load SAM models and process points if we're using SAM styles and have points
504
+ if box_style in ["sam", "sam-fast"] and points:
505
+ # Start with the original PIL image
506
+ frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
507
+
508
+ # Collect all masks and points
509
+ all_masks = []
510
+ point_coords = []
511
+ point_labels = []
512
+
513
+ for point in points:
514
+ try:
515
+ center_x = int(float(point["x"]) * width)
516
+ center_y = int(float(point["y"]) * height)
517
+
518
+ # Get mask and visualization
519
+ mask, _ = process_sam_detection(frame_pil, center_x, center_y, slim=(box_style == "sam-fast"))
520
+
521
+ # Collect mask and point data
522
+ all_masks.append(mask)
523
+ point_coords.append((center_x, center_y))
524
+ point_labels.append(detect_keyword)
525
+
526
+ except Exception as e:
527
+ print(f"Error processing individual SAM point: {str(e)}")
528
+ print(f"Point data: {point}")
529
+
530
+ if all_masks:
531
+ # Create final visualization with all masks
532
+ result_pil = create_mask_overlay(frame_pil, all_masks, point_coords, point_labels)
533
+ frame = cv2.cvtColor(np.array(result_pil), cv2.COLOR_RGB2BGR)
534
+
535
+ # Process other visualization styles
536
+ for detection in detected_objects:
537
+ try:
538
+ # Handle both tracked and untracked detections
539
+ if len(detection) == 3: # Tracked detection with ID
540
+ box, keyword, track_id = detection
541
+ else: # Regular detection without tracking
542
+ box, keyword = detection
543
+ track_id = None
544
+
545
+ x1 = int(box[0] * width)
546
+ y1 = int(box[1] * height)
547
+ x2 = int(box[2] * width)
548
+ y2 = int(box[3] * height)
549
+
550
+ x1 = max(0, min(x1, width - 1))
551
+ y1 = max(0, min(y1, height - 1))
552
+ x2 = max(0, min(x2, width - 1))
553
+ y2 = max(0, min(y2, height - 1))
554
+
555
+ if x2 > x1 and y2 > y1:
556
+ if box_style == "censor":
557
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 0), -1)
558
+ elif box_style == "bounding-box":
559
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 3)
560
+
561
+ label = f"{detect_keyword}" if track_id is not None else detect_keyword
562
+ label_size = cv2.getTextSize(label, FONT, 0.7, 2)[0]
563
+ cv2.rectangle(
564
+ frame, (x1, y1 - 25), (x1 + label_size[0], y1), (0, 0, 255), -1
565
+ )
566
+ cv2.putText(
567
+ frame,
568
+ label,
569
+ (x1, y1 - 6),
570
+ FONT,
571
+ 0.7,
572
+ (255, 255, 255),
573
+ 2,
574
+ cv2.LINE_AA,
575
+ )
576
+ elif box_style == "fuzzy-blur":
577
+ # Extract ROI
578
+ roi = frame[y1:y2, x1:x2]
579
+ # Apply Gaussian blur with much larger kernel for intense blur
580
+ blurred_roi = cv2.GaussianBlur(roi, (125, 125), 0)
581
+ # Replace original ROI with blurred version
582
+ frame[y1:y2, x1:x2] = blurred_roi
583
+ elif box_style == "pixelated-blur":
584
+ # Extract ROI
585
+ roi = frame[y1:y2, x1:x2]
586
+ # Pixelate by resizing down and up
587
+ h, w = roi.shape[:2]
588
+ temp = cv2.resize(roi, (10, 10), interpolation=cv2.INTER_LINEAR)
589
+ pixelated = cv2.resize(temp, (w, h), interpolation=cv2.INTER_NEAREST)
590
+ # Mix up the pixelated frame slightly by adding random noise
591
+ noise = np.random.randint(0, 50, (h, w, 3), dtype=np.uint8)
592
+ pixelated = cv2.add(pixelated, noise)
593
+ # Apply stronger Gaussian blur to smooth edges
594
+ blurred_pixelated = cv2.GaussianBlur(pixelated, (15, 15), 0)
595
+ # Replace original ROI
596
+ frame[y1:y2, x1:x2] = blurred_pixelated
597
+ elif box_style == "obfuscated-pixel":
598
+ # Calculate expansion amount based on 10% of object dimensions
599
+ box_width = x2 - x1
600
+ box_height = y2 - y1
601
+ expand_x = int(box_width * 0.10)
602
+ expand_y = int(box_height * 0.10)
603
+
604
+ # Expand the bounding box by 10% in all directions
605
+ x1_expanded = max(0, x1 - expand_x)
606
+ y1_expanded = max(0, y1 - expand_y)
607
+ x2_expanded = min(width - 1, x2 + expand_x)
608
+ y2_expanded = min(height - 1, y2 + expand_y)
609
+
610
+ # Extract ROI with much larger padding for true background sampling
611
+ padding = 100 # Much larger padding to get true background
612
+ y1_pad = max(0, y1_expanded - padding)
613
+ y2_pad = min(height, y2_expanded + padding)
614
+ x1_pad = max(0, x1_expanded - padding)
615
+ x2_pad = min(width, x2_expanded + padding)
616
+
617
+ # Get the padded region including background
618
+ padded_roi = frame[y1_pad:y2_pad, x1_pad:x2_pad]
619
+
620
+ # Create mask that excludes a larger region around the detection
621
+ h, w = y2_expanded - y1_expanded, x2_expanded - x1_expanded
622
+ bg_mask = np.ones(padded_roi.shape[:2], dtype=bool)
623
+
624
+ # Exclude a larger region around the detection from background sampling
625
+ exclusion_padding = 50 # Area to exclude around detection
626
+ exclude_y1 = padding - exclusion_padding
627
+ exclude_y2 = padding + h + exclusion_padding
628
+ exclude_x1 = padding - exclusion_padding
629
+ exclude_x2 = padding + w + exclusion_padding
630
+
631
+ # Make sure exclusion coordinates are valid
632
+ exclude_y1 = max(0, exclude_y1)
633
+ exclude_y2 = min(padded_roi.shape[0], exclude_y2)
634
+ exclude_x1 = max(0, exclude_x1)
635
+ exclude_x2 = min(padded_roi.shape[1], exclude_x2)
636
+
637
+ # Mark the exclusion zone in the mask
638
+ bg_mask[exclude_y1:exclude_y2, exclude_x1:exclude_x2] = False
639
+
640
+ # If we have enough background pixels, calculate average color
641
+ if np.any(bg_mask):
642
+ bg_color = np.mean(padded_roi[bg_mask], axis=0).astype(np.uint8)
643
+ else:
644
+ # Fallback to edges if we couldn't get enough background
645
+ edge_samples = np.concatenate([
646
+ padded_roi[0], # Top edge
647
+ padded_roi[-1], # Bottom edge
648
+ padded_roi[:, 0], # Left edge
649
+ padded_roi[:, -1] # Right edge
650
+ ])
651
+ bg_color = np.mean(edge_samples, axis=0).astype(np.uint8)
652
+
653
+ # Create base pixelated version (of the expanded region)
654
+ temp = cv2.resize(frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded],
655
+ (6, 6), interpolation=cv2.INTER_LINEAR)
656
+ pixelated = cv2.resize(temp, (w, h), interpolation=cv2.INTER_NEAREST)
657
+
658
+ # Blend heavily towards background color
659
+ blend_factor = 0.9 # Much stronger blend with background
660
+ blended = cv2.addWeighted(
661
+ pixelated, 1 - blend_factor,
662
+ np.full((h, w, 3), bg_color, dtype=np.uint8), blend_factor,
663
+ 0
664
+ )
665
+
666
+ # Replace original ROI with blended version (using expanded coordinates)
667
+ frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded] = blended
668
+ elif box_style == "intense-pixelated-blur":
669
+ # Expand the bounding box by pixels in all directions
670
+ x1_expanded = max(0, x1 - 15)
671
+ y1_expanded = max(0, y1 - 15)
672
+ x2_expanded = min(width - 1, x2 + 25)
673
+ y2_expanded = min(height - 1, y2 + 25)
674
+
675
+ # Extract ROI
676
+ roi = frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded]
677
+ # Pixelate by resizing down and up
678
+ h, w = roi.shape[:2]
679
+ temp = cv2.resize(roi, (10, 10), interpolation=cv2.INTER_LINEAR)
680
+ pixelated = cv2.resize(temp, (w, h), interpolation=cv2.INTER_NEAREST)
681
+ # Mix up the pixelated frame slightly by adding random noise
682
+ noise = np.random.randint(0, 50, (h, w, 3), dtype=np.uint8)
683
+ pixelated = cv2.add(pixelated, noise)
684
+ # Apply stronger Gaussian blur to smooth edges
685
+ blurred_pixelated = cv2.GaussianBlur(pixelated, (15, 15), 0)
686
+ # Replace original ROI
687
+ frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded] = blurred_pixelated
688
+ elif box_style == "hitmarker":
689
+ if points:
690
+ for point in points:
691
+ try:
692
+ print(f"Processing point: {point}")
693
+ center_x = int(float(point["x"]) * width)
694
+ center_y = int(float(point["y"]) * height)
695
+ print(f"Converted coordinates: ({center_x}, {center_y})")
696
+
697
+ draw_hitmarker(frame, center_x, center_y)
698
+
699
+ label = f"{detect_keyword}" if track_id is not None else detect_keyword
700
+ label_size = cv2.getTextSize(label, FONT, 0.5, 1)[0]
701
+ cv2.putText(
702
+ frame,
703
+ label,
704
+ (center_x - label_size[0] // 2, center_y - HITMARKER_SIZE - 5),
705
+ FONT,
706
+ 0.5,
707
+ HITMARKER_COLOR,
708
+ 1,
709
+ cv2.LINE_AA,
710
+ )
711
+ except Exception as e:
712
+ print(f"Error processing individual point: {str(e)}")
713
+ print(f"Point data: {point}")
714
+
715
+ except Exception as e:
716
+ print(f"Error drawing {box_style} style box: {str(e)}")
717
+ print(f"Box data: {box}")
718
+ print(f"Keyword: {keyword}")
719
+
720
+ return frame
721
+
722
+
723
+ def filter_temporal_outliers(detections_dict):
724
+ """Filter out extremely large detections that take up most of the frame.
725
+ Only keeps detections that are reasonable in size.
726
+
727
+ Args:
728
+ detections_dict: Dictionary of {frame_number: [(box, keyword, track_id), ...]}
729
+ """
730
+ filtered_detections = {}
731
+
732
+ for t, detections in detections_dict.items():
733
+ # Only keep detections that aren't too large
734
+ valid_detections = []
735
+ for detection in detections:
736
+ # Handle both tracked and untracked detections
737
+ if len(detection) == 3: # Tracked detection with ID
738
+ box, keyword, track_id = detection
739
+ else: # Regular detection without tracking
740
+ box, keyword = detection
741
+ track_id = None
742
+
743
+ # Calculate box size as percentage of frame
744
+ width = box[2] - box[0]
745
+ height = box[3] - box[1]
746
+ area = width * height
747
+
748
+ # If box is less than 90% of frame, keep it
749
+ if area < 0.9:
750
+ if track_id is not None:
751
+ valid_detections.append((box, keyword, track_id))
752
+ else:
753
+ valid_detections.append((box, keyword))
754
+
755
+ if valid_detections:
756
+ filtered_detections[t] = valid_detections
757
+
758
+ return filtered_detections
759
+
760
+
761
+ def describe_frames(video_path, model, tokenizer, detect_keyword, test_mode=False, test_duration=DEFAULT_TEST_MODE_DURATION, grid_rows=1, grid_cols=1):
762
+ """Extract and detect objects in frames."""
763
+ props = get_video_properties(video_path)
764
+ fps = props["fps"]
765
+
766
+ # Initialize DeepSORT tracker
767
+ tracker = DeepSORTTracker()
768
+
769
+ # If in test mode, only process first N seconds
770
+ if test_mode:
771
+ frame_count = min(int(fps * test_duration), props["frame_count"])
772
+ else:
773
+ frame_count = props["frame_count"]
774
+
775
+ ad_detections = {} # Store detection results by frame number
776
+
777
+ print("Extracting frames and detecting objects...")
778
+ video = cv2.VideoCapture(video_path)
779
+
780
+ # Detect scenes first
781
+ scenes = detect(video_path, scene_detector)
782
+ scene_changes = set(end.get_frames() for _, end in scenes)
783
+ print(f"Detected {len(scenes)} scenes")
784
+
785
+ frame_count_processed = 0
786
+ with tqdm(total=frame_count) as pbar:
787
+ while frame_count_processed < frame_count:
788
+ ret, frame = video.read()
789
+ if not ret:
790
+ break
791
+
792
+ # Check if current frame is a scene change
793
+ if frame_count_processed in scene_changes:
794
+ # Detect objects in the frame
795
+ detected_objects = detect_objects_in_frame(
796
+ model, tokenizer, frame, detect_keyword, grid_rows=grid_rows, grid_cols=grid_cols
797
+ )
798
+
799
+ # Update tracker with current detections
800
+ tracked_objects = tracker.update(frame, detected_objects)
801
+
802
+ # Store results for every frame, even if empty
803
+ ad_detections[frame_count_processed] = tracked_objects
804
+
805
+ frame_count_processed += 1
806
+ pbar.update(1)
807
+
808
+ video.release()
809
+
810
+ if frame_count_processed == 0:
811
+ print("No frames could be read from video")
812
+ return {}
813
+
814
+ return ad_detections
815
+
816
+
817
+ def create_detection_video(
818
+ video_path,
819
+ ad_detections,
820
+ detect_keyword,
821
+ model,
822
+ output_path=None,
823
+ ffmpeg_preset="medium",
824
+ test_mode=False,
825
+ test_duration=DEFAULT_TEST_MODE_DURATION,
826
+ box_style="censor",
827
+ ):
828
+ """Create video with detection boxes while preserving audio."""
829
+ if output_path is None:
830
+ # Create outputs directory if it doesn't exist
831
+ outputs_dir = os.path.join(
832
+ os.path.dirname(os.path.abspath(__file__)), "outputs"
833
+ )
834
+ os.makedirs(outputs_dir, exist_ok=True)
835
+
836
+ # Clean the detect_keyword for filename
837
+ safe_keyword = "".join(
838
+ x for x in detect_keyword if x.isalnum() or x in (" ", "_", "-")
839
+ )
840
+ safe_keyword = safe_keyword.replace(" ", "_")
841
+
842
+ # Create output filename
843
+ base_name = os.path.splitext(os.path.basename(video_path))[0]
844
+ output_path = os.path.join(
845
+ outputs_dir, f"{box_style}_{safe_keyword}_{base_name}.mp4"
846
+ )
847
+
848
+ print(f"Will save output to: {output_path}")
849
+
850
+ props = get_video_properties(video_path)
851
+ fps, width, height = props["fps"], props["width"], props["height"]
852
+
853
+ # If in test mode, only process first few seconds
854
+ if test_mode:
855
+ frame_count = min(int(fps * test_duration), props["frame_count"])
856
+ print(f"Test mode enabled: Processing first {test_duration} seconds ({frame_count} frames)")
857
+ else:
858
+ frame_count = props["frame_count"]
859
+ print("Full video mode: Processing entire video")
860
+
861
+ video = cv2.VideoCapture(video_path)
862
+
863
+ # Create temp output path by adding _temp before the extension
864
+ base, ext = os.path.splitext(output_path)
865
+ temp_output = f"{base}_temp{ext}"
866
+ temp_audio = f"{base}_audio.aac" # Temporary audio file
867
+
868
+ out = cv2.VideoWriter(
869
+ temp_output, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
870
+ )
871
+
872
+ print("Creating detection video...")
873
+ frame_count_processed = 0
874
+
875
+ with tqdm(total=frame_count) as pbar:
876
+ while frame_count_processed < frame_count:
877
+ ret, frame = video.read()
878
+ if not ret:
879
+ break
880
+
881
+ # Get detections for this exact frame
882
+ if frame_count_processed in ad_detections:
883
+ current_detections = ad_detections[frame_count_processed]
884
+ if current_detections:
885
+ frame = draw_ad_boxes(
886
+ frame, current_detections, detect_keyword, model, box_style=box_style
887
+ )
888
+
889
+ out.write(frame)
890
+ frame_count_processed += 1
891
+ pbar.update(1)
892
+
893
+ video.release()
894
+ out.release()
895
+
896
+ # Extract audio from original video
897
+ try:
898
+ if test_mode:
899
+ # In test mode, extract only the required duration of audio
900
+ subprocess.run(
901
+ [
902
+ "ffmpeg",
903
+ "-y",
904
+ "-i",
905
+ video_path,
906
+ "-t",
907
+ str(test_duration),
908
+ "-vn", # No video
909
+ "-acodec",
910
+ "copy",
911
+ temp_audio,
912
+ ],
913
+ check=True,
914
+ )
915
+ else:
916
+ subprocess.run(
917
+ [
918
+ "ffmpeg",
919
+ "-y",
920
+ "-i",
921
+ video_path,
922
+ "-vn", # No video
923
+ "-acodec",
924
+ "copy",
925
+ temp_audio,
926
+ ],
927
+ check=True,
928
+ )
929
+ except subprocess.CalledProcessError as e:
930
+ print(f"Error extracting audio: {str(e)}")
931
+ if os.path.exists(temp_output):
932
+ os.remove(temp_output)
933
+ return None
934
+
935
+ # Merge processed video with original audio
936
+ try:
937
+ # Base FFmpeg command
938
+ ffmpeg_cmd = [
939
+ "ffmpeg",
940
+ "-y",
941
+ "-i",
942
+ temp_output,
943
+ "-i",
944
+ temp_audio,
945
+ "-c:v",
946
+ "libx264",
947
+ "-preset",
948
+ ffmpeg_preset,
949
+ "-crf",
950
+ "23",
951
+ "-c:a",
952
+ "aac",
953
+ "-b:a",
954
+ "192k",
955
+ "-movflags",
956
+ "+faststart", # Better web playback
957
+ ]
958
+
959
+ if test_mode:
960
+ # In test mode, ensure output duration matches test_duration
961
+ ffmpeg_cmd.extend([
962
+ "-t",
963
+ str(test_duration),
964
+ "-shortest" # Ensure output duration matches shortest input
965
+ ])
966
+
967
+ ffmpeg_cmd.extend([
968
+ "-loglevel",
969
+ "error",
970
+ output_path
971
+ ])
972
+
973
+ subprocess.run(ffmpeg_cmd, check=True)
974
+
975
+ # Clean up temporary files
976
+ os.remove(temp_output)
977
+ os.remove(temp_audio)
978
+
979
+ if not os.path.exists(output_path):
980
+ print(
981
+ f"Warning: FFmpeg completed but output file not found at {output_path}"
982
+ )
983
+ return None
984
+
985
+ return output_path
986
+
987
+ except subprocess.CalledProcessError as e:
988
+ print(f"Error merging audio with video: {str(e)}")
989
+ if os.path.exists(temp_output):
990
+ os.remove(temp_output)
991
+ if os.path.exists(temp_audio):
992
+ os.remove(temp_audio)
993
+ return None
994
+
995
+
996
+ def process_video(
997
+ video_path,
998
+ target_object,
999
+ test_mode=False,
1000
+ test_duration=DEFAULT_TEST_MODE_DURATION,
1001
+ ffmpeg_preset="medium",
1002
+ grid_rows=1,
1003
+ grid_cols=1,
1004
+ box_style="censor",
1005
+ ):
1006
+ """Process a video to detect and visualize specified objects."""
1007
+ try:
1008
+ print(f"\nProcessing: {video_path}")
1009
+ print(f"Looking for: {target_object}")
1010
+
1011
+ # Load model
1012
+ print("Loading Moondream model...")
1013
+ model, tokenizer = load_moondream()
1014
+
1015
+ # Get video properties
1016
+ props = get_video_properties(video_path)
1017
+
1018
+ # Initialize scene detector with ContentDetector
1019
+ scene_detector = ContentDetector(threshold=30.0) # Adjust threshold as needed
1020
+
1021
+ # Initialize DeepSORT tracker
1022
+ tracker = DeepSORTTracker()
1023
+
1024
+ # If in test mode, only process first N seconds
1025
+ if test_mode:
1026
+ frame_count = min(int(props["fps"] * test_duration), props["frame_count"])
1027
+ else:
1028
+ frame_count = props["frame_count"]
1029
+
1030
+ ad_detections = {} # Store detection results by frame number
1031
+
1032
+ print("Extracting frames and detecting objects...")
1033
+ video = cv2.VideoCapture(video_path)
1034
+
1035
+ # Detect scenes first
1036
+ scenes = detect(video_path, scene_detector)
1037
+ scene_changes = set(end.get_frames() for _, end in scenes)
1038
+ print(f"Detected {len(scenes)} scenes")
1039
+
1040
+ frame_count_processed = 0
1041
+ with tqdm(total=frame_count) as pbar:
1042
+ while frame_count_processed < frame_count:
1043
+ ret, frame = video.read()
1044
+ if not ret:
1045
+ break
1046
+
1047
+ # Check if current frame is a scene change
1048
+ if frame_count_processed in scene_changes:
1049
+ print(f"Scene change detected at frame {frame_count_processed}. Resetting tracker.")
1050
+ tracker.reset()
1051
+
1052
+ # Detect objects in the frame
1053
+ detected_objects = detect_objects_in_frame(
1054
+ model, tokenizer, frame, target_object, grid_rows=grid_rows, grid_cols=grid_cols
1055
+ )
1056
+
1057
+ # Update tracker with current detections
1058
+ tracked_objects = tracker.update(frame, detected_objects)
1059
+
1060
+ # Store results for every frame, even if empty
1061
+ ad_detections[frame_count_processed] = tracked_objects
1062
+
1063
+ frame_count_processed += 1
1064
+ pbar.update(1)
1065
+
1066
+ video.release()
1067
+
1068
+ if frame_count_processed == 0:
1069
+ print("No frames could be read from video")
1070
+ return {}
1071
+
1072
+ # Apply filtering
1073
+ filtered_ad_detections = filter_temporal_outliers(ad_detections)
1074
+
1075
+ # Build detection data structure
1076
+ detection_data = {
1077
+ "video_metadata": {
1078
+ "file_name": os.path.basename(video_path),
1079
+ "fps": props["fps"],
1080
+ "width": props["width"],
1081
+ "height": props["height"],
1082
+ "total_frames": props["frame_count"],
1083
+ "duration_sec": props["frame_count"] / props["fps"],
1084
+ "detect_keyword": target_object,
1085
+ "test_mode": test_mode,
1086
+ "grid_size": f"{grid_rows}x{grid_cols}",
1087
+ "box_style": box_style,
1088
+ "timestamp": datetime.now().isoformat()
1089
+ },
1090
+ "frame_detections": [
1091
+ {
1092
+ "frame": frame_num,
1093
+ "timestamp": frame_num / props["fps"],
1094
+ "objects": [
1095
+ {
1096
+ "keyword": kw,
1097
+ "bbox": list(box), # Convert numpy array to list if needed
1098
+ "track_id": track_id if len(detection) == 3 else None
1099
+ }
1100
+ for detection in filtered_ad_detections.get(frame_num, [])
1101
+ for box, kw, *track_id in [detection] # Unpack detection tuple, track_id will be empty list if not present
1102
+ ]
1103
+ }
1104
+ for frame_num in range(props["frame_count"] if not test_mode else min(int(props["fps"] * test_duration), props["frame_count"]))
1105
+ ]
1106
+ }
1107
+
1108
+ # Save filtered data
1109
+ outputs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "outputs")
1110
+ os.makedirs(outputs_dir, exist_ok=True)
1111
+ base_name = os.path.splitext(os.path.basename(video_path))[0]
1112
+ json_path = os.path.join(outputs_dir, f"{box_style}_{target_object}_{base_name}_detections.json")
1113
+
1114
+ from persistence import save_detection_data
1115
+ if not save_detection_data(detection_data, json_path):
1116
+ print("Warning: Failed to save detection data")
1117
+
1118
+ # Create video with filtered data
1119
+ output_path = create_detection_video(
1120
+ video_path,
1121
+ filtered_ad_detections,
1122
+ target_object,
1123
+ model,
1124
+ ffmpeg_preset=ffmpeg_preset,
1125
+ test_mode=test_mode,
1126
+ test_duration=test_duration,
1127
+ box_style=box_style,
1128
+ )
1129
+
1130
+ if output_path is None:
1131
+ print("\nError: Failed to create output video")
1132
+ return None
1133
+
1134
+ print(f"\nOutput saved to: {output_path}")
1135
+ print(f"Detection data saved to: {json_path}")
1136
+ return output_path
1137
+
1138
+ except Exception as e:
1139
+ print(f"Error processing video: {str(e)}")
1140
+ import traceback
1141
+ traceback.print_exc()
1142
+ return None
1143
+
1144
+
1145
+ def main():
1146
+ """Process all videos in the inputs directory."""
1147
+ parser = argparse.ArgumentParser(
1148
+ description="Detect objects in videos using Moondream2"
1149
+ )
1150
+ parser.add_argument(
1151
+ "--test", action="store_true", help="Process only first 3 seconds of each video"
1152
+ )
1153
+ parser.add_argument(
1154
+ "--test-duration",
1155
+ type=int,
1156
+ default=DEFAULT_TEST_MODE_DURATION,
1157
+ help=f"Number of seconds to process in test mode (default: {DEFAULT_TEST_MODE_DURATION})"
1158
+ )
1159
+ parser.add_argument(
1160
+ "--preset",
1161
+ choices=FFMPEG_PRESETS,
1162
+ default="medium",
1163
+ help="FFmpeg encoding preset (default: medium). Faster presets = lower quality",
1164
+ )
1165
+ parser.add_argument(
1166
+ "--detect",
1167
+ type=str,
1168
+ default="face",
1169
+ help='Object to detect in the video (default: face, use --detect "thing to detect" to override)',
1170
+ )
1171
+ parser.add_argument(
1172
+ "--rows",
1173
+ type=int,
1174
+ default=1,
1175
+ help="Number of rows to split each frame into (default: 1)",
1176
+ )
1177
+ parser.add_argument(
1178
+ "--cols",
1179
+ type=int,
1180
+ default=1,
1181
+ help="Number of columns to split each frame into (default: 1)",
1182
+ )
1183
+ parser.add_argument(
1184
+ "--box-style",
1185
+ choices=["censor", "bounding-box", "hitmarker", "sam", "sam-fast", "fuzzy-blur",
1186
+ "pixelated-blur", "intense-pixelated-blur", "obfuscated-pixel"],
1187
+ default="censor",
1188
+ help="Style of detection visualization (default: censor)",
1189
+ )
1190
+ args = parser.parse_args()
1191
+
1192
+ input_dir = "inputs"
1193
+ os.makedirs(input_dir, exist_ok=True)
1194
+ os.makedirs("outputs", exist_ok=True)
1195
+
1196
+ video_files = [
1197
+ f
1198
+ for f in os.listdir(input_dir)
1199
+ if f.lower().endswith((".mp4", ".avi", ".mov", ".mkv", ".webm"))
1200
+ ]
1201
+
1202
+ if not video_files:
1203
+ print("No video files found in 'inputs' directory")
1204
+ return
1205
+
1206
+ print(f"Found {len(video_files)} videos to process")
1207
+ print(f"Will detect: {args.detect}")
1208
+ if args.test:
1209
+ print("Running in test mode - processing only first 3 seconds of each video")
1210
+ print(f"Using FFmpeg preset: {args.preset}")
1211
+ print(f"Grid size: {args.rows}x{args.cols}")
1212
+ print(f"Box style: {args.box_style}")
1213
+
1214
+ success_count = 0
1215
+ for video_file in video_files:
1216
+ video_path = os.path.join(input_dir, video_file)
1217
+ output_path = process_video(
1218
+ video_path,
1219
+ args.detect,
1220
+ test_mode=args.test,
1221
+ test_duration=args.test_duration,
1222
+ ffmpeg_preset=args.preset,
1223
+ grid_rows=args.rows,
1224
+ grid_cols=args.cols,
1225
+ box_style=args.box_style,
1226
+ )
1227
+ if output_path:
1228
+ success_count += 1
1229
+
1230
+ print(
1231
+ f"\nProcessing complete. Successfully processed {success_count} out of {len(video_files)} videos."
1232
+ )
1233
+
1234
+
1235
+ if __name__ == "__main__":
1236
+ main()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ libvips
2
+ ffmpeg
persistence.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ def save_detection_data(data, output_file):
5
+ """
6
+ Saves the detection data to a JSON file.
7
+
8
+ Args:
9
+ data (dict): The complete detection data structure.
10
+ output_file (str): Path to the output JSON file.
11
+ """
12
+ try:
13
+ # Create directory if it doesn't exist
14
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
15
+
16
+ with open(output_file, "w") as f:
17
+ json.dump(data, f, indent=4)
18
+ print(f"Detection data saved to {output_file}")
19
+ return True
20
+ except Exception as e:
21
+ print(f"Error saving data: {str(e)}")
22
+ return False
23
+
24
+ def load_detection_data(input_file):
25
+ """
26
+ Loads the detection data from a JSON file.
27
+
28
+ Args:
29
+ input_file (str): Path to the JSON file.
30
+
31
+ Returns:
32
+ dict: The loaded detection data, or None if there was an error.
33
+ """
34
+ try:
35
+ with open(input_file, "r") as f:
36
+ return json.load(f)
37
+ except Exception as e:
38
+ print(f"Error loading data: {str(e)}")
39
+ return None
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ # if on windows: pip install torch==2.5.1+cu121 torchvision==0.20.1+cu121 --index-url https://download.pytorch.org/whl/cu121
4
+ transformers>=4.36.0
5
+ opencv-python>=4.8.0
6
+ pillow>=10.0.0
7
+ numpy>=1.24.0
8
+ tqdm>=4.66.0
9
+ ffmpeg-python
10
+ einops
11
+ pyvips-binary
12
+ pyvips
13
+ accelerate
14
+ # for spaces
15
+ --extra-index-url https://download.pytorch.org/whl/cu113
16
+ spaces
17
+ # SAM dependencies
18
+ torchvision>=0.20.1
19
+ matplotlib>=3.7.0
20
+ pandas>=2.0.0
21
+ plotly
22
+ # DeepSORT dependencies
23
+ deep-sort-realtime>=1.3.2
24
+ scikit-learn # Required for deep-sort-realtime
25
+ # Scene detection dependencies (for intelligent scene-aware tracking)
26
+ scenedetect[opencv]>=0.6.2 # Provides scene change detection capabilities
video_visualization.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import subprocess
4
+ import matplotlib.pyplot as plt
5
+ import pandas as pd
6
+ import cv2
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ from persistence import load_detection_data
10
+
11
+ def create_frame_data(json_path):
12
+ """Create frame-by-frame detection data for visualization."""
13
+ try:
14
+ data = load_detection_data(json_path)
15
+ if not data:
16
+ print("No data loaded from JSON file")
17
+ return None
18
+
19
+ if "video_metadata" not in data or "frame_detections" not in data:
20
+ print("Invalid JSON structure: missing required fields")
21
+ return None
22
+
23
+ # Extract video metadata
24
+ metadata = data["video_metadata"]
25
+ if "fps" not in metadata or "total_frames" not in metadata:
26
+ print("Invalid metadata: missing fps or total_frames")
27
+ return None
28
+
29
+ fps = metadata["fps"]
30
+ total_frames = metadata["total_frames"]
31
+
32
+ # Create frame data
33
+ frame_counts = {}
34
+ for frame_data in data["frame_detections"]:
35
+ if "frame" not in frame_data or "objects" not in frame_data:
36
+ continue # Skip invalid frame data
37
+ frame_num = frame_data["frame"]
38
+ frame_counts[frame_num] = len(frame_data["objects"])
39
+
40
+ # Fill in missing frames with 0 detections
41
+ for frame in range(total_frames):
42
+ if frame not in frame_counts:
43
+ frame_counts[frame] = 0
44
+
45
+ if not frame_counts:
46
+ print("No valid frame data found")
47
+ return None
48
+
49
+ # Convert to DataFrame
50
+ df = pd.DataFrame(list(frame_counts.items()), columns=["frame", "detections"])
51
+ df["timestamp"] = df["frame"] / fps
52
+
53
+ return df, metadata
54
+
55
+ except Exception as e:
56
+ print(f"Error creating frame data: {str(e)}")
57
+ import traceback
58
+ traceback.print_exc()
59
+ return None
60
+
61
+ def generate_frame_image(df, frame_num, temp_dir, max_y):
62
+ """Generate and save a single frame of the visualization."""
63
+ # Set the style to dark background
64
+ plt.style.use('dark_background')
65
+
66
+ # Set global font to monospace
67
+ plt.rcParams['font.family'] = 'monospace'
68
+ plt.rcParams['font.monospace'] = ['DejaVu Sans Mono']
69
+
70
+ plt.figure(figsize=(10, 6))
71
+
72
+ # Plot data up to current frame
73
+ current_data = df[df['frame'] <= frame_num]
74
+ plt.plot(df['frame'], df['detections'], color='#1a1a1a', alpha=0.5) # Darker background line
75
+ plt.plot(current_data['frame'], current_data['detections'], color='#00ff41') # Matrix green
76
+
77
+ # Add vertical line for current position
78
+ plt.axvline(x=frame_num, color='#ff0000', linestyle='-', alpha=0.7) # Keep red for position
79
+
80
+ # Set consistent axes
81
+ plt.xlim(0, len(df) - 1)
82
+ plt.ylim(0, max_y * 1.1) # Add 10% padding
83
+
84
+ # Add labels with Matrix green color
85
+ plt.title(f'FRAME {frame_num:04d} - DETECTIONS OVER TIME', color='#00ff41', pad=20)
86
+ plt.xlabel('FRAME NUMBER', color='#00ff41')
87
+ plt.ylabel('NUMBER OF DETECTIONS', color='#00ff41')
88
+
89
+ # Add current stats in Matrix green with monospace formatting
90
+ current_detections = df[df['frame'] == frame_num]['detections'].iloc[0]
91
+ plt.text(0.02, 0.98, f'CURRENT DETECTIONS: {current_detections:02d}',
92
+ transform=plt.gca().transAxes, verticalalignment='top',
93
+ color='#00ff41', family='monospace')
94
+
95
+ # Style the grid and ticks
96
+ plt.grid(True, color='#1a1a1a', linestyle='-', alpha=0.3)
97
+ plt.tick_params(colors='#00ff41')
98
+
99
+ # Save frame
100
+ frame_path = os.path.join(temp_dir, f'frame_{frame_num:05d}.png')
101
+ plt.savefig(frame_path, bbox_inches='tight', dpi=100, facecolor='black', edgecolor='none')
102
+ plt.close()
103
+
104
+ return frame_path
105
+
106
+ def generate_gauge_frame(df, frame_num, temp_dir, detect_keyword="OBJECT"):
107
+ """Generate a modern square-style binary gauge visualization frame."""
108
+ # Set the style to dark background
109
+ plt.style.use('dark_background')
110
+
111
+ # Set global font to monospace
112
+ plt.rcParams['font.family'] = 'monospace'
113
+ plt.rcParams['font.monospace'] = ['DejaVu Sans Mono']
114
+
115
+ # Create figure with 16:9 aspect ratio
116
+ plt.figure(figsize=(16, 9))
117
+
118
+ # Get current detection state
119
+ current_detections = df[df['frame'] == frame_num]['detections'].iloc[0]
120
+ has_detection = current_detections > 0
121
+
122
+ # Create a simple gauge visualization
123
+ plt.axis('off')
124
+
125
+ # Set colors
126
+ if has_detection:
127
+ color = '#00ff41' # Matrix green for YES
128
+ status = 'YES'
129
+ indicator_pos = 0.8 # Right position
130
+ else:
131
+ color = '#ff0000' # Red for NO
132
+ status = 'NO'
133
+ indicator_pos = 0.2 # Left position
134
+
135
+ # Draw background rectangle
136
+ background = plt.Rectangle((0.1, 0.3), 0.8, 0.2,
137
+ facecolor='#1a1a1a',
138
+ edgecolor='#333333',
139
+ linewidth=2)
140
+ plt.gca().add_patch(background)
141
+
142
+ # Draw indicator
143
+ indicator_width = 0.05
144
+ indicator = plt.Rectangle((indicator_pos - indicator_width/2, 0.25),
145
+ indicator_width, 0.3,
146
+ facecolor=color,
147
+ edgecolor=None)
148
+ plt.gca().add_patch(indicator)
149
+
150
+ # Add tick marks
151
+ tick_positions = [0.2, 0.5, 0.8] # NO, CENTER, YES
152
+ for x in tick_positions:
153
+ plt.plot([x, x], [0.3, 0.5], color='#444444', linewidth=2)
154
+
155
+ # Add YES/NO labels
156
+ plt.text(0.8, 0.2, 'YES', color='#00ff41', fontsize=14,
157
+ ha='center', va='center', family='monospace')
158
+ plt.text(0.2, 0.2, 'NO', color='#ff0000', fontsize=14,
159
+ ha='center', va='center', family='monospace')
160
+
161
+ # Add status box at top with detection keyword
162
+ plt.text(0.5, 0.8, f'{detect_keyword.upper()} DETECTED?', color=color,
163
+ fontsize=16, ha='center', va='center', family='monospace',
164
+ bbox=dict(facecolor='#1a1a1a',
165
+ edgecolor=color,
166
+ linewidth=2,
167
+ pad=10))
168
+
169
+ # Add frame counter at bottom
170
+ plt.text(0.5, 0.1, f'FRAME: {frame_num:04d}', color='#00ff41',
171
+ fontsize=14, ha='center', va='center', family='monospace')
172
+
173
+ # Add subtle grid lines for depth
174
+ for x in np.linspace(0.2, 0.8, 7):
175
+ plt.plot([x, x], [0.3, 0.5], color='#222222', linewidth=1, zorder=0)
176
+
177
+ # Add glow effect to indicator
178
+ for i in range(3):
179
+ glow = plt.Rectangle((indicator_pos - (indicator_width + i*0.01)/2,
180
+ 0.25 - i*0.01),
181
+ indicator_width + i*0.01,
182
+ 0.3 + i*0.02,
183
+ facecolor=color,
184
+ alpha=0.1/(i+1))
185
+ plt.gca().add_patch(glow)
186
+
187
+ # Set consistent plot limits
188
+ plt.xlim(0, 1)
189
+ plt.ylim(0, 1)
190
+
191
+ # Save frame with 16:9 aspect ratio
192
+ frame_path = os.path.join(temp_dir, f'gauge_{frame_num:05d}.png')
193
+ plt.savefig(frame_path,
194
+ bbox_inches='tight',
195
+ dpi=100,
196
+ facecolor='black',
197
+ edgecolor='none',
198
+ pad_inches=0)
199
+ plt.close()
200
+
201
+ return frame_path
202
+
203
+ def create_video_visualization(json_path, style="timeline"):
204
+ """Create a video visualization of the detection data."""
205
+ try:
206
+ if not json_path:
207
+ return None, "No JSON file provided"
208
+
209
+ if not os.path.exists(json_path):
210
+ return None, f"File not found: {json_path}"
211
+
212
+ # Load and process data
213
+ result = create_frame_data(json_path)
214
+ if result is None:
215
+ return None, "Failed to load detection data from JSON file"
216
+
217
+ frame_data, metadata = result
218
+ if len(frame_data) == 0:
219
+ return None, "No frame data found in JSON file"
220
+
221
+ total_frames = metadata["total_frames"]
222
+ detect_keyword = metadata.get("detect_keyword", "OBJECT") # Get the detection keyword
223
+
224
+ # Create temporary directory for frames
225
+ with tempfile.TemporaryDirectory() as temp_dir:
226
+ max_y = frame_data['detections'].max()
227
+
228
+ # Generate each frame
229
+ print("Generating frames...")
230
+ frame_paths = []
231
+ with tqdm(total=total_frames, desc="Generating frames") as pbar:
232
+ for frame in range(total_frames):
233
+ try:
234
+ if style == "gauge":
235
+ frame_path = generate_gauge_frame(frame_data, frame, temp_dir, detect_keyword)
236
+ else: # default to timeline
237
+ frame_path = generate_frame_image(frame_data, frame, temp_dir, max_y)
238
+ if frame_path and os.path.exists(frame_path):
239
+ frame_paths.append(frame_path)
240
+ else:
241
+ print(f"Warning: Failed to generate frame {frame}")
242
+ pbar.update(1)
243
+ except Exception as e:
244
+ print(f"Error generating frame {frame}: {str(e)}")
245
+ continue
246
+
247
+ if not frame_paths:
248
+ return None, "Failed to generate any frames"
249
+
250
+ # Create output video path
251
+ output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "outputs")
252
+ os.makedirs(output_dir, exist_ok=True)
253
+ output_video = os.path.join(output_dir, f"detection_visualization_{style}.mp4")
254
+
255
+ # Create temp output path
256
+ base, ext = os.path.splitext(output_video)
257
+ temp_output = f"{base}_temp{ext}"
258
+
259
+ # First pass: Create video with OpenCV VideoWriter
260
+ print("Creating initial video...")
261
+ # Get frame size from first image
262
+ first_frame = cv2.imread(frame_paths[0])
263
+ height, width = first_frame.shape[:2]
264
+
265
+ out = cv2.VideoWriter(
266
+ temp_output,
267
+ cv2.VideoWriter_fourcc(*"mp4v"),
268
+ metadata["fps"],
269
+ (width, height)
270
+ )
271
+
272
+ with tqdm(total=total_frames, desc="Creating video") as pbar: # Use total_frames here too
273
+ for frame_path in frame_paths:
274
+ frame = cv2.imread(frame_path)
275
+ out.write(frame)
276
+ pbar.update(1)
277
+
278
+ out.release()
279
+
280
+ # Second pass: Convert to web-compatible format
281
+ print("Converting to web format...")
282
+ try:
283
+ subprocess.run(
284
+ [
285
+ "ffmpeg",
286
+ "-y",
287
+ "-i",
288
+ temp_output,
289
+ "-c:v",
290
+ "libx264",
291
+ "-preset",
292
+ "medium",
293
+ "-crf",
294
+ "23",
295
+ "-movflags",
296
+ "+faststart", # Better web playback
297
+ "-loglevel",
298
+ "error",
299
+ output_video,
300
+ ],
301
+ check=True,
302
+ )
303
+
304
+ os.remove(temp_output) # Remove the temporary file
305
+
306
+ if not os.path.exists(output_video):
307
+ print(f"Warning: FFmpeg completed but output file not found at {output_video}")
308
+ return None, "Failed to create video"
309
+
310
+ # Return video path and stats
311
+ stats = f"""Video Stats:
312
+ FPS: {metadata['fps']}
313
+ Total Frames: {metadata['total_frames']}
314
+ Duration: {metadata['duration_sec']:.2f} seconds
315
+ Max Detections in a Frame: {frame_data['detections'].max()}
316
+ Average Detections per Frame: {frame_data['detections'].mean():.2f}"""
317
+
318
+ return output_video, stats
319
+
320
+ except subprocess.CalledProcessError as e:
321
+ print(f"Error running FFmpeg: {str(e)}")
322
+ if os.path.exists(temp_output):
323
+ os.remove(temp_output)
324
+ return None, f"Error creating visualization: {str(e)}"
325
+
326
+ except Exception as e:
327
+ print(f"Error creating video visualization: {str(e)}")
328
+ import traceback
329
+ traceback.print_exc()
330
+ return None, f"Error creating visualization: {str(e)}"
visualization.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+ from persistence import load_detection_data
4
+ import argparse
5
+
6
+ def visualize_detections(json_path):
7
+ """
8
+ Visualize detection data from a JSON file.
9
+
10
+ Args:
11
+ json_path (str): Path to the JSON file containing detection data.
12
+ """
13
+ # Load the persisted JSON data
14
+ data = load_detection_data(json_path)
15
+ if not data:
16
+ return
17
+
18
+ # Convert the frame detections to a DataFrame
19
+ rows = []
20
+ for frame_data in data["frame_detections"]:
21
+ frame = frame_data["frame"]
22
+ timestamp = frame_data["timestamp"]
23
+ for obj in frame_data["objects"]:
24
+ rows.append({
25
+ "frame": frame,
26
+ "timestamp": timestamp,
27
+ "keyword": obj["keyword"],
28
+ "x1": obj["bbox"][0],
29
+ "y1": obj["bbox"][1],
30
+ "x2": obj["bbox"][2],
31
+ "y2": obj["bbox"][3],
32
+ "area": (obj["bbox"][2] - obj["bbox"][0]) * (obj["bbox"][3] - obj["bbox"][1])
33
+ })
34
+
35
+ if not rows:
36
+ print("No detections found in the data")
37
+ return
38
+
39
+ df = pd.DataFrame(rows)
40
+
41
+ # Create a figure with multiple subplots
42
+ fig = plt.figure(figsize=(15, 10))
43
+
44
+ # Plot 1: Number of detections per frame
45
+ plt.subplot(2, 2, 1)
46
+ detections_per_frame = df.groupby("frame").size()
47
+ plt.plot(detections_per_frame.index, detections_per_frame.values)
48
+ plt.xlabel("Frame")
49
+ plt.ylabel("Number of Detections")
50
+ plt.title("Detections Per Frame")
51
+
52
+ # Plot 2: Distribution of detection areas
53
+ plt.subplot(2, 2, 2)
54
+ df["area"].hist(bins=30)
55
+ plt.xlabel("Detection Area (normalized)")
56
+ plt.ylabel("Count")
57
+ plt.title("Distribution of Detection Areas")
58
+
59
+ # Plot 3: Average detection area over time
60
+ plt.subplot(2, 2, 3)
61
+ avg_area = df.groupby("frame")["area"].mean()
62
+ plt.plot(avg_area.index, avg_area.values)
63
+ plt.xlabel("Frame")
64
+ plt.ylabel("Average Detection Area")
65
+ plt.title("Average Detection Area Over Time")
66
+
67
+ # Plot 4: Heatmap of detection centers
68
+ plt.subplot(2, 2, 4)
69
+ df["center_x"] = (df["x1"] + df["x2"]) / 2
70
+ df["center_y"] = (df["y1"] + df["y2"]) / 2
71
+ plt.hist2d(df["center_x"], df["center_y"], bins=30)
72
+ plt.colorbar()
73
+ plt.xlabel("X Position")
74
+ plt.ylabel("Y Position")
75
+ plt.title("Detection Center Heatmap")
76
+
77
+ # Adjust layout and display
78
+ plt.tight_layout()
79
+ plt.show()
80
+
81
+ # Print summary statistics
82
+ print("\nSummary Statistics:")
83
+ print(f"Total frames analyzed: {len(data['frame_detections'])}")
84
+ print(f"Total detections: {len(df)}")
85
+ print(f"Average detections per frame: {len(df) / len(data['frame_detections']):.2f}")
86
+ print(f"\nVideo metadata:")
87
+ for key, value in data["video_metadata"].items():
88
+ print(f"{key}: {value}")
89
+
90
+ def main():
91
+ parser = argparse.ArgumentParser(description="Visualize object detection data")
92
+ parser.add_argument("json_file", help="Path to the JSON file containing detection data")
93
+ args = parser.parse_args()
94
+
95
+ visualize_detections(args.json_file)
96
+
97
+ if __name__ == "__main__":
98
+ main()