Spaces:
Sleeping
Sleeping
import os | |
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
import gradio as gr | |
import requests | |
import time | |
import re | |
import torch | |
from ultralytics import YOLO | |
from transformers import AutoImageProcessor, AutoModelForObjectDetection | |
# ================= Load Models ================= | |
# Violence detection model | |
violence_model = tf.keras.models.load_model("modelnew.h5") if os.path.exists("modelnew.h5") else None | |
# Hugging Face DETR Weapon Detection | |
weapon_model_id = "KIRANKALLA/WeaponDetection" | |
weapon_processor = AutoImageProcessor.from_pretrained(weapon_model_id) | |
weapon_model = AutoModelForObjectDetection.from_pretrained(weapon_model_id) | |
id2label = weapon_model.config.id2label | |
# YOLOv8 Person model | |
person_yolo = YOLO("yolov8n.pt") if os.path.exists("yolov8n.pt") else None | |
# ================= Detection Functions ================= | |
def draw_label(frame, text, x, y, color): | |
cv2.rectangle(frame, (x, y-25), (x+len(text)*12, y), color, -1) | |
cv2.putText(frame, text, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2) | |
def detect_violence(frame): | |
if violence_model is None: | |
return frame, "Violence model missing!" | |
resized = cv2.resize(frame, (128, 128)) / 255.0 | |
prediction = violence_model.predict(np.expand_dims(resized, axis=0), verbose=0)[0][0] | |
violence = prediction > 0.5 | |
color = (0, 0, 255) if violence else (0, 255, 0) | |
draw_label(frame, f"Violence: {prediction:.2f}", 10, 30, color) | |
return frame, ("⚠ ALERT: Violence Detected!" if violence else "No Violence Detected") | |
def detect_weapon(frame): | |
# Convert BGR to RGB | |
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
inputs = weapon_processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = weapon_model(**inputs) | |
# Post-process detections | |
target_sizes = torch.tensor([image.shape[:2]]) | |
results = weapon_processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.5)[0] | |
alert = "No Weapon Detected" | |
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
box = [int(i) for i in box.tolist()] | |
x1, y1, x2, y2 = box | |
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 3) | |
label_text = f"{id2label[label.item()]}: {score:.2f}" | |
draw_label(frame, label_text, x1, y1, (0,0,255)) | |
alert = "⚠ ALERT: Weapon Detected!" | |
return frame, alert | |
def detect_person(frame): | |
if person_yolo is None: | |
return frame, "YOLOv8 model missing!" | |
results = person_yolo(frame, stream=True) | |
count = 0 | |
for r in results: | |
for box in r.boxes: | |
cls = int(box.cls[0]) | |
conf = float(box.conf[0]) | |
if cls == 0 and conf > 0.5: | |
count += 1 | |
x1, y1, x2, y2 = map(int, box.xyxy[0]) | |
cv2.rectangle(frame, (x1,y1), (x2,y2), (0,255,0), 3) | |
draw_label(frame, "Person", x1, y1, (0,255,0)) | |
draw_label(frame, f"Count: {count}", 10, 30, (255,255,0)) | |
return frame, f"Total Persons: {count}" | |
def track_person(frame): | |
if person_yolo is None: | |
return frame, "YOLOv8 model missing!" | |
results = person_yolo.track(frame, persist=True) | |
if results and len(results) > 0: | |
boxes = results[0].boxes | |
ids = results[0].boxes.id | |
for i, box in enumerate(boxes): | |
x1, y1, x2, y2 = map(int, box.xyxy[0]) | |
label = f"ID {int(ids[i])}" if ids is not None else "Person" | |
cv2.rectangle(frame, (x1,y1), (x2,y2), (0,255,255), 3) | |
draw_label(frame, label, x1, y1, (0,255,255)) | |
return frame, "Tracking Active" | |
def parse_person_count(person_text): | |
try: | |
m = re.search(r'(\d+)', person_text) | |
return int(m.group(1)) if m else 0 | |
except: | |
return 0 | |
# ================= Live Inference ================= | |
current_mode = {"mode": "Violence Detection"} | |
def set_mode(new_mode): | |
current_mode["mode"] = new_mode | |
return f"Mode switched to {new_mode}" | |
def live_inference(video_frame): | |
frame = cv2.cvtColor(video_frame, cv2.COLOR_RGB2BGR) | |
# primary detection depending on mode | |
if current_mode["mode"] == "Violence Detection": | |
frame, alert = detect_violence(frame) | |
elif current_mode["mode"] == "Weapon Detection": | |
frame, alert = detect_weapon(frame) | |
elif current_mode["mode"] == "Person Counting": | |
frame, alert = detect_person(frame) | |
elif current_mode["mode"] == "Person Tracking": | |
frame, alert = track_person(frame) | |
else: | |
alert = "Invalid Mode" | |
# Always try to get a person count | |
person_count = 0 | |
tracked_ids_str = "N/A" | |
if person_yolo is not None: | |
try: | |
tmp = frame.copy() | |
_, person_text = detect_person(tmp) | |
person_count = parse_person_count(person_text) | |
except Exception: | |
person_count = 0 | |
# Build the auto-report text | |
timestamp = time.strftime("%Y-%m-%d %H:%M:%S") | |
violence_status = "Yes" if "⚠ ALERT: Violence" in alert else "No" | |
weapon_status = "Yes" if "⚠ ALERT: Weapon" in alert else "No" | |
report = f"[{timestamp}] Violence detected: {violence_status} | Weapon detected: {weapon_status} | Persons at location: {person_count} | Tracked IDs: {tracked_ids_str}" | |
# Convert output image back to RGB for Gradio | |
out_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
return out_img, alert, report | |
# ================= Gradio UI ================= | |
with gr.Blocks(css=""" | |
#alert-box {font-size:24px;font-weight:bold;text-align:center;} | |
#report-box {font-family:monospace; white-space:pre-wrap; height:140px; overflow:auto;} | |
.blink {animation: blink 1s infinite;} | |
@keyframes blink {0%{background:red;color:white;}50%{background:white;color:red;}100%{background:red;color:white;}} | |
""") as demo: | |
gr.Markdown("# 🚨 Live AI Surveillance Dashboard") | |
with gr.Row(): | |
violence_btn = gr.Button("Violence Detection") | |
weapon_btn = gr.Button("Weapon Detection") | |
count_btn = gr.Button("Person Counting") | |
track_btn = gr.Button("Person Tracking") | |
with gr.Row(): | |
webcam_input = gr.Image(sources=["webcam"], streaming=True, type="numpy", label="Live Webcam", height=720) | |
output_image = gr.Image(label="Output", height=720) | |
alert_box = gr.Textbox(label="Alert", elem_id="alert-box", interactive=False) | |
report_box = gr.Textbox(label="Auto-Generated Report", elem_id="report-box", interactive=False) | |
# Button events | |
violence_btn.click(lambda: set_mode("Violence Detection"), outputs=alert_box) | |
weapon_btn.click(lambda: set_mode("Weapon Detection"), outputs=alert_box) | |
count_btn.click(lambda: set_mode("Person Counting"), outputs=alert_box) | |
track_btn.click(lambda: set_mode("Person Tracking"), outputs=alert_box) | |
# Live stream | |
webcam_input.stream( | |
live_inference, | |
inputs=[webcam_input], | |
outputs=[output_image, alert_box, report_box] | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |