Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import cv2 | |
import tqdm | |
import uuid | |
import logging | |
import torch | |
import spaces | |
import trackers | |
import numpy as np | |
import gradio as gr | |
import imageio.v3 as iio | |
import supervision as sv | |
from pathlib import Path | |
from functools import lru_cache | |
from typing import List, Optional, Tuple | |
from transformers import AutoModelForObjectDetection, AutoImageProcessor | |
# Configuration constants | |
CHECKPOINTS = [ | |
"ustc-community/dfine-xlarge-obj2coco" | |
] | |
DEFAULT_CHECKPOINT = CHECKPOINTS[0] | |
DEFAULT_CONFIDENCE_THRESHOLD = 0.3 | |
TORCH_DTYPE = torch.float32 | |
# Video | |
MAX_NUM_FRAMES = 250 | |
BATCH_SIZE = 4 | |
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov"} | |
VIDEO_OUTPUT_DIR = Path("static/videos") | |
VIDEO_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
class TrackingAlgorithm: | |
BYTETRACK = "ByteTrack (2021)" | |
DEEPSORT = "DeepSORT (2017)" | |
SORT = "SORT (2016)" | |
# Create a color palette for visualization | |
# These hex color codes define different colors for tracking different objects | |
color = sv.ColorPalette.from_hex([ | |
"#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff", | |
"#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00" | |
]) | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
def get_model_and_processor(checkpoint: str): | |
model = AutoModelForObjectDetection.from_pretrained(checkpoint, torch_dtype=TORCH_DTYPE) | |
image_processor = AutoImageProcessor.from_pretrained(checkpoint) | |
return model, image_processor | |
def detect_objects( | |
images: List[np.ndarray] | np.ndarray, | |
target_size: Optional[Tuple[int, int]] = None, | |
batch_size: int = BATCH_SIZE | |
): | |
checkpoint = "ustc-community/dfine-xlarge-obj2coco" | |
confidence_threshold = 0.3 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, image_processor = get_model_and_processor(checkpoint) | |
model = model.to(device) | |
classes = ["person","aeroplane","bicycle","car","motorbike","bus","train","truck","boat"] | |
if classes is not None: | |
wrong_classes = [cls for cls in classes if cls not in model.config.label2id] | |
if wrong_classes: | |
gr.Warning(f"Classes not found in model config") | |
keep_ids = [model.config.label2id[cls] for cls in classes if cls in model.config.label2id] | |
else: | |
keep_ids = None | |
if isinstance(images, np.ndarray) and images.ndim == 4: | |
images = [x for x in images] | |
batches = [images[i:i + batch_size] for i in range(0, len(images), batch_size)] | |
results = [] | |
for batch in tqdm.tqdm(batches, desc="Processing frames"): | |
# preprocess images | |
inputs = image_processor(images=batch, return_tensors="pt") | |
inputs = inputs.to(device).to(TORCH_DTYPE) | |
# forward pass | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# postprocess outputs | |
if target_size: | |
target_sizes = [target_size] * len(batch) | |
else: | |
target_sizes = [(image.shape[0], image.shape[1]) for image in batch] | |
batch_results = image_processor.post_process_object_detection( | |
outputs, target_sizes=target_sizes, threshold=confidence_threshold | |
) | |
results.extend(batch_results) | |
# move results to cpu | |
for i, result in enumerate(results): | |
results[i] = {k: v.cpu() for k, v in result.items()} | |
if keep_ids is not None: | |
keep = torch.isin(results[i]["labels"], torch.tensor(keep_ids)) | |
results[i] = {k: v[keep] for k, v in results[i].items()} | |
return results, model.config.id2label | |
def get_target_size(image_height, image_width, max_size: int): | |
if image_height < max_size and image_width < max_size: | |
new_height, new_width = image_height, image_width | |
elif image_height > image_width: | |
new_height = max_size | |
new_width = int(image_width * max_size / image_height) | |
else: | |
new_width = max_size | |
new_height = int(image_height * max_size / image_width) | |
# make even (for video codec compatibility) | |
new_height = new_height // 2 * 2 | |
new_width = new_width // 2 * 2 | |
return new_width, new_height | |
def read_video_k_frames(video_path: str, k: int, read_every_i_frame: int = 1): | |
cap = cv2.VideoCapture(video_path) | |
frames = [] | |
i = 0 | |
progress_bar = tqdm.tqdm(total=k, desc="Reading frames") | |
while cap.isOpened() and len(frames) < k: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if i % read_every_i_frame == 0: | |
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
progress_bar.update(1) | |
i += 1 | |
cap.release() | |
progress_bar.close() | |
return frames | |
def get_tracker(fps: float): | |
tracker = TrackingAlgorithm.BYTETRACK | |
if tracker == TrackingAlgorithm.SORT: | |
return trackers.SORTTracker(frame_rate=fps) | |
elif tracker == TrackingAlgorithm.DEEPSORT: | |
feature_extractor = trackers.DeepSORTFeatureExtractor.from_timm("mobilenetv4_conv_small.e1200_r224_in1k", | |
device="cpu") | |
return trackers.DeepSORTTracker(feature_extractor, frame_rate=fps) | |
elif tracker == TrackingAlgorithm.BYTETRACK: | |
return sv.ByteTrack(frame_rate=int(fps)) | |
else: | |
raise ValueError(f"Invalid tracker: {tracker}") | |
def update_tracker(tracker, detections, frame): | |
tracker_name = tracker.__class__.__name__ | |
if tracker_name == "SORTTracker": | |
return tracker.update(detections) | |
elif tracker_name == "DeepSORTTracker": | |
return tracker.update(detections, frame) | |
elif tracker_name == "ByteTrack": | |
return tracker.update_with_detections(detections) | |
else: | |
raise ValueError(f"Invalid tracker: {tracker}") | |
def process_video( | |
video_path: str, | |
tracker_algorithm: Optional[str] = None, | |
progress: gr.Progress = gr.Progress(track_tqdm=True), | |
) -> str: | |
if not video_path or not os.path.isfile(video_path): | |
raise ValueError(f"Invalid video path: {video_path}") | |
ext = os.path.splitext(video_path)[1].lower() | |
if ext not in ALLOWED_VIDEO_EXTENSIONS: | |
raise ValueError(f"Unsupported video format: {ext}, supported formats: {ALLOWED_VIDEO_EXTENSIONS}") | |
video_info = sv.VideoInfo.from_video_path(video_path) | |
read_each_i_frame = max(1, video_info.fps // 25) | |
target_fps = video_info.fps / read_each_i_frame | |
target_width, target_height = get_target_size(video_info.height, video_info.width, 1080) | |
n_frames_to_read = min(MAX_NUM_FRAMES, video_info.total_frames // read_each_i_frame) | |
frames = read_video_k_frames(video_path, n_frames_to_read, read_each_i_frame) | |
frames = [cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_CUBIC) for frame in frames] | |
# Set the color lookup mode to assign colors by track ID | |
# This mean objects with the same track ID will be annotated by the same color | |
color_lookup = sv.ColorLookup.TRACK if tracker_algorithm else sv.ColorLookup.CLASS | |
box_annotator = sv.BoxAnnotator(color, color_lookup=color_lookup, thickness=1) | |
label_annotator = sv.LabelAnnotator(color, color_lookup=color_lookup, text_scale=0.5) | |
trace_annotator = sv.TraceAnnotator(color, color_lookup=color_lookup, thickness=1, trace_length=100) | |
results, id2label = detect_objects( | |
images=np.array(frames), | |
target_size=(target_height, target_width), | |
) | |
annotated_frames = [] | |
# detections | |
if tracker_algorithm: | |
tracker = get_tracker(tracker_algorithm, target_fps) | |
for frame, result in progress.tqdm(zip(frames, results), desc="Tracking objects", total=len(frames)): | |
detections = sv.Detections.from_transformers(result, id2label=id2label) | |
detections = detections.with_nms(threshold=0.95, class_agnostic=True) | |
detections = update_tracker(tracker, detections, frame) | |
labels = [f"#{tracker_id} {id2label[class_id]}" for class_id, tracker_id in | |
zip(detections.class_id, detections.tracker_id)] | |
annotated_frame = box_annotator.annotate(scene=frame, detections=detections) | |
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) | |
annotated_frame = trace_annotator.annotate(scene=annotated_frame, detections=detections) | |
annotated_frames.append(annotated_frame) | |
else: | |
for frame, result in tqdm.tqdm(zip(frames, results), desc="Annotating frames", total=len(frames)): | |
detections = sv.Detections.from_transformers(result, id2label=id2label) | |
detections = detections.with_nms(threshold=0.95, class_agnostic=True) | |
annotated_frame = box_annotator.annotate(scene=frame, detections=detections) | |
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections) | |
annotated_frames.append(annotated_frame) | |
output_filename = os.path.join(VIDEO_OUTPUT_DIR, f"output_{uuid.uuid4()}.mp4") | |
iio.imwrite(output_filename, annotated_frames, fps=target_fps, codec="h264") | |
return output_filename | |
def create_video_inputs() -> List[gr.components.Component]: | |
return [ | |
gr.Video( | |
label="Upload Video", | |
sources=["upload"], | |
interactive=True, | |
format="mp4", # Ensure MP4 format | |
elem_classes="input-component", | |
) | |
] | |
def create_button_row() -> List[gr.Button]: | |
return [ | |
gr.Button( | |
f"Detect Objects", variant="primary", elem_classes="action-button" | |
), | |
gr.Button(f"Clear", variant="secondary", elem_classes="action-button"), | |
] | |
# Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# Vehicle & People Detection Demo | |
## Input your video and see the detected! | |
""", | |
elem_classes="header-text", | |
) | |
with gr.Tabs(): | |
with gr.Tab("Video"): | |
gr.Markdown( | |
f"The input video will be processed in ~25 FPS (up to {MAX_NUM_FRAMES} frames in result)." | |
) | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=300): | |
with gr.Group(): | |
video_input = create_video_inputs()[0] | |
video_detect_button, video_clear_button = create_button_row() | |
with gr.Column(scale=2): | |
video_output = gr.Video( | |
label="Detection Results", | |
format="mp4", # Explicit MP4 format | |
elem_classes="output-component", | |
) | |
video_clear_button.click( | |
fn=lambda: (None,None), | |
outputs=[ | |
video_input, | |
video_output | |
] | |
) | |
video_detect_button.click( | |
fn=process_video, | |
inputs=[video_input], | |
outputs=[video_output], | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() |