|
import cv2 |
|
import torch |
|
import torch.backends.cudnn as cudnn |
|
from models.experimental import attempt_load |
|
from utils.general import non_max_suppression |
|
from torchvision import models |
|
from torchvision import transforms |
|
from PIL import Image |
|
import time |
|
import streamlit as st |
|
import IPython |
|
import numpy as np |
|
|
|
yolov5_weight_file = 'model100e.pt' |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
yolov5_model = attempt_load(yolov5_weight_file, device=device, inplace=True, fuse=True) |
|
cudnn.benchmark = True |
|
names = yolov5_model.module.names if hasattr(yolov5_model, 'module') else yolov5_model.names |
|
|
|
conf_set = 0.1 |
|
frame_size = (800, 480) |
|
|
|
colors = { |
|
'helmet': (255, 0, 0), |
|
'rider': (0, 255, 0), |
|
'number': (0, 0, 255), |
|
'no_helmet': (0, 100, 255), |
|
} |
|
|
|
def detect_objects(frame): |
|
img = torch.from_numpy(frame) |
|
img = img.permute(2, 0, 1).float().to(device) |
|
img /= 255.0 |
|
if img.ndimension() == 3: |
|
img = img.unsqueeze(0) |
|
with torch.no_grad(): |
|
pred = yolov5_model(img, augment=False)[0] |
|
pred = non_max_suppression(pred, conf_set, 0.30) |
|
detections = [] |
|
for det in pred: |
|
if len(det): |
|
for d in det: |
|
x1 = int(d[0].item()) |
|
y1 = int(d[1].item()) |
|
x2 = int(d[2].item()) |
|
y2 = int(d[3].item()) |
|
conf = round(d[4].item(), 2) |
|
c = int(d[5].item()) |
|
detected_name = names[c] |
|
detections.append((x1, y1, x2, y2, conf, detected_name)) |
|
|
|
color = colors.get(detected_name, (255, 255, 255)) |
|
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) |
|
cv2.putText(frame, detected_name, (x1, y1), cv2.FONT_HERSHEY_DUPLEX, 1, color, 2) |
|
|
|
return detections |
|
|
|
def display_detections(input_image, output_image, detections): |
|
for det in detections: |
|
x1, y1, x2, y2, conf, detected_name = det |
|
color = colors.get(detected_name, (255, 255, 255)) |
|
cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 2) |
|
cv2.putText(output_image, f"{detected_name} ({conf:.2f})", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) |
|
return output_image |
|
|
|
def app(): |
|
st.title("Helmet Detection App") |
|
st.write("This app uses YOLOv5 to detect helmets and riders in images and videos.") |
|
|
|
|
|
input_type = st.radio("Select input type:", options=["Image", "Video"]) |
|
|
|
|
|
if input_type == "Image": |
|
uploaded_file = st.file_uploader("Upload image", type=["jpg", "jpeg", "png"]) |
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file) |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
detections = detect_objects(np.array(image)) |
|
output_image = display_detections(np.array(image), np.array(image), detections) |
|
st.image(output_image, caption="Output Image", use_column_width=True) |
|
|
|
|
|
elif input_type == "Video": |
|
st.write("Select an option to get the input video:") |
|
video_option = st.radio("", options=["Webcam", "Upload video"]) |
|
|
|
if video_option == "Webcam": |
|
cap = cv2.VideoCapture(0) |
|
elif video_option == "Upload video": |
|
uploaded_file = st.file_uploader("Upload video", type=["mp4"]) |
|
if uploaded_file is not None: |
|
temp_file = NamedTemporaryFile(delete=False) |
|
temp_file.write(uploaded_file.read()) |
|
st.write("Video uploaded successfully!") |
|
cap = cv2.VideoCapture(temp_file.name) |
|
|
|
if 'cap' in locals(): |
|
frame_size = (800, 480) |
|
show_video = st.checkbox("Show video", value=True) |
|
save_video = st.checkbox("Save video", value=False) |
|
font = cv2.FONT_HERSHEY_DUPLEX |
|
|
|
while True: |
|
ret, frame = cap.read() |
|
if ret: |
|
frame = cv2.resize(frame, frame_size) |
|
detections = detect_objects(frame) |
|
display_frame = display_detections(frame, detections) |
|
fps = 1 / (time.time() - start_time) |
|
start_time = time.time() |
|
cv2.putText(display_frame, f'FPS: {fps:.2f}', (10, 30), font, 1, (0, 255, 0), 2, cv2.LINE_AA) |
|
if show_video: |
|
stframe.image(display_frame, channels="BGR") |
|
if save_video: |
|
out.write(display_frame) |
|
if cv2.waitKey(1) & 0xFF == ord('q'): |
|
break |
|
cap.release() |
|
if save_video: |
|
out.release() |
|
|
|
if __name__ == "__main__": |
|
app() |