reeeeemo commited on
Commit
019f9fc
·
0 Parent(s):

Added YOLO model

Browse files
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.onnx filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ .Python
6
+ env/
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ eggs/
11
+ *.egg-info/
12
+
13
+ # Data
14
+ data/
15
+
16
+ # Models
17
+ runs/
18
+ yolov8n.pt
19
+
20
+ # Visual Studio
21
+ *.vs
models/best.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5511f2d7d2c6d85e3011c133c3284025c3bfb4ab2e39c9ef70af1a0d0e8e7ea
3
+ size 12274092
models/best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbb4fcbd99e83e1fae188d27a443754981e271e1a0a8b3e4903b40f94014075b
3
+ size 6244707
src/detectobjects.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ from ultralytics.utils.plotting import Annotator, colors
3
+ import cv2
4
+ from typing import List
5
+ import numpy as np
6
+
7
+ class ObjectDetector():
8
+ def __init__(self, pretrained_model: str = 'yolov8n.pt', debug: bool = False):
9
+ self.model = YOLO(pretrained_model)
10
+ self.debug = debug
11
+
12
+ self.color_map = {
13
+ 0: (0, 255, 0), # player: green
14
+ 1: (255, 128, 0), # Storm Timer: light blue
15
+ 2: (0, 0, 255), # Killfeed: red
16
+ 3: (255, 0, 0), # Player Count: blue
17
+ 4: (0, 255, 255), # Minimap: yellow
18
+ 5: (128, 0, 255), # Storm Shrink Warning: dark red
19
+ 6: (255, 0, 255), # Eliminations: magenta
20
+ 7: (0, 128, 255), # Health: orange
21
+ 8: (255, 255, 0), # Shield: cyan
22
+ 9: (128, 255, 0), # Inventory: light green
23
+ 10: (0, 165, 255), # Buildings: orange-yellow
24
+ 11: (139, 69, 19), # Wood Material: brown
25
+ 12: (128, 128, 128),# Brick Material: gray
26
+ 13: (192, 192, 192),# Metal Material: light gray
27
+ 14: (255, 191, 0), # Compass: deep sky blue
28
+ 15: (255, 0, 128), # Equipped Item: purple
29
+ 16: (0, 255, 191), # Waypoint: yellow-green
30
+ 17: (128, 128, 0), # Sprint Meter: teal
31
+ 18: (0, 140, 255), # Safe Zone: orange-red
32
+ 19: (0, 215, 255), # playerIcon: gold
33
+ 20: (34, 139, 34), # Tree: forest green
34
+ 21: (75, 75, 75), # Stone: dark gray
35
+ 22: (0, 69, 255), # Building: orange-red
36
+ 23: (122, 61, 0), # Wood Building: dark brown
37
+ 24: (108, 108, 108),# Stone Building: medium gray
38
+ 25: (211, 211, 211), # Metal Building: silver
39
+ 26: (0, 43, 27), # Wall: dark green
40
+ 27: (22, 22, 22), # Ramp: dark gray
41
+ 28: (17, 211, 0), # Pyramid: bright green
42
+ 29: (121, 132, 9) # Floor: olive green
43
+ }
44
+
45
+ def train_model(self, yaml_filepath):
46
+ self.model.train(data=yaml_filepath, epochs=100, imgsz=640, batch=16, patience=50)
47
+
48
+ def detect_object(self, frames: List[np.ndarray]):
49
+ for frame in frames:
50
+ results = self.model.track(frame, stream=True)
51
+
52
+ for result in results:
53
+ class_names = result.names
54
+ annotated_frame = frame.copy()
55
+
56
+ for box in result.boxes:
57
+ if box.conf[0] > 0.4:
58
+ [x1, y1, x2, y2] = box.xyxy[0] # coords
59
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
60
+
61
+ cls = int(box.cls[0]) # class
62
+ color = self.color_map.get(cls, (0,255,0))
63
+
64
+ cv2.rectangle(annotated_frame, (x1,y1), (x2,y2), color=color, thickness=2)
65
+
66
+ text = f'{class_names[cls]} {box.conf[0]:.2f}'
67
+ (text_width, text_height), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
68
+ cv2.rectangle(annotated_frame, (x1, y1-text_height-5), (x1+text_width, y1), color, -1)
69
+
70
+ cv2.putText(annotated_frame, text, (x1, y1-5),
71
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), thickness=2)
72
+ while self.debug:
73
+ cv2.imshow('frame', annotated_frame)
74
+ if cv2.waitKey(1) & 0xFF == ord('q'):
75
+ break
76
+ cv2.destroyAllWindows()
77
+
78
+ def export_model(self, format: str = 'onnx'):
79
+ self.model.export(format=format)
src/vidprocessing.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from typing import List
3
+ from os.path import dirname, abspath
4
+ from pathlib import Path
5
+ import numpy as np
6
+ from detectobjects import ObjectDetector
7
+ import yaml
8
+
9
+ base_dir = Path(dirname(dirname(abspath(__file__))))
10
+
11
+ def get_video(path: str) -> cv2.VideoCapture:
12
+ video = cv2.VideoCapture(path)
13
+ if not video.isOpened():
14
+ raise ValueError(f'Could not open video file: {path}')
15
+
16
+ fps = video.get(cv2.CAP_PROP_FPS)
17
+ frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
18
+ duration = frame_count/fps
19
+
20
+ print(f'FPS: {fps}\nFrame Count: {frame_count}\nDuration: {duration}')
21
+
22
+ return video
23
+
24
+ def get_frames(video: cv2.VideoCapture, frame_start: int, frame_end: int) -> List[np.ndarray]:
25
+
26
+ frames = []
27
+ for i in range(frame_start, frame_end+1):
28
+ video.set(cv2.CAP_PROP_POS_FRAMES, i)
29
+ ret, frame = video.read()
30
+ if not ret:
31
+ raise ValueError(f'Could not read frame {i}')
32
+
33
+ frames.append(frame)
34
+
35
+ return frames
36
+
37
+
38
+ def create_images_of_video(video: cv2.VideoCapture, interval: int = 100):
39
+ frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
40
+
41
+ # read frame_count / interval of frames
42
+ for i in range(0, frame_count, interval):
43
+ video.set(cv2.CAP_PROP_POS_FRAMES, i)
44
+ success, frame = video.read()
45
+ if not success:
46
+ raise ValueError(f'Could not read frame {i}')
47
+
48
+ # write photo to file
49
+ cv2.imwrite(base_dir / "data" / 'model_data' / 'temp_vid_folder' /f'{i}.png', frame)
50
+
51
+
52
+ if __name__ == "__main__":
53
+ vid = get_video(base_dir / "data" / "video_data" / "fortnite_remo_three.mp4")
54
+ frames = get_frames(vid, 21100, 21110)
55
+
56
+ ### PRETRAINED MODEL DETECTION CODE
57
+ yolo = ObjectDetector(pretrained_model=(base_dir / 'best.pt'), debug=True)
58
+ yolo.detect_object(frames)
59
+
60
+
61
+ ### TRAINING CODE FOR YOLO MODEL
62
+ # yolo = ObjectDetector()
63
+ # yolo.train_model(base_dir / 'data' / 'model_data' / 'fortnite_train.yaml')
64
+
65
+ ### VIDEO FRAME CUTTING CODE
66
+ # create_images_of_video(vid)
67
+
68
+ ### MODEL EXPORT + VIDEO RELEASE
69
+ yolo.export_model()
70
+ vid.release()