import gradio as gr import cv2 import os from core.chessboard_detector import ChessboardDetector detector = ChessboardDetector( det_model_path="onnx/det/v1.onnx", pose_model_path="onnx/pose/4_v2.onnx", full_classifier_model_path="onnx/layout_recognition/v2.onnx" ) # 数据集路径 dict_cate_names = { '.': '.', 'x': 'x', '红帅': 'K', '红士': 'A', '红相': 'B', '红马': 'N', '红车': 'R', '红炮': 'C', '红兵': 'P', '黑将': 'k', '黑仕': 'a', '黑象': 'b', '黑傌': 'n', '黑車': 'r', '黑砲': 'c', '黑卒': 'p', } dict_cate_names_reverse = {v: k for k, v in dict_cate_names.items()} ### 构建 examples def build_examples(): examples = [] # 读取 examples 目录下的所有图片 for file in os.listdir("examples"): if file.endswith(".jpg"): image_path = os.path.join("examples", file) examples.append([image_path, None]) elif file.endswith(".mp4"): video_path = os.path.join("examples", file) examples.append([None, video_path]) return examples full_examples = build_examples() def get_video_frame_with_processs(video_data, process: str = '00:00') -> cv2.UMat: """ 获取视频指定位置的帧 """ # 读取视频 cap = cv2.VideoCapture(video_data) if not cap.isOpened(): gr.Warning("无法打开视频") return None # 获取视频的帧率 fps = cap.get(cv2.CAP_PROP_FPS) # process 是 00:00 process_time = process.split(":") minutes = int(process_time[0]) seconds = float(process_time[1]) # 计算总秒数 target_seconds = minutes * 60 + seconds # 计算当前帧 current_frame = int(target_seconds * fps) # 设置到指定帧 cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame) # 读取当前帧 ret, frame = cap.read() cap.release() if not ret: gr.Warning("无法读取视频帧") return None frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) return frame_rgb with gr.Blocks( js=""" async () => { document.addEventListener('timeupdate', function(e) { // 检查事件源是否是视频元素 if (e.target.matches('#video_player video')) { const video = e.target; const currentTime = video.currentTime; // 转换成 00:00 格式 let minutes = Math.floor(currentTime / 60); let seconds = Math.floor(currentTime % 60); let formattedTime = `${minutes.toString().padStart(2,'0')}:${seconds.toString().padStart(2,'0')}`; // 更新输入框值 let processInput = document.querySelector("#video_process textarea"); if(processInput) { processInput.value = formattedTime; processInput.text = formattedTime; processInput.dispatchEvent(new Event("input")); } } }, true); // 使用捕获阶段 } """, css=""" .image img { max-height: 512px; } """ ) as demo: gr.Markdown(""" ## 棋盘检测, 棋子识别 x 表示 有遮挡位置 . 表示 棋盘上的普通交叉点 步骤: 1. 流程分成两步,第一步检测边缘 2. 对整个棋盘画面进行棋子分类预测 """ ) with gr.Row(): with gr.Column(): video_input = gr.Video(label="上传视频", interactive=True, elem_id="video_player", height=356) video_process = gr.Textbox(label="当前时间", interactive=True, elem_id="video_process", value="00:00") extract_frame_btn = gr.Button("从视频提取当前帧") with gr.Row(): with gr.Column(): image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image") with gr.Column(): original_image_with_keypoints = gr.Image( label="step1: 原图带关键点", interactive=False, visible=True, elem_classes="image" ) with gr.Row(): with gr.Column(): transformed_image = gr.Image( label="step2: 拉伸棋盘", interactive=False, visible=True, elem_classes="image" ) with gr.Column(): use_time = gr.Textbox( label="用时", interactive=False, visible=True, ) layout_pred_info = gr.Dataframe( label="棋子识别", interactive=False, visible=True, ) with gr.Row(): with gr.Column(): gr.Examples(full_examples, inputs=[image_input, video_input], label="示例视频、图片") def detect_chessboard(image): if image is None: return None, None, None, None try: original_image_with_keypoints, transformed_image, cells_labels_str, scores, time_info = detector.pred_detect_board_and_classifier(image) # 将 cells_labels 转换为 DataFrame # cells_labels 通过 \n 分割 annotation_10_rows = [item for item in cells_labels_str.split("\n")] # 将 annotation_10_rows 转换成为 10 行 9 列的二维数组 annotation_arr_10_9 = [list(item) for item in annotation_10_rows] # 将 棋子类别 转换为 中文 annotation_arr_10_9 = [[dict_cate_names_reverse[item] for item in row] for row in annotation_arr_10_9] except Exception as e: gr.Warning(f"检测失败 图片或者视频布局错误") return None, None, None, None return original_image_with_keypoints, transformed_image, annotation_arr_10_9, time_info image_input.change(fn=detect_chessboard, inputs=[image_input], outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, use_time]) extract_frame_btn.click(fn=get_video_frame_with_processs, inputs=[video_input, video_process], outputs=[image_input]) if __name__ == "__main__": demo.launch()