ethanrom commited on
Commit
4b3d8ad
·
1 Parent(s): ed52a8d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +110 -0
  2. model100e.pt +3 -0
  3. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import torch.backends.cudnn as cudnn
4
+ from models.experimental import attempt_load
5
+ from utils.general import non_max_suppression
6
+ from torchvision import models
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ import time
10
+ import streamlit as st
11
+
12
+ yolov5_weight_file = 'model100e.pt'
13
+
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ yolov5_model = attempt_load(yolov5_weight_file, device=device, inplace=True, fuse=True)
16
+ cudnn.benchmark = True
17
+ names = yolov5_model.module.names if hasattr(yolov5_model, 'module') else yolov5_model.names
18
+
19
+ conf_set = 0.1
20
+ frame_size = (800, 480)
21
+
22
+ colors = {
23
+ 'helmet': (255, 0, 0),
24
+ 'rider': (0, 255, 0),
25
+ 'number': (0, 0, 255),
26
+ 'no_helmet': (0, 100, 255),
27
+ }
28
+
29
+ def detect_objects(frame):
30
+ img = torch.from_numpy(frame)
31
+ img = img.permute(2, 0, 1).float().to(device)
32
+ img /= 255.0
33
+ if img.ndimension() == 3:
34
+ img = img.unsqueeze(0)
35
+ with torch.no_grad():
36
+ pred = yolov5_model(img, augment=False)[0]
37
+ pred = non_max_suppression(pred, conf_set, 0.30)
38
+ detections = []
39
+ for det in pred:
40
+ if len(det):
41
+ for d in det: # d = (x1, y1, x2, y2, conf, cls)
42
+ x1 = int(d[0].item())
43
+ y1 = int(d[1].item())
44
+ x2 = int(d[2].item())
45
+ y2 = int(d[3].item())
46
+ conf = round(d[4].item(), 2)
47
+ c = int(d[5].item())
48
+ detected_name = names[c]
49
+ detections.append((x1, y1, x2, y2, conf, detected_name))
50
+
51
+ color = colors.get(detected_name, (255, 255, 255))
52
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
53
+ cv2.putText(frame, detected_name, (x1, y1), cv2.FONT_HERSHEY_DUPLEX, 1, color, 2)
54
+
55
+ return detections
56
+
57
+ def app():
58
+ st.title("Helmet Detection App")
59
+ st.write("This app uses YOLOv5 to detect helmets and riders in images and videos.")
60
+
61
+ # Select input type
62
+ input_type = st.radio("Select input type:", options=["Image", "Video"])
63
+
64
+ # Upload file or use webcam
65
+ if input_type == "Image":
66
+ uploaded_file = st.file_uploader("Upload image", type=["jpg", "jpeg", "png"])
67
+ if uploaded_file is not None:
68
+ image = Image.open(uploaded_file)
69
+ st.image(image, caption="Uploaded Image", use_column_width=True)
70
+ detections = detect_objects(np.array(image))
71
+ display_detections(image, detections)
72
+
73
+ elif input_type == "Video":
74
+ st.write("Select an option to get the input video:")
75
+ video_option = st.radio("", options=["Webcam", "Upload video"])
76
+
77
+ if video_option == "Webcam":
78
+ cap = cv2.VideoCapture(0)
79
+ elif video_option == "Upload video":
80
+ uploaded_file = st.file_uploader("Upload video", type=["mp4"])
81
+ if uploaded_file is not None:
82
+ temp_file = NamedTemporaryFile(delete=False)
83
+ temp_file.write(uploaded_file.read())
84
+ st.write("Video uploaded successfully!")
85
+ cap = cv2.VideoCapture(temp_file.name)
86
+
87
+ if 'cap' in locals():
88
+ frame_size = (800, 480)
89
+ show_video = st.checkbox("Show video", value=True)
90
+ save_video = st.checkbox("Save video", value=False)
91
+ font = cv2.FONT_HERSHEY_DUPLEX
92
+
93
+ while True:
94
+ ret, frame = cap.read()
95
+ if ret:
96
+ frame = cv2.resize(frame, frame_size)
97
+ detections = detect_objects(frame)
98
+ display_frame = display_detections(frame, detections)
99
+ fps = 1 / (time.time() - start_time)
100
+ start_time = time.time()
101
+ cv2.putText(display_frame, f'FPS: {fps:.2f}', (10, 30), font, 1, (0, 255, 0), 2, cv2.LINE_AA)
102
+ if show_video:
103
+ stframe.image(display_frame, channels="BGR")
104
+ if save_video:
105
+ out.write(display_frame)
106
+ if cv2.waitKey(1) & 0xFF == ord('q'):
107
+ break
108
+ cap.release()
109
+ if save_video:
110
+ out.release()
model100e.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a9d58c38bf2cf302bc1a4d75e42e16de1c9dabe3d7fe94d8b415a08d4e5857a
3
+ size 3906877
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ Pillow # PIL
2
+ opencv-python
3
+ torch
4
+ torchvision
5
+ numpy
6
+ tqdm
7
+ pandas
8
+ matplotlib