reeeeemo
commited on
Commit
·
019f9fc
0
Parent(s):
Added YOLO model
Browse files- .gitattributes +2 -0
- .gitignore +21 -0
- models/best.onnx +3 -0
- models/best.pt +3 -0
- src/detectobjects.py +79 -0
- src/vidprocessing.py +70 -0
.gitattributes
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
.Python
|
6 |
+
env/
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
eggs/
|
11 |
+
*.egg-info/
|
12 |
+
|
13 |
+
# Data
|
14 |
+
data/
|
15 |
+
|
16 |
+
# Models
|
17 |
+
runs/
|
18 |
+
yolov8n.pt
|
19 |
+
|
20 |
+
# Visual Studio
|
21 |
+
*.vs
|
models/best.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e5511f2d7d2c6d85e3011c133c3284025c3bfb4ab2e39c9ef70af1a0d0e8e7ea
|
3 |
+
size 12274092
|
models/best.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fbb4fcbd99e83e1fae188d27a443754981e271e1a0a8b3e4903b40f94014075b
|
3 |
+
size 6244707
|
src/detectobjects.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ultralytics import YOLO
|
2 |
+
from ultralytics.utils.plotting import Annotator, colors
|
3 |
+
import cv2
|
4 |
+
from typing import List
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class ObjectDetector():
|
8 |
+
def __init__(self, pretrained_model: str = 'yolov8n.pt', debug: bool = False):
|
9 |
+
self.model = YOLO(pretrained_model)
|
10 |
+
self.debug = debug
|
11 |
+
|
12 |
+
self.color_map = {
|
13 |
+
0: (0, 255, 0), # player: green
|
14 |
+
1: (255, 128, 0), # Storm Timer: light blue
|
15 |
+
2: (0, 0, 255), # Killfeed: red
|
16 |
+
3: (255, 0, 0), # Player Count: blue
|
17 |
+
4: (0, 255, 255), # Minimap: yellow
|
18 |
+
5: (128, 0, 255), # Storm Shrink Warning: dark red
|
19 |
+
6: (255, 0, 255), # Eliminations: magenta
|
20 |
+
7: (0, 128, 255), # Health: orange
|
21 |
+
8: (255, 255, 0), # Shield: cyan
|
22 |
+
9: (128, 255, 0), # Inventory: light green
|
23 |
+
10: (0, 165, 255), # Buildings: orange-yellow
|
24 |
+
11: (139, 69, 19), # Wood Material: brown
|
25 |
+
12: (128, 128, 128),# Brick Material: gray
|
26 |
+
13: (192, 192, 192),# Metal Material: light gray
|
27 |
+
14: (255, 191, 0), # Compass: deep sky blue
|
28 |
+
15: (255, 0, 128), # Equipped Item: purple
|
29 |
+
16: (0, 255, 191), # Waypoint: yellow-green
|
30 |
+
17: (128, 128, 0), # Sprint Meter: teal
|
31 |
+
18: (0, 140, 255), # Safe Zone: orange-red
|
32 |
+
19: (0, 215, 255), # playerIcon: gold
|
33 |
+
20: (34, 139, 34), # Tree: forest green
|
34 |
+
21: (75, 75, 75), # Stone: dark gray
|
35 |
+
22: (0, 69, 255), # Building: orange-red
|
36 |
+
23: (122, 61, 0), # Wood Building: dark brown
|
37 |
+
24: (108, 108, 108),# Stone Building: medium gray
|
38 |
+
25: (211, 211, 211), # Metal Building: silver
|
39 |
+
26: (0, 43, 27), # Wall: dark green
|
40 |
+
27: (22, 22, 22), # Ramp: dark gray
|
41 |
+
28: (17, 211, 0), # Pyramid: bright green
|
42 |
+
29: (121, 132, 9) # Floor: olive green
|
43 |
+
}
|
44 |
+
|
45 |
+
def train_model(self, yaml_filepath):
|
46 |
+
self.model.train(data=yaml_filepath, epochs=100, imgsz=640, batch=16, patience=50)
|
47 |
+
|
48 |
+
def detect_object(self, frames: List[np.ndarray]):
|
49 |
+
for frame in frames:
|
50 |
+
results = self.model.track(frame, stream=True)
|
51 |
+
|
52 |
+
for result in results:
|
53 |
+
class_names = result.names
|
54 |
+
annotated_frame = frame.copy()
|
55 |
+
|
56 |
+
for box in result.boxes:
|
57 |
+
if box.conf[0] > 0.4:
|
58 |
+
[x1, y1, x2, y2] = box.xyxy[0] # coords
|
59 |
+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
60 |
+
|
61 |
+
cls = int(box.cls[0]) # class
|
62 |
+
color = self.color_map.get(cls, (0,255,0))
|
63 |
+
|
64 |
+
cv2.rectangle(annotated_frame, (x1,y1), (x2,y2), color=color, thickness=2)
|
65 |
+
|
66 |
+
text = f'{class_names[cls]} {box.conf[0]:.2f}'
|
67 |
+
(text_width, text_height), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
|
68 |
+
cv2.rectangle(annotated_frame, (x1, y1-text_height-5), (x1+text_width, y1), color, -1)
|
69 |
+
|
70 |
+
cv2.putText(annotated_frame, text, (x1, y1-5),
|
71 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), thickness=2)
|
72 |
+
while self.debug:
|
73 |
+
cv2.imshow('frame', annotated_frame)
|
74 |
+
if cv2.waitKey(1) & 0xFF == ord('q'):
|
75 |
+
break
|
76 |
+
cv2.destroyAllWindows()
|
77 |
+
|
78 |
+
def export_model(self, format: str = 'onnx'):
|
79 |
+
self.model.export(format=format)
|
src/vidprocessing.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
from typing import List
|
3 |
+
from os.path import dirname, abspath
|
4 |
+
from pathlib import Path
|
5 |
+
import numpy as np
|
6 |
+
from detectobjects import ObjectDetector
|
7 |
+
import yaml
|
8 |
+
|
9 |
+
base_dir = Path(dirname(dirname(abspath(__file__))))
|
10 |
+
|
11 |
+
def get_video(path: str) -> cv2.VideoCapture:
|
12 |
+
video = cv2.VideoCapture(path)
|
13 |
+
if not video.isOpened():
|
14 |
+
raise ValueError(f'Could not open video file: {path}')
|
15 |
+
|
16 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
17 |
+
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
18 |
+
duration = frame_count/fps
|
19 |
+
|
20 |
+
print(f'FPS: {fps}\nFrame Count: {frame_count}\nDuration: {duration}')
|
21 |
+
|
22 |
+
return video
|
23 |
+
|
24 |
+
def get_frames(video: cv2.VideoCapture, frame_start: int, frame_end: int) -> List[np.ndarray]:
|
25 |
+
|
26 |
+
frames = []
|
27 |
+
for i in range(frame_start, frame_end+1):
|
28 |
+
video.set(cv2.CAP_PROP_POS_FRAMES, i)
|
29 |
+
ret, frame = video.read()
|
30 |
+
if not ret:
|
31 |
+
raise ValueError(f'Could not read frame {i}')
|
32 |
+
|
33 |
+
frames.append(frame)
|
34 |
+
|
35 |
+
return frames
|
36 |
+
|
37 |
+
|
38 |
+
def create_images_of_video(video: cv2.VideoCapture, interval: int = 100):
|
39 |
+
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
40 |
+
|
41 |
+
# read frame_count / interval of frames
|
42 |
+
for i in range(0, frame_count, interval):
|
43 |
+
video.set(cv2.CAP_PROP_POS_FRAMES, i)
|
44 |
+
success, frame = video.read()
|
45 |
+
if not success:
|
46 |
+
raise ValueError(f'Could not read frame {i}')
|
47 |
+
|
48 |
+
# write photo to file
|
49 |
+
cv2.imwrite(base_dir / "data" / 'model_data' / 'temp_vid_folder' /f'{i}.png', frame)
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
vid = get_video(base_dir / "data" / "video_data" / "fortnite_remo_three.mp4")
|
54 |
+
frames = get_frames(vid, 21100, 21110)
|
55 |
+
|
56 |
+
### PRETRAINED MODEL DETECTION CODE
|
57 |
+
yolo = ObjectDetector(pretrained_model=(base_dir / 'best.pt'), debug=True)
|
58 |
+
yolo.detect_object(frames)
|
59 |
+
|
60 |
+
|
61 |
+
### TRAINING CODE FOR YOLO MODEL
|
62 |
+
# yolo = ObjectDetector()
|
63 |
+
# yolo.train_model(base_dir / 'data' / 'model_data' / 'fortnite_train.yaml')
|
64 |
+
|
65 |
+
### VIDEO FRAME CUTTING CODE
|
66 |
+
# create_images_of_video(vid)
|
67 |
+
|
68 |
+
### MODEL EXPORT + VIDEO RELEASE
|
69 |
+
yolo.export_model()
|
70 |
+
vid.release()
|