Spaces:
Sleeping
Sleeping
import warnings | |
warnings.filterwarnings("ignore") | |
import gradio as gr | |
import cv2 | |
import numpy as np | |
import json | |
import os | |
from datetime import datetime | |
from ultralytics import YOLO | |
from insightface.app import FaceAnalysis | |
import torchreid | |
import torch | |
import logging | |
import shutil | |
import tempfile | |
import uuid | |
# ========== Logging Configuration ========== | |
logging.basicConfig( | |
level=logging.INFO, | |
format='[%(asctime)s] [%(levelname)s] %(message)s', | |
handlers=[ | |
logging.FileHandler("app.log"), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# ========== Configuration ========== | |
DETECTION_THRESHOLD = 0.75 | |
# Create output directory for Gradio | |
OUTPUT_DIR = os.path.join(os.getcwd(), "outputs") | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
# ========== Video Processing Class ========== | |
class VideoProcessor: | |
def __init__(self): | |
try: | |
self.model = YOLO('detection.pt') | |
self.face_app = FaceAnalysis(name='buffalo_l', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) | |
self.face_app.prepare(ctx_id=0) | |
self.reid_extractor = torchreid.utils.FeatureExtractor( | |
model_name='osnet_x0_25', | |
model_path=None, | |
device='cuda' if torch.cuda.is_available() else 'cpu' | |
) | |
self.models_loaded = True | |
logger.info("Models loaded successfully.") | |
except Exception as e: | |
logger.exception("Model loading failed.") | |
self.models_loaded = False | |
self.reset_tracking() | |
def reset_tracking(self): | |
self.known_embeddings = [] | |
self.known_ids = [] | |
self.next_global_id = 1 | |
self.track_to_global = {} | |
self.tracking_data = { | |
"metadata": { | |
"total_frames": 0, | |
"total_people": 0, | |
"id_mapping": {} | |
}, | |
"frames": [] | |
} | |
logger.info("Tracking state reset.") | |
def extract_embeddings(self, person_crop): | |
face_embedding, body_embedding = None, None | |
try: | |
faces = self.face_app.get(person_crop) | |
if faces: | |
face_embedding = faces[0].embedding | |
except Exception: | |
logger.debug("Face embedding failed.") | |
try: | |
body_input = cv2.resize(person_crop, (128, 256)) | |
body_input = cv2.cvtColor(body_input, cv2.COLOR_BGR2RGB) | |
body_embedding = self.reid_extractor(body_input)[0].cpu().numpy() | |
except Exception: | |
logger.debug("Body embedding failed.") | |
if face_embedding is not None and body_embedding is not None: | |
return np.concatenate((face_embedding, body_embedding)).astype(np.float32) | |
elif face_embedding is not None: | |
return face_embedding.astype(np.float32) | |
elif body_embedding is not None: | |
return body_embedding.astype(np.float32) | |
return None | |
def assign_global_id(self, embedding, track_id): | |
if embedding is None: | |
return self.track_to_global.get(track_id, f"T{track_id}") | |
match_found = False | |
if self.known_embeddings: | |
matching_embeddings = [ | |
(emb, gid) for emb, gid in zip(self.known_embeddings, self.known_ids) | |
if emb.shape[0] == embedding.shape[0] | |
] | |
if matching_embeddings: | |
embs, gids = zip(*matching_embeddings) | |
embs = np.array(embs) | |
sims = np.dot(embs, embedding) / ( | |
np.linalg.norm(embs, axis=1) * np.linalg.norm(embedding) + 1e-6 | |
) | |
best_match = np.argmax(sims) | |
if sims[best_match] > 0.6: | |
global_id = gids[best_match] | |
match_found = True | |
if not match_found: | |
global_id = self.next_global_id | |
self.next_global_id += 1 | |
self.known_embeddings.append(embedding) | |
self.known_ids.append(global_id) | |
if track_id is not None: | |
self.track_to_global[track_id] = global_id | |
return global_id | |
def process_video(self, input_video_path, progress_callback=None): | |
if not self.models_loaded: | |
raise Exception("Models not loaded properly") | |
self.reset_tracking() | |
# Create output files with timestamp | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
unique_id = str(uuid.uuid4())[:8] | |
# Use the OUTPUT_DIR instead of temp directory | |
output_video_path = os.path.join(OUTPUT_DIR, f"tracked_video_{timestamp}_{unique_id}.mp4") | |
output_json_path = os.path.join(OUTPUT_DIR, f"tracking_data_{timestamp}_{unique_id}.json") | |
cap = cv2.VideoCapture(input_video_path) | |
if not cap.isOpened(): | |
raise Exception("Could not open video file") | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# Use H.264 codec for better compatibility and add proper video codec | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Changed from 'mp4v' to 'H264' | |
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) | |
# Verify video writer is properly initialized | |
if not out.isOpened(): | |
logger.warning("H264 codec failed, trying XVID") | |
fourcc = cv2.VideoWriter_fourcc(*'XVID') | |
output_video_path = output_video_path.replace('.mp4', '.avi') | |
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) | |
if not out.isOpened(): | |
logger.warning("XVID codec failed, trying mp4v") | |
fourcc = cv2.VideoWriter_fourcc(*'H264') | |
output_video_path = output_video_path.replace('.avi', '.mp4') | |
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) | |
frame_count = 0 | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame_count += 1 | |
if progress_callback: | |
progress_callback(frame_count / total_frames, f"Processing frame {frame_count}/{total_frames}") | |
frame_data = {"frame": frame_count, "people": []} | |
try: | |
results = self.model.track( | |
frame, tracker="bytetrack.yaml", persist=True, verbose=False, conf=DETECTION_THRESHOLD | |
) | |
for result in results: | |
if result.boxes is not None: | |
boxes = result.boxes.xyxy.cpu().numpy() | |
confidences = result.boxes.conf.cpu().numpy() | |
track_ids = result.boxes.id.int().cpu().tolist() if result.boxes.id is not None else [None] * len(boxes) | |
for box, conf, track_id in zip(boxes, confidences, track_ids): | |
x1, y1, x2, y2 = map(int, box) | |
person_crop = frame[y1:y2, x1:x2] | |
if person_crop.size > 0: | |
embedding = self.extract_embeddings(person_crop) | |
global_id = self.assign_global_id(embedding, track_id) | |
frame_data["people"].append({ | |
"person_id": global_id, | |
"center_x": (x1 + x2) / 2, | |
"center_y": (y1 + y2) / 2, | |
"confidence": float(conf), | |
"bbox": {"x1": float(x1), "y1": float(y1), "x2": float(x2), "y2": float(y2)} | |
}) | |
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
cv2.putText(frame, f"ID {global_id}", (x1, y1 - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) | |
except Exception as e: | |
logger.exception(f"Error processing frame {frame_count}.") | |
self.tracking_data["frames"].append(frame_data) | |
out.write(frame) | |
cap.release() | |
out.release() | |
# Verify the output file was created and has content | |
if not os.path.exists(output_video_path) or os.path.getsize(output_video_path) == 0: | |
raise Exception("Output video file was not created properly") | |
self.tracking_data["metadata"]["total_frames"] = frame_count | |
self.tracking_data["metadata"]["total_people"] = len(set(self.known_ids)) | |
self.tracking_data["metadata"]["id_mapping"] = {str(k): v for k, v in self.track_to_global.items()} | |
# Save JSON file | |
with open(output_json_path, 'w') as f: | |
json.dump(self.tracking_data, f, indent=2) | |
logger.info(f"Video processing completed. Saved to {output_video_path}") | |
logger.info(f"Video file size: {os.path.getsize(output_video_path)} bytes") | |
return output_video_path, output_json_path | |
# ========== Processor ========== | |
processor = VideoProcessor() | |
# ========== Gradio Handler ========== | |
def process_video_gradio(input_video, progress=gr.Progress()): | |
if input_video is None: | |
return None, None, "Please upload a video file." | |
try: | |
def progress_callback(prog, message): | |
progress(prog, desc=message) | |
# Process video | |
output_video_path, output_json_path = processor.process_video(input_video, progress_callback) | |
# Verify files exist and are accessible | |
if not os.path.exists(output_video_path): | |
raise Exception(f"Output video not found at {output_video_path}") | |
if not os.path.exists(output_json_path): | |
raise Exception(f"Output JSON not found at {output_json_path}") | |
# Read tracking data for stats | |
with open(output_json_path, 'r') as f: | |
data = json.load(f) | |
stats = f""" | |
**Processing Complete!** β | |
- **Total Frames Processed:** {data['metadata']['total_frames']} | |
- **Total People Detected:** {data['metadata']['total_people']} | |
- **Unique IDs Assigned:** {len(data['metadata']['id_mapping'])} | |
- **Output Video Size:** {os.path.getsize(output_video_path) / (1024*1024):.1f} MB | |
πΉ **Output video** is ready for download | |
π **JSON tracking data** contains frame-by-frame detection results | |
""" | |
logger.info(f"Returning video path: {output_video_path}") | |
logger.info(f"Video exists: {os.path.exists(output_video_path)}") | |
return output_video_path, output_json_path, stats | |
except Exception as e: | |
logger.exception("Video processing failed.") | |
return None, None, f"β **Error processing video:** {str(e)}" | |
# ========== Gradio Interface ========== | |
def create_interface(): | |
with gr.Blocks(title="Video Person Detection & Tracking", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π₯ Video Person Detection & Tracking with ReID") | |
gr.Markdown("Upload a video to detect and track people using YOLOv8, InsightFace, and ReID models for consistent person identification across frames.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_video = gr.Video( | |
label="π Upload Input Video", | |
height=400, | |
interactive=True | |
) | |
process_btn = gr.Button( | |
"π Process Video", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=1): | |
output_video = gr.Video( | |
label="π¬ Processed Video (with tracking)", | |
height=400, | |
interactive=False, | |
show_download_button=True # Enable download button | |
) | |
download_json = gr.File( | |
label="π Download Tracking Data (JSON)", | |
interactive=False | |
) | |
with gr.Row(): | |
status_text = gr.Markdown("π€ Upload a video and click **'Process Video'** to start tracking people.") | |
# Event handler | |
process_btn.click( | |
fn=process_video_gradio, | |
inputs=[input_video], | |
outputs=[output_video, download_json, status_text], | |
show_progress=True | |
) | |
# Additional information | |
with gr.Accordion("π How it works", open=False): | |
gr.Markdown(""" | |
### π§ **Technology Stack:** | |
- **YOLOv8:** Real-time person detection | |
- **ByteTrack:** Multi-object tracking algorithm | |
- **InsightFace:** Facial feature extraction for person identification | |
- **OSNet:** Full-body re-identification features | |
### π **Process:** | |
1. **Detection:** YOLOv8 detects people in each frame | |
2. **Tracking:** ByteTrack assigns temporary tracking IDs | |
3. **Feature Extraction:** InsightFace + OSNet extract identifying features | |
4. **Re-identification:** Combines face and body features for consistent global IDs | |
5. **Output:** Generates annotated video + detailed JSON tracking data | |
### π **Supported Formats:** | |
- **Input:** MP4, AVI, MOV, WEBM | |
- **Output:** MP4 video + JSON metadata | |
""") | |
with gr.Accordion("βοΈ Model Configuration", open=False): | |
gr.Markdown(f""" | |
- **Detection Threshold:** {DETECTION_THRESHOLD} | |
- **Similarity Threshold:** 0.6 (for person re-identification) | |
- **Device:** {"CUDA" if torch.cuda.is_available() else "CPU"} | |
- **Output Directory:** {OUTPUT_DIR} | |
""") | |
with gr.Accordion("π§ Troubleshooting", open=False): | |
gr.Markdown(""" | |
**If video doesn't display:** | |
1. Check if the output file exists in the outputs directory | |
2. Try downloading the video manually | |
3. Ensure proper video codec support | |
**Common issues:** | |
- Large video files may take time to load | |
- Some browsers may not support certain video formats | |
- Network issues can affect video streaming | |
""") | |
return demo | |
# ========== Launch ========== | |
if __name__ == "__main__": | |
demo = create_interface() | |
# Add file serving for outputs directory | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
share=True | |
) |