Spaces:
Build error
Build error
| import asyncio | |
| from datetime import datetime | |
| import logging | |
| import cv2 | |
| import numpy as np | |
| from pathlib import Path | |
| import torch | |
| from zoneinfo import ZoneInfo | |
| from starlette.middleware import Middleware | |
| from starlette.responses import StreamingResponse, Response | |
| from starlette.requests import Request | |
| from starlette.routing import Mount, Route | |
| from starlette.staticfiles import StaticFiles | |
| from starlette.templating import Jinja2Templates | |
| from sse_starlette import EventSourceResponse | |
| from asgi_htmx import HtmxMiddleware | |
| from asgi_htmx import HtmxRequest | |
| from ultralytics import YOLO | |
| from ultralytics_solutions_modified import object_counter, speed_estimation | |
| from vidgear.gears import CamGear | |
| from vidgear.gears.asyncio import WebGear | |
| from vidgear.gears.asyncio.helper import reducer | |
| from helper import ( | |
| draw_text, make_table_from_dict_multiselect, make_table_from_dict, try_site | |
| ) | |
| HERE = Path(__file__).parent | |
| static = StaticFiles(directory=HERE / ".vidgear/webgear/static") | |
| templates = Jinja2Templates(directory=HERE / ".vidgear/webgear/templates") | |
| EVT_STREAM_DELAY_SEC = 0.05 # second | |
| RETRY_TIMEOUT_MILSEC = 15000 # milisecond | |
| # Create and configure logger | |
| # logger = logging.getLogger(__name__).addHandler(logging.NullHandler()) | |
| logging.basicConfig( | |
| format='%(asctime)s %(name)-8s->%(module)-20s->%(funcName)-20s:%(lineno)-4s::%(levelname)-8s %(message)s', # noqa | |
| level=logging.INFO | |
| ) | |
| class DemoCase: | |
| def __init__( | |
| self, | |
| FRAME_WIDTH: int = 1280, | |
| FRAME_HEIGHT: int = 720, | |
| YOLO_VERBOSE: bool = True | |
| ): | |
| self.FRAME_WIDTH: int = FRAME_WIDTH | |
| self.FRAME_HEIGHT: int = FRAME_HEIGHT | |
| self.YOLO_VERBOSE: bool = YOLO_VERBOSE | |
| self.STREAM_RESOLUTION: str = "720p" | |
| # predefined yolov8 model references | |
| self.model_dict: dict = { | |
| "y8nano": "./data/models/yolov8n.pt", | |
| "y8small": "./data/models/yolov8s.pt", | |
| "y8medium": "./data/models/yolov8m.pt", | |
| "y8large": "./data/models/yolov8l.pt", | |
| "y8huge": "./data/models/yolov8x.pt", | |
| } | |
| self.model_choice_default: str = "y8small" | |
| self.model_choice: str = self.model_choice_default | |
| # predefined youtube live stream urls | |
| self.url_dict: dict = { | |
| "Peace Bridge US": "https://youtu.be/9En2186vo5g", | |
| "Peace Bridge CA": "https://youtu.be/WPMgP2C3_co", | |
| "San Marcos TX": "https://youtu.be/E8LsKcVpL5A", | |
| "4Corners Downtown": "https://youtu.be/ByED80IKdIU", | |
| "Gangnam Seoul": "https://youtu.be/3ottn7kfRuc", | |
| "Time Square NY": "https://youtu.be/QTTTY_ra2Tg", | |
| "Port Everglades-1": "https://youtu.be/67-73mgWDf0", | |
| "Port Everglades-2": "https://youtu.be/Nhuu1QsW5LI", | |
| "Port Everglades-3": "https://youtu.be/Lpm-C_Gz6yM", | |
| } | |
| self.obj_dict: dict = { | |
| "person": 0, | |
| "bicycle": 1, | |
| "car": 2, | |
| "motorcycle": 3, | |
| "airplane": 4, | |
| "bus": 5, | |
| "train": 6, | |
| "truck": 7, | |
| "boat": 8, | |
| "traffic light": 9, | |
| "fire hydrant": 10, | |
| "stop sign": 11, | |
| "parking meter": 12 | |
| } | |
| self.cam_loc_default: str = "Peace Bridge US" | |
| self.cam_loc: str = self.cam_loc_default | |
| self.frame_reduction: int = 35 | |
| # run time parameters that are from user input | |
| self.roi_height_default: int = int(FRAME_HEIGHT / 2) | |
| self.roi_height: int = self.roi_height_default | |
| self.roi_thickness_half_default: int = 30 | |
| self.roi_thickness_half: int = self.roi_thickness_half_default | |
| self.obj_class_id_default: list[int] = [2, 3, 5, 7] | |
| self.obj_class_id: list[int] = self.obj_class_id_default | |
| self.conf_threshold: float = 0.25 | |
| self.iou_threshold: float = 0.7 | |
| self.use_FP16: bool = False | |
| self.use_stream_buffer: bool = True | |
| self.stream0: CamGear = None | |
| self.stream1: CamGear = None | |
| self.counter = None | |
| self.speed_obj = None | |
| # define some logic flow control booleans | |
| self._is_running: bool = False | |
| self._is_tracking: bool = False | |
| self._roi_changed: bool = False | |
| def load_model( | |
| self, | |
| model_choice: str = "y8small", | |
| conf_threshold: float = 0.25, | |
| iou_threshold: float = 0.7, | |
| use_FP16: bool = False, | |
| use_stream_buffer: bool = False | |
| ) -> None: | |
| """ | |
| load the YOLOv8 model of choice | |
| """ | |
| if model_choice not in self.model_dict: | |
| logging.warning( | |
| f'\"{model_choice}\" not found in the model_dict, use ' | |
| f'\"{self.model_dict[self.model_choice_default]}\" instead!' | |
| ) | |
| self.model_choice = self.model_choice_default | |
| else: | |
| self.model_choice = model_choice | |
| self.model = YOLO(f"{self.model_dict[self.model_choice]}") | |
| # push the model to GPU if available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if device == "cuda": | |
| torch.cuda.set_device(0) | |
| self.model.to(device) | |
| logging.info( | |
| f"{self.model_dict[self.model_choice]} loaded using " | |
| f"torch w GPU0" | |
| ) | |
| else: | |
| logging.info( | |
| f"{self.model_dict[self.model_choice]} loaded using CPU" | |
| ) | |
| # setup some configs | |
| self.conf_threshold = conf_threshold if conf_threshold > 0.0 else 0.25 # noqa | |
| self.iou_threshold = iou_threshold if iou_threshold > 0.0 else 0.7 | |
| self.use_FP16 = use_FP16 | |
| self.use_stream_buffer = use_stream_buffer | |
| logging.info( | |
| f"{self.model_choice}: conf={self.conf_threshold:.2f} | " | |
| f"iou={self.iou_threshold:.2f} | FP16={self.use_FP16} | " | |
| f"stream_buffer={self.use_stream_buffer}" | |
| ) | |
| def select_cam_loc( | |
| self, | |
| cam_loc_key: str = "Peace Bridge US", | |
| cam_loc_val: str = "https://www.youtube.com/watch?v=9En2186vo5g" | |
| ) -> None: | |
| """ | |
| select camera video feed from url_dict, or set as a new url | |
| """ | |
| if (bool(cam_loc_key) is False or bool(cam_loc_val) is False): | |
| self.cam_loc = self.cam_loc_default | |
| logging.warning( | |
| f'input cam_loc_key, cam_loc_val pair invalid, use default ' | |
| f'{{{self.cam_loc_default}: ' | |
| f'{self.url_dict[self.cam_loc_default]}}}' | |
| ) | |
| elif cam_loc_key not in self.url_dict: | |
| if try_site(self.url_dict[self.cam_loc]): | |
| self.url_dict.update({cam_loc_key: cam_loc_val}) | |
| self.cam_loc = cam_loc_key | |
| logging.info( | |
| f'input cam_loc key:val pair is new and playable, add ' | |
| f'{{{cam_loc_key}:{cam_loc_val}}} into url_dict' | |
| ) | |
| else: | |
| self.cam_loc = self.cam_loc_default | |
| logging.warning( | |
| f'input cam_loc key:val pair is new but not playable, ' | |
| f'roll back to default {{{self.cam_loc_default}: ' | |
| f'{self.url_dict[self.cam_loc_default]}}}' | |
| ) | |
| self.cam_loc = self.cam_loc_default | |
| else: | |
| self.cam_loc = cam_loc_key | |
| logging.info( | |
| f'use {{{self.cam_loc}: {self.url_dict[self.cam_loc]}}} as source' | |
| ) | |
| def select_obj_class_id( | |
| self, | |
| obj_names: list[str] = [ | |
| "person", "bicycle", "car", "motorcycle", "airplane", "bus", | |
| "train", "truck", "boat", "traffic light", "fire hydrant", | |
| "stop sign", "parking meter" | |
| ] | |
| ) -> None: | |
| """ | |
| select object class id list based on the input obj_names str list | |
| """ | |
| if (bool(obj_names) is False): | |
| self.obj_class_id = self.obj_class_id_default | |
| logging.warning( | |
| f'input obj_names invalid, use default id {self.obj_class_id_default}' | |
| ) | |
| else: | |
| obj_class_id = [] | |
| for name in obj_names: | |
| if name in list(self.obj_dict.keys()): | |
| obj_class_id.append(self.obj_dict[name]) | |
| if (len(obj_class_id) == 0): | |
| self.obj_class_id = self.obj_class_id_default | |
| logging.warning( | |
| f'input obj_names invalid, use default id ' | |
| f'{self.obj_class_id_default}' | |
| ) | |
| else: | |
| self.obj_class_id = obj_class_id | |
| logging.info(f'object class id set as {self.obj_class_id}') | |
| # def set_roi(self, roi_height: int = 360, roi_thickness_half: int = 30): | |
| def set_roi(self, roi_height: int = 360): | |
| if (roi_height < 120 or roi_height > 600): | |
| self.roi_height = int(self.FRAME_HEIGHT / 2) | |
| logging.warning( | |
| f'roi_height invalid, use default {int(self.FRAME_HEIGHT / 2)}' | |
| ) | |
| else: | |
| self.roi_height = roi_height | |
| logging.info(f'roi_height is set at {self.roi_height}') | |
| self.roi_thickness_half = self.roi_thickness_half_default | |
| ''' | |
| if ( | |
| roi_thickness_half > 0 and | |
| roi_thickness_half < int(self.FRAME_HEIGHT / 2) | |
| ): | |
| if (self.roi_height + roi_thickness_half > self.FRAME_HEIGHT): | |
| self.roi_thickness_half = self.FRAME_HEIGHT - self.roi_height | |
| elif (self.roi_height - roi_thickness_half < 0): | |
| self.roi_thickness_half = self.roi_height | |
| else: | |
| self.roi_thickness_half = roi_thickness_half | |
| logging.info( | |
| f'roi_thickness_half is set at {self.roi_thickness_half}' | |
| ) | |
| else: | |
| self.roi_thickness_half = self.roi_thickness_half_default | |
| logging.warning('roi_half_thickness invalid, use default 30') | |
| ''' | |
| def set_frame_reduction(self, frame_reduction: int = 35): | |
| if (frame_reduction < 0 or frame_reduction > 100): | |
| self.frame_reduction = 35 | |
| logging.warning( | |
| f'frame_reduction:{frame_reduction} invalid, ' | |
| f'use default value 35' | |
| ) | |
| else: | |
| self.frame_reduction = frame_reduction | |
| logging.info(f'frame_reduction is set at {self.frame_reduction}') | |
| async def frame0_producer(self): | |
| """ | |
| !!! define your original video source here !!! | |
| Yields: | |
| _type_: an image frame as a bytestring output from the producer | |
| """ | |
| while True: | |
| if self._is_running: | |
| if self.stream0 is None: | |
| try: | |
| # Start the stream, set desired resolution to be 720p | |
| options = {"STREAM_RESOLUTION": "720p"} | |
| self.stream0 = CamGear( | |
| source=self.url_dict[self.cam_loc], | |
| colorspace=None, | |
| stream_mode=True, | |
| logging=True, | |
| **options | |
| ).start() | |
| except Exception: | |
| # Start the stream, set best resolution | |
| self.stream0 = CamGear( | |
| source=self.url_dict[self.cam_loc], | |
| colorspace=None, | |
| stream_mode=True, | |
| logging=True | |
| ).start() | |
| logging.warning( | |
| f"failed to connect {self.url_dict[self.cam_loc]} " | |
| f"at 720p resolution, use best resolution" | |
| ) | |
| try: | |
| # loop over frames | |
| while (self.stream0 is not None and self._is_running): | |
| frame = self.stream0.read() | |
| if frame is None: | |
| frame = (np.random.standard_normal([ | |
| self.FRAME_HEIGHT, self.FRAME_WIDTH, 3 | |
| ]) * 255).astype(np.uint8) | |
| elif frame.shape != (self.FRAME_HEIGHT, self.FRAME_WIDTH, 3): | |
| frame = cv2.resize(frame, (self.FRAME_HEIGHT, self.FRAME_WIDTH)) | |
| # do something with your OpenCV frame here | |
| draw_text( | |
| img=frame, | |
| text=datetime.now( | |
| tz=ZoneInfo("America/Los_Angeles") | |
| ).strftime("%m/%d/%Y %H:%M:%S") + " PDT", | |
| pos=(int(self.FRAME_WIDTH - 500), 50), | |
| font=cv2.FONT_HERSHEY_SIMPLEX, | |
| font_scale=1, | |
| font_thickness=2, | |
| line_type=cv2.LINE_AA, | |
| text_color=(0, 255, 255), | |
| text_color_bg=(0, 0, 0), | |
| ) | |
| # reducer frame size for performance, percentage int | |
| frame = await reducer( | |
| frame, percentage=self.frame_reduction | |
| ) | |
| # handle JPEG encoding & yield frame in byte format | |
| img_encoded = cv2.imencode(".jpg", frame)[1].tobytes() | |
| yield ( | |
| b"--frame\r\nContent-Type:video/jpeg2000\r\n\r\n" + | |
| img_encoded + b"\r\n" | |
| ) | |
| await asyncio.sleep(0.00001) | |
| if self.stream0 is not None: | |
| self.stream0.stop() | |
| while self.stream0.read() is not None: | |
| continue | |
| self.stream0 = None | |
| self._is_running = False | |
| except asyncio.CancelledError: | |
| if self.stream0 is not None: | |
| self.stream0.stop() | |
| while self.stream0.read() is not None: | |
| continue | |
| self.stream0 = None | |
| self._is_running = False | |
| logging.warning( | |
| "client disconneted in frame0_producer" | |
| ) | |
| frame = (np.random.standard_normal([ | |
| self.FRAME_HEIGHT, self.FRAME_WIDTH, 3 | |
| ]) * 255).astype(np.uint8) | |
| frame = await reducer( | |
| frame, percentage=self.frame_reduction | |
| ) | |
| img_encoded = cv2.imencode(".jpg", frame)[1].tobytes() | |
| logging.info( | |
| f"_is_running is {self._is_running} in frame0_producer" | |
| ) | |
| yield ( | |
| b"--frame\r\nContent-Type:video/jpeg2000\r\n\r\n" + | |
| img_encoded + b"\r\n" | |
| ) | |
| await asyncio.sleep(0.00001) | |
| else: | |
| if self._is_running is True: | |
| pass | |
| frame = (np.random.standard_normal([ | |
| self.FRAME_HEIGHT, self.FRAME_WIDTH, 3 | |
| ]) * 255).astype(np.uint8) | |
| frame = await reducer( | |
| frame, percentage=self.frame_reduction | |
| ) | |
| img_encoded = cv2.imencode(".jpg", frame)[1].tobytes() | |
| logging.info( | |
| f"_is_running is {self._is_running} in frame0_producer" | |
| ) | |
| yield ( | |
| b"--frame\r\nContent-Type:video/jpeg2000\r\n\r\n" + | |
| img_encoded + b"\r\n" | |
| ) | |
| await asyncio.sleep(0.00001) | |
| async def frame1_producer(self): | |
| """ | |
| !!! define your processed video producer here !!! | |
| Yields: | |
| _type_: an image frame as a bytestring output from the producer | |
| """ | |
| while True: | |
| if self._is_running: | |
| if self.stream1 is None: | |
| try: | |
| # Start the stream, set desired quality as 720p | |
| options = {"STREAM_RESOLUTION": "720p"} | |
| self.stream1 = CamGear( | |
| source=self.url_dict[self.cam_loc], | |
| colorspace=None, | |
| stream_mode=True, | |
| logging=True, | |
| **options | |
| ).start() | |
| except Exception: | |
| # Start the stream, use the best resolution | |
| self.stream1 = CamGear( | |
| source=self.url_dict[self.cam_loc], | |
| colorspace=None, | |
| stream_mode=True, | |
| logging=True | |
| ).start() | |
| logging.warning( | |
| f"failed to connect {self.url_dict[self.cam_loc]} " | |
| f"at 720p resolution, use best resolution" | |
| ) | |
| if (self._is_tracking and self.stream1 is not None): | |
| if self.counter is None or self._roi_changed: | |
| # setup object counter & speed estimator | |
| region_points = [ | |
| (5, -self.roi_thickness_half + self.roi_height), | |
| (5, self.roi_thickness_half + self.roi_height), | |
| ( | |
| self.FRAME_WIDTH - 5, | |
| self.roi_thickness_half + self.roi_height | |
| ), | |
| ( | |
| self.FRAME_WIDTH - 5, | |
| -self.roi_thickness_half + self.roi_height | |
| ), | |
| ] | |
| self.counter = object_counter.ObjectCounter() | |
| self.counter.set_args( | |
| view_img=False, | |
| reg_pts=region_points, | |
| classes_names=self.model.names, | |
| draw_tracks=False, | |
| draw_boxes=False, | |
| draw_reg_pts=True, | |
| ) | |
| self._roi_changed = False | |
| if self.speed_obj is None or self._roi_changed: | |
| # Init speed estimator | |
| line_points = [ | |
| (5, self.roi_height), | |
| (self.FRAME_WIDTH - 5, self.roi_height) | |
| ] | |
| self.speed_obj = speed_estimation.SpeedEstimator() | |
| self.speed_obj.set_args( | |
| reg_pts=line_points, | |
| names=self.model.names, | |
| view_img=False | |
| ) | |
| self._roi_changed = False | |
| try: | |
| while (self.stream1 is not None and self._is_running): | |
| if self._roi_changed: | |
| # setup object counter & speed estimator | |
| region_points = [ | |
| (5, -self.roi_thickness_half + self.roi_height), | |
| (5, self.roi_thickness_half + self.roi_height), | |
| ( | |
| self.FRAME_WIDTH - 5, | |
| self.roi_thickness_half + self.roi_height | |
| ), | |
| ( | |
| self.FRAME_WIDTH - 5, | |
| -self.roi_thickness_half + self.roi_height | |
| ), | |
| ] | |
| self.counter = object_counter.ObjectCounter() | |
| self.counter.set_args( | |
| view_img=False, | |
| reg_pts=region_points, | |
| classes_names=self.model.names, | |
| draw_tracks=False, | |
| draw_boxes=False, | |
| draw_reg_pts=True, | |
| ) | |
| # Init speed estimator | |
| line_points = [ | |
| (5, self.roi_height), | |
| (self.FRAME_WIDTH - 5, self.roi_height) | |
| ] | |
| self.speed_obj = speed_estimation.SpeedEstimator() | |
| self.speed_obj.set_args( | |
| reg_pts=line_points, | |
| names=self.model.names, | |
| view_img=False | |
| ) | |
| self._roi_changed = False | |
| # read frame from provided source | |
| frame = self.stream1.read() | |
| if frame is None: | |
| frame = (np.random.standard_normal([ | |
| self.FRAME_HEIGHT, self.FRAME_WIDTH, 3 | |
| ]) * 255).astype(np.uint8) | |
| elif frame.shape != (self.FRAME_HEIGHT, self.FRAME_WIDTH, 3): | |
| frame = cv2.resize(frame, (self.FRAME_WIDTH, self.FRAME_HEIGHT)) | |
| # do something with your OpenCV frame here | |
| draw_text( | |
| img=frame, | |
| text=datetime.now( | |
| tz=ZoneInfo("America/Los_Angeles") | |
| ).strftime("%m/%d/%Y %H:%M:%S") + " PDT", | |
| pos=(self.FRAME_WIDTH - 500, 50), | |
| font=cv2.FONT_HERSHEY_SIMPLEX, | |
| font_scale=1, | |
| font_thickness=2, | |
| line_type=cv2.LINE_AA, | |
| text_color=(0, 255, 255), | |
| text_color_bg=(0, 0, 0), | |
| ) | |
| frame_tagged = frame | |
| if ( | |
| self._is_tracking and self.model is not None | |
| and self.speed_obj is not None | |
| and self.counter is not None | |
| and self._roi_changed is False | |
| ): | |
| # YOLOv8 tracking, persisting tracks between frames | |
| results = self.model.track( | |
| source=frame, | |
| classes=self.obj_class_id, | |
| conf=self.conf_threshold, | |
| iou=self.iou_threshold, | |
| half=self.use_FP16, | |
| stream_buffer=self.use_stream_buffer, | |
| persist=True, | |
| show=False, | |
| verbose=self.YOLO_VERBOSE | |
| ) | |
| if results[0].boxes.id is None: | |
| pass | |
| else: | |
| self.speed_obj.estimate_speed( | |
| frame_tagged, results | |
| ) | |
| self.counter.start_counting( | |
| frame_tagged, results | |
| ) | |
| # reducer frames size for performance, int percentage | |
| frame_tagged = await reducer( | |
| frame=frame_tagged, | |
| percentage=self.frame_reduction | |
| ) | |
| # handle JPEG encoding & yield frame in byte format | |
| img_encoded = \ | |
| cv2.imencode(".jpg", frame_tagged)[1].tobytes() | |
| yield ( | |
| b"--frame\r\nContent-Type:video/jpeg2000\r\n\r\n" + | |
| img_encoded + b"\r\n" | |
| ) | |
| await asyncio.sleep(0.00001) | |
| if self.stream1 is not None: | |
| self.stream1.stop() | |
| while self.stream1.read() is not None: | |
| continue | |
| self.stream1 = None | |
| self._is_tracking = False | |
| self._is_running = False | |
| except asyncio.CancelledError: | |
| if self.stream1 is not None: | |
| self.stream1.stop() | |
| while self.stream1.read() is not None: | |
| continue | |
| self.stream1 = None | |
| self._is_tracking = False | |
| self._is_running = False | |
| logging.warning( | |
| "client disconnected in frame1_producer" | |
| ) | |
| frame = (np.random.standard_normal([ | |
| self.FRAME_HEIGHT, self.FRAME_WIDTH, 3 | |
| ]) * 255).astype(np.uint8) | |
| frame = await reducer( | |
| frame, percentage=self.frame_reduction | |
| ) | |
| img_encoded = cv2.imencode(".jpg", frame)[1].tobytes() | |
| logging.info( | |
| f"_is_running is {self._is_running} in frame0_producer" | |
| ) | |
| yield ( | |
| b"--frame\r\nContent-Type:video/jpeg2000\r\n\r\n" + | |
| img_encoded + b"\r\n" | |
| ) | |
| await asyncio.sleep(0.00001) | |
| else: | |
| if self._is_running is True: | |
| pass | |
| frame = (np.random.standard_normal([ | |
| self.FRAME_HEIGHT, self.FRAME_WIDTH, 3 | |
| ]) * 255).astype(np.uint8) | |
| # reducer frame size for more performance, percentage int | |
| frame = await reducer(frame, percentage=self.frame_reduction) | |
| # handle JPEG encoding & yield frame in byte format | |
| img_encoded = cv2.imencode(".jpg", frame)[1].tobytes() | |
| yield ( | |
| b"--frame\r\nContent-Type:video/jpeg2000\r\n\r\n" + | |
| img_encoded + b"\r\n" | |
| ) | |
| await asyncio.sleep(0.00001) | |
| async def custom_video_response(self, scope): | |
| """ | |
| Return a async video streaming response for `frame1_producer` generator | |
| Tip1: use BackgroundTask to handle the async cleanup | |
| https://github.com/tiangolo/fastapi/discussions/11022 | |
| Tip2: use is_disconnected to check client disconnection | |
| https://www.starlette.io/requests/#body | |
| https://github.com/encode/starlette/pull/320/files/d56c917460a1e6488e1206c428445c39854859c1 | |
| """ | |
| assert scope["type"] in ["http", "https"] | |
| await asyncio.sleep(0.00001) | |
| return StreamingResponse( | |
| content=self.frame1_producer(), | |
| media_type="multipart/x-mixed-replace; boundary=frame" | |
| ) | |
| async def models(self, request: HtmxRequest) -> Response: | |
| # assert (htmx := request.scope["htmx"]) | |
| if len(self.model_dict) == 0: | |
| template = "partials/ack.html" | |
| table_contents = ["model list unavailable!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| # response.headers['Hx-Retarget'] = '#set-model-ack' | |
| else: | |
| template = "partials/yolo_models.html" | |
| table_contents = make_table_from_dict( | |
| self.model_dict, self.model_choice | |
| ) | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| await asyncio.sleep(0.001) | |
| return response | |
| async def urls(self, request: HtmxRequest) -> Response: | |
| # assert (htmx := request.scope["htmx"]) | |
| if len(self.url_dict) == 0: | |
| template = "partials/ack.html" | |
| table_contents = ["streaming url list unavailable!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| # response.headers['Hx-Retarget'] = '#set-model-ack' | |
| else: | |
| template = "partials/camera_streams.html" | |
| table_contents = make_table_from_dict(self.url_dict, self.cam_loc) | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| await asyncio.sleep(0.01) | |
| return response | |
| async def objects(self, request: HtmxRequest) -> Response: | |
| # assert (htmx := request.scope["htmx"]) | |
| if len(self.obj_dict) == 0: | |
| template = "partials/ack.html" | |
| table_contents = ["object list unavailable!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| # response.headers['Hx-Retarget'] = '#set-model-ack' | |
| else: | |
| template = "partials/object_list.html" | |
| table_contents = make_table_from_dict_multiselect( | |
| self.obj_dict, self.obj_class_id | |
| ) | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| await asyncio.sleep(0.001) | |
| return response | |
| async def geturl(self, request: HtmxRequest) -> Response: | |
| # assert (htmx := request.scope["htmx"]) | |
| if len(self.url_dict) == 0: | |
| template = "partials/ack.html" | |
| table_contents = ["streaming url list unavailable!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| # response.headers['Hx-Retarget'] = '#set-model-ack' | |
| else: | |
| template = "partials/ack.html" | |
| if self.cam_loc in self.url_dict.keys(): | |
| table_contents = [f"{self.cam_loc} selected"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=201 | |
| ) | |
| else: | |
| table_contents = [ | |
| f"{self.cam_loc} is not in the registered url_list" | |
| ] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| # response.headers['Hx-Retarget'] = '#set-url-ack' | |
| await asyncio.sleep(0.01) | |
| return response | |
| async def addurl(self, request: HtmxRequest) -> Response: | |
| # assert (htmx := request.scope["htmx"]) | |
| try: | |
| req_json = await request.json() | |
| except RuntimeError: | |
| template = "partials/ack.html" | |
| table_contents = ["receive channel unavailable!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| response.headers['Hx-Retarget'] = '#add-url-ack' | |
| await asyncio.sleep(0.01) | |
| return response | |
| if ( | |
| "payload" in req_json | |
| and "CamLoc" in req_json["payload"] and "URL" in req_json["payload"] | |
| ): | |
| cam_loc = req_json["payload"]["CamLoc"] | |
| cam_url = req_json["payload"]["URL"] | |
| if cam_loc != "" and cam_url != "": | |
| if try_site(cam_url) is False: | |
| template = "partials/ack.html" | |
| table_contents = ["invalid video URL!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| response.headers['Hx-Retarget'] = '#add-url-ack' | |
| else: | |
| self.select_cam_loc( | |
| cam_loc_key=cam_loc, cam_loc_val=cam_url | |
| ) | |
| template = "partials/camera_streams.html" | |
| table_contents = make_table_from_dict( | |
| self.url_dict, self.cam_loc | |
| ) | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=201 | |
| ) | |
| else: | |
| template = "partials/ack.html" | |
| table_contents = ["empty or invalid inputs!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| response.headers['Hx-Retarget'] = '#add-url-ack' | |
| else: | |
| template = "partials/ack.html" | |
| table_contents = ["invalid POST request!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| response.headers['Hx-Retarget'] = '#add-url-ack' | |
| await asyncio.sleep(0.01) | |
| return response | |
| async def seturl(self, request: HtmxRequest) -> Response: | |
| # assert (htmx := request.scope["htmx"]) | |
| template = "partials/ack.html" | |
| try: | |
| req_json = await request.json() | |
| except RuntimeError: | |
| table_contents = ["receive channel unavailable!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| # response.headers['Hx-Retarget'] = '#set-url-ack' | |
| await asyncio.sleep(0.01) | |
| return response | |
| if ("payload" in req_json and "cam_url" in req_json["payload"]): | |
| logging.info( | |
| f"seturl: _is_running = {self._is_running}, " | |
| f"_is_tracking = {self._is_tracking}" | |
| ) | |
| if (self._is_running is True or self._is_tracking is True): | |
| table_contents = ["turn off streaming and tracking before \ | |
| setting a new camera stream!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| # response.headers['Hx-Retarget'] = '#set-url-ack' | |
| else: | |
| cam_url = req_json["payload"]["cam_url"] | |
| url_list = list(filter( | |
| lambda x: self.url_dict[x] == cam_url, self.url_dict | |
| )) | |
| if len(url_list) > 0: | |
| self.cam_loc = url_list[0] | |
| table_contents = [f"{self.cam_loc} selected"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=201 | |
| ) | |
| else: | |
| table_contents = [ | |
| f"{cam_url} is not in the registered url_list" | |
| ] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| # response.headers['Hx-Retarget'] = '#set-url-ack' | |
| else: | |
| table_contents = ["invalid POST request!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| # response.headers['Hx-Retarget'] = '#set-url-ack' | |
| await asyncio.sleep(0.01) | |
| return response | |
| async def getmodel(self, request: HtmxRequest) -> Response: | |
| # assert (htmx := request.scope["htmx"]) | |
| if len(self.model_dict) == 0: | |
| template = "partials/ack.html" | |
| table_contents = ["model list unavailable!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| # response.headers['Hx-Retarget'] = '#set-model-ack' | |
| else: | |
| template = "partials/ack.html" | |
| if self.model_choice in self.model_dict.keys(): | |
| table_contents = [f"{self.model_choice} selected"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=201 | |
| ) | |
| else: | |
| table_contents = [ | |
| f"{self.model_choice} is not in the registered model_list" | |
| ] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| # response.headers['Hx-Retarget'] = '#set-url-ack' | |
| await asyncio.sleep(0.01) | |
| return response | |
| async def setmodel(self, request: HtmxRequest) -> Response: | |
| # assert (htmx := request.scope["htmx"]) | |
| template = "partials/ack.html" | |
| try: | |
| req_json = await request.json() | |
| except RuntimeError: | |
| table_contents = ["receive channel unavailable!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| # response.headers['Hx-Retarget'] = '#set-model-ack' | |
| await asyncio.sleep(0.01) | |
| return response | |
| if ("payload" in req_json and "model_path" in req_json["payload"]): | |
| logging.info( | |
| f"setmodel: _is_running = {self._is_running}, " | |
| f"_is_tracking = {self._is_tracking}" | |
| ) | |
| if (self._is_tracking is True): | |
| table_contents = ["turn off tracking before setting a new \ | |
| YOLO model!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| # response.headers['Hx-Retarget'] = '#set-model-ack' | |
| else: | |
| model_path = req_json["payload"]["model_path"] | |
| model_list = list(filter( | |
| lambda x: self.model_dict[x] == model_path, self.model_dict | |
| )) | |
| if len(model_list) > 0: | |
| self.model_choice = model_list[0] | |
| self.load_model( | |
| model_choice=self.model_choice, | |
| conf_threshold=self.conf_threshold, | |
| iou_threshold=self.iou_threshold, | |
| use_FP16=self.use_FP16, | |
| use_stream_buffer=self.use_stream_buffer | |
| ) | |
| table_contents = [f"{self.model_choice} selected"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=201 | |
| ) | |
| else: | |
| table_contents = [ | |
| f"{model_path} is not in the registered model_list" | |
| ] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| # response.headers['Hx-Retarget'] = '#set-model-ack' | |
| else: | |
| table_contents = ["invalid POST request!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| # response.headers['Hx-Retarget'] = '#set-model-ack' | |
| await asyncio.sleep(0.01) | |
| return response | |
| async def selectobjects(self, request: HtmxRequest) -> Response: | |
| # assert (htmx := request.scope["htmx"]) | |
| template = "partials/ack.html" | |
| try: | |
| req_json = await request.json() | |
| except RuntimeError: | |
| table_contents = ["receive channel unavailable!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| await asyncio.sleep(0.01) | |
| return response | |
| if ("payload" in req_json and "object_id" in req_json["payload"]): | |
| logging.info(f"requested_ids: {req_json['payload']}") | |
| req_ids = req_json["payload"]["object_id"] | |
| if len(req_ids) > 0: | |
| self.obj_class_id = [ | |
| int(id) for id in req_ids | |
| if int(id) in self.obj_dict.values() | |
| ] | |
| if len(self.obj_class_id) > 0: | |
| table_contents = [ | |
| f"{len(self.obj_class_id)} object types selected" | |
| ] | |
| else: | |
| self.obj_class_id = self.obj_class_id_default | |
| table_contents = [ | |
| "invalid objects selection, use default object types" | |
| ] | |
| else: | |
| table_contents = ["invalid POST request! need at least one object type"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| await asyncio.sleep(0.01) | |
| return response | |
| async def setroi(self, request: HtmxRequest) -> Response: | |
| # assert (htmx := request.scope["htmx"]) | |
| template = "partials/ack.html" | |
| try: | |
| req_json = await request.json() | |
| except RuntimeError: | |
| table_contents = ["receive channel unavailable!"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| await asyncio.sleep(0.01) | |
| return response | |
| if ("payload" in req_json and "roi_height" in req_json["payload"]): | |
| logging.info(f"{req_json['payload']}") | |
| req_height = (int)(req_json["payload"]["roi_height"]) | |
| if ( | |
| req_height >= 120 and req_height <= 600 and | |
| req_height < self.FRAME_HEIGHT | |
| ): | |
| self.roi_height = self.FRAME_HEIGHT - req_height | |
| table_contents = [ | |
| f"roi_height set at " | |
| f"{self.FRAME_HEIGHT - self.roi_height}px" | |
| ] | |
| else: | |
| self.roi_height = self.roi_height_default | |
| table_contents = [ | |
| f"invalid roi_height request, use default" | |
| f"{self.FRAME_HEIGHT - self.roi_height_default}px" | |
| ] | |
| self._roi_changed = True | |
| else: | |
| table_contents = ["invalid POST request! need a valid roi_height"] | |
| context = {"request": request, "table": table_contents} | |
| response = templates.TemplateResponse( | |
| template, context, status_code=200 | |
| ) | |
| await asyncio.sleep(0.01) | |
| return response | |
| async def streamswitch(self, request: HtmxRequest) -> Response: | |
| # assert (htmx := request.scope["htmx"]) | |
| template = "partials/ack.html" | |
| try: | |
| req_json = await request.json() | |
| except RuntimeError: | |
| context = { | |
| "request": request, "table": ["receive channel unavailable!"] | |
| } | |
| status_code = 200 | |
| await asyncio.sleep(0.01) | |
| return templates.TemplateResponse( | |
| template, context, status_code=status_code | |
| ) | |
| if "payload" in req_json: | |
| logging.info(f"payload = {req_json['payload']}") | |
| if ( | |
| "stream_switch" in req_json["payload"] | |
| and req_json["payload"]["stream_switch"] == "on" | |
| ): | |
| self._is_running = True | |
| self._is_tracking = False | |
| table_contents = ["on"] | |
| status_code = 201 | |
| else: | |
| self._is_running = False | |
| self._is_tracking = False | |
| table_contents = ["off"] | |
| status_code = 201 | |
| else: | |
| table_contents = ["invalid POST request!"] | |
| status_code = 200 | |
| context = {"request": request, "table": table_contents} | |
| await asyncio.sleep(0.1) | |
| return templates.TemplateResponse( | |
| template, context, status_code=status_code | |
| ) | |
| async def trackingswitch(self, request: HtmxRequest) -> Response: | |
| # assert (htmx := request.scope["htmx"]) | |
| template = "partials/ack.html" | |
| try: | |
| req_json = await request.json() | |
| except RuntimeError: | |
| context = { | |
| "request": request, "table": ["receive channel unavailable!"] | |
| } | |
| status_code = 200 | |
| await asyncio.sleep(0.01) | |
| return templates.TemplateResponse( | |
| template, context, status_code=status_code | |
| ) | |
| if "payload" in req_json: | |
| logging.info(f"payload = {req_json['payload']}") | |
| if ( | |
| "tracking_switch" in req_json["payload"] | |
| and req_json["payload"]["tracking_switch"] == "on" | |
| ): | |
| self._is_tracking = True and self._is_running | |
| else: | |
| self._is_tracking = False | |
| if self._is_tracking: | |
| table_contents = ["on"] | |
| status_code = 201 | |
| # setup object counter & speed estimator | |
| region_points = [ | |
| (5, -20 + self.roi_height), | |
| (5, 20 + self.roi_height), | |
| (self.FRAME_WIDTH - 5, 20 + self.roi_height), | |
| (self.FRAME_WIDTH - 5, -20 + self.roi_height), | |
| ] | |
| self.counter = object_counter.ObjectCounter() | |
| self.counter.set_args( | |
| view_img=False, | |
| reg_pts=region_points, | |
| classes_names=self.model.names, | |
| draw_tracks=False, | |
| draw_boxes=False, | |
| draw_reg_pts=True, | |
| ) | |
| # Init speed estimator | |
| line_points = [ | |
| (5, self.roi_height), | |
| (self.FRAME_WIDTH - 5, self.roi_height) | |
| ] | |
| self.speed_obj = speed_estimation.SpeedEstimator() | |
| self.speed_obj.set_args( | |
| reg_pts=line_points, | |
| names=self.model.names, | |
| view_img=False | |
| ) | |
| else: | |
| table_contents = ["off"] | |
| status_code = 201 | |
| else: | |
| table_contents = ["invalid POST request!"] | |
| status_code = 200 | |
| context = {"request": request, "table": table_contents} | |
| await asyncio.sleep(0.1) | |
| return templates.TemplateResponse( | |
| template, context, status_code=status_code | |
| ) | |
| async def sse_incounts(self, request: Request): | |
| async def event_generator(): | |
| _stop_sse = False | |
| while True: | |
| # If client closes connection, stop sending events | |
| if await request.is_disconnected(): | |
| yield { | |
| "event": "evt_in_counts", | |
| "id": datetime.now( | |
| tz=ZoneInfo("America/Los_Angeles") | |
| ).strftime("%m/%d/%Y %H:%M:%S"), | |
| "retry": RETRY_TIMEOUT_MILSEC, | |
| "data": "..." | |
| } | |
| break | |
| if self._is_running: | |
| if self._is_tracking: | |
| if _stop_sse is True: | |
| _stop_sse = False | |
| incounts_msg = self.counter.incounts_updated() | |
| if (self.counter is not None and incounts_msg): | |
| yield { | |
| "event": "evt_in_counts", | |
| "id": datetime.now( | |
| tz=ZoneInfo("America/Los_Angeles") | |
| ).strftime("%m/%d/%Y %H:%M:%S"), | |
| "retry": RETRY_TIMEOUT_MILSEC, | |
| "data": f"{self.counter.in_counts}" | |
| } | |
| else: | |
| if _stop_sse is False: | |
| yield { | |
| "event": "evt_in_counts", | |
| "id": datetime.now( | |
| tz=ZoneInfo("America/Los_Angeles") | |
| ).strftime("%m/%d/%Y %H:%M:%S"), | |
| "retry": RETRY_TIMEOUT_MILSEC, | |
| "data": "---" | |
| } | |
| _stop_sse = True | |
| await asyncio.sleep(EVT_STREAM_DELAY_SEC) | |
| return EventSourceResponse(event_generator()) | |
| async def sse_outcounts(self, request: Request): | |
| async def event_generator(): | |
| _stop_sse = False | |
| while True: | |
| # If client closes connection, stop sending events | |
| if await request.is_disconnected(): | |
| yield { | |
| "event": "evt_out_counts", | |
| "id": datetime.now( | |
| tz=ZoneInfo("America/Los_Angeles") | |
| ).strftime("%m/%d/%Y %H:%M:%S"), | |
| "retry": RETRY_TIMEOUT_MILSEC, | |
| "data": "..." | |
| } | |
| break | |
| if self._is_running: | |
| if self._is_tracking: | |
| if _stop_sse is True: | |
| _stop_sse = False | |
| outcounts_msg = self.counter.outcounts_updated() | |
| if (self.counter is not None and outcounts_msg): | |
| yield { | |
| "event": "evt_out_counts", | |
| "id": datetime.now( | |
| tz=ZoneInfo("America/Los_Angeles") | |
| ).strftime("%m/%d/%Y %H:%M:%S"), | |
| "retry": RETRY_TIMEOUT_MILSEC, | |
| "data": f"{self.counter.out_counts}" | |
| } | |
| else: | |
| if _stop_sse is False: | |
| yield { | |
| "event": "evt_out_counts", | |
| "id": datetime.now( | |
| tz=ZoneInfo("America/Los_Angeles") | |
| ).strftime("%m/%d/%Y %H:%M:%S"), | |
| "retry": RETRY_TIMEOUT_MILSEC, | |
| "data": "---" | |
| } | |
| _stop_sse = True | |
| await asyncio.sleep(EVT_STREAM_DELAY_SEC) | |
| return EventSourceResponse(event_generator()) | |
| # is_huggingface = False | |
| # define the host url and port for webgear server | |
| # HOST_WEBGEAR, PORT_WEBGEAR = "localhost", 8080 | |
| # instantiate a demo case | |
| demo_case = DemoCase(YOLO_VERBOSE=False) | |
| demo_case.set_frame_reduction(frame_reduction=35) | |
| demo_case.load_model( | |
| model_choice="y8small", | |
| conf_threshold=0.1, | |
| iou_threshold=0.6, | |
| use_FP16=False, | |
| use_stream_buffer=True | |
| ) | |
| logging.info(f"url_dict: {demo_case.url_dict}") | |
| logging.info(f"model_dict: {demo_case.model_dict}") | |
| logging.info(f"obj_dict: {demo_case.obj_dict}") | |
| logging.info(f"obj_class_id: {demo_case.obj_class_id}") | |
| # logging.info(f"model.names: {demo_case.model.names}") | |
| # setup webgear server | |
| options = { | |
| "custom_data_location": "./", | |
| } | |
| web = WebGear( | |
| logging=True, **options | |
| ) | |
| # config webgear server | |
| web.config["generator"] = demo_case.frame1_producer | |
| web.config["middleware"] = [Middleware(HtmxMiddleware)] | |
| web.routes.append(Mount("/static", static, name="static")) | |
| # web.routes.append( | |
| # Route("/video1", endpoint=demo_case.custom_video_response) | |
| # ) | |
| routes_dict = { | |
| "models": (demo_case.models, ["GET"]), | |
| "getmodel": (demo_case.getmodel, ["GET"]), | |
| "setmodel": (demo_case.setmodel, ["POST"]), | |
| "urls": (demo_case.urls, ["GET"]), | |
| "addurl": (demo_case.addurl, ["POST"]), | |
| "geturl": (demo_case.geturl, ["GET"]), | |
| "seturl": (demo_case.seturl, ["POST"]), | |
| "objects": (demo_case.objects, ["GET"]), | |
| "selectobjects": (demo_case.selectobjects, ["POST"]), | |
| "setroi": (demo_case.setroi, ["POST"]), | |
| "streamswitch": (demo_case.streamswitch, ["POST"]), | |
| "trackingswitch": (demo_case.trackingswitch, ["POST"]), | |
| } | |
| for k, v in routes_dict.items(): | |
| web.routes.append( | |
| Route(path=f"/{k}", endpoint=v[0], name=k, methods=v[1]) | |
| ) | |
| web.routes.append(Route( | |
| path="/sseincounts", | |
| endpoint=demo_case.sse_incounts, | |
| name="sseincounts" | |
| )) | |
| web.routes.append(Route( | |
| path="/sseoutcounts", | |
| endpoint=demo_case.sse_outcounts, | |
| name="sseoutcounts" | |
| )) | |
| # if is_huggingface is False: | |
| # # run this app on Uvicorn server at address http://localhost:8080/ | |
| # uvicorn.run( | |
| # web(), host=HOST_WEBGEAR, port=PORT_WEBGEAR, log_level="info" | |
| # ) | |
| # # close app safely | |
| # web.shutdown() | |
| # | |
| # or launch it using cli -- | |
| # uvicorn webapp:web --host "localhost" --port 8080 --reload | |