Spaces:
Paused
Paused
| import os | |
| from typing import List | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import json | |
| import tempfile | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| import cv2 | |
| from gradio.themes.utils import sizes | |
| from classes_and_palettes import ( | |
| COCO_KPTS_COLORS, | |
| COCO_WHOLEBODY_KPTS_COLORS, | |
| GOLIATH_KPTS_COLORS, | |
| GOLIATH_SKELETON_INFO, | |
| GOLIATH_KEYPOINTS | |
| ) | |
| import os | |
| import sys | |
| import subprocess | |
| import importlib.util | |
| def is_package_installed(package_name): | |
| return importlib.util.find_spec(package_name) is not None | |
| def find_wheel(package_path): | |
| dist_dir = os.path.join(package_path, "dist") | |
| if os.path.exists(dist_dir): | |
| wheel_files = [f for f in os.listdir(dist_dir) if f.endswith('.whl')] | |
| if wheel_files: | |
| return os.path.join(dist_dir, wheel_files[0]) | |
| return None | |
| def install_from_wheel(package_name, package_path): | |
| wheel_file = find_wheel(package_path) | |
| if wheel_file: | |
| print(f"Installing {package_name} from wheel: {wheel_file}") | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", wheel_file]) | |
| else: | |
| print(f"{package_name} wheel not found in {package_path}. Please build it first.") | |
| sys.exit(1) | |
| def install_local_packages(): | |
| packages = [ | |
| ("mmengine", "./external/engine"), | |
| ("mmcv", "./external/cv"), | |
| ("mmdet", "./external/det") | |
| ] | |
| for package_name, package_path in packages: | |
| if not is_package_installed(package_name): | |
| print(f"Installing {package_name}...") | |
| install_from_wheel(package_name, package_path) | |
| else: | |
| print(f"{package_name} is already installed.") | |
| # Run the installation at the start of your app | |
| install_local_packages() | |
| from detector_utils import ( | |
| adapt_mmdet_pipeline, | |
| init_detector, | |
| process_images_detector, | |
| ) | |
| class Config: | |
| ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets') | |
| CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints") | |
| CHECKPOINTS = { | |
| "0.3b": "sapiens_0.3b_goliath_best_goliath_AP_575_torchscript.pt2", | |
| "0.6b": "sapiens_0.6b_goliath_best_goliath_AP_600_torchscript.pt2", | |
| "1b": "sapiens_1b_goliath_best_goliath_AP_640_torchscript.pt2", | |
| } | |
| DETECTION_CHECKPOINT = os.path.join(CHECKPOINTS_DIR, 'rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth') | |
| DETECTION_CONFIG = os.path.join(ASSETS_DIR, 'rtmdet_m_640-8xb32_coco-person_no_nms.py') | |
| class ModelManager: | |
| def load_model(checkpoint_name: str): | |
| if checkpoint_name is None: | |
| return None | |
| checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, checkpoint_name) | |
| model = torch.jit.load(checkpoint_path) | |
| model.eval() | |
| model.to("cuda") | |
| return model | |
| def run_model(model, input_tensor): | |
| return model(input_tensor) | |
| class ImageProcessor: | |
| def __init__(self): | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((1024, 768)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255], | |
| std=[58.5/255, 57.0/255, 57.5/255]) | |
| ]) | |
| self.detector = init_detector( | |
| Config.DETECTION_CONFIG, Config.DETECTION_CHECKPOINT, device='cpu' | |
| ) | |
| self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg) | |
| def detect_persons(self, image: Image.Image): | |
| # Convert PIL Image to tensor | |
| image = np.array(image) | |
| image = np.expand_dims(image, axis=0) | |
| # Perform person detection | |
| bboxes_batch = process_images_detector( | |
| image, | |
| self.detector | |
| ) | |
| bboxes = self.get_person_bboxes(bboxes_batch[0]) # Get bboxes for the first (and only) image | |
| return bboxes | |
| def get_person_bboxes(self, bboxes_batch, score_thr=0.3): | |
| person_bboxes = [] | |
| for bbox in bboxes_batch: | |
| if len(bbox) == 5: # [x1, y1, x2, y2, score] | |
| if bbox[4] > score_thr: | |
| person_bboxes.append(bbox) | |
| elif len(bbox) == 4: # [x1, y1, x2, y2] | |
| person_bboxes.append(bbox + [1.0]) # Add a default score of 1.0 | |
| return person_bboxes | |
| def estimate_pose(self, image: Image.Image, bboxes: List[List[float]], model_name: str, kpt_threshold: float): | |
| pose_model = ModelManager.load_model(Config.CHECKPOINTS[model_name]) | |
| result_image = image.copy() | |
| all_keypoints = [] # List to store keypoints for all persons | |
| for bbox in bboxes: | |
| cropped_img = self.crop_image(result_image, bbox) | |
| input_tensor = self.transform(cropped_img).unsqueeze(0).to("cuda") | |
| heatmaps = ModelManager.run_model(pose_model, input_tensor) | |
| keypoints = self.heatmaps_to_keypoints(heatmaps[0].cpu().numpy()) | |
| all_keypoints.append(keypoints) # Collect keypoints | |
| result_image = self.draw_keypoints(result_image, keypoints, bbox, kpt_threshold) | |
| return result_image, all_keypoints | |
| def process_image(self, image: Image.Image, model_name: str, kpt_threshold: str): | |
| bboxes = self.detect_persons(image) | |
| result_image, keypoints = self.estimate_pose(image, bboxes, model_name, float(kpt_threshold)) | |
| return result_image, keypoints | |
| def crop_image(self, image, bbox): | |
| if len(bbox) == 4: | |
| x1, y1, x2, y2 = map(int, bbox) | |
| elif len(bbox) >= 5: | |
| x1, y1, x2, y2, _ = map(int, bbox[:5]) | |
| else: | |
| raise ValueError(f"Unexpected bbox format: {bbox}") | |
| crop = image.crop((x1, y1, x2, y2)) | |
| return crop | |
| def heatmaps_to_keypoints(heatmaps): | |
| num_joints = heatmaps.shape[0] # Should be 308 | |
| keypoints = {} | |
| for i, name in enumerate(GOLIATH_KEYPOINTS): | |
| if i < num_joints: | |
| heatmap = heatmaps[i] | |
| y, x = np.unravel_index(np.argmax(heatmap), heatmap.shape) | |
| conf = heatmap[y, x] | |
| keypoints[name] = (float(x), float(y), float(conf)) | |
| return keypoints | |
| def draw_keypoints(image, keypoints, bbox, kpt_threshold): | |
| image = np.array(image) | |
| # Handle both 4 and 5-element bounding boxes | |
| if len(bbox) == 4: | |
| x1, y1, x2, y2 = map(int, bbox) | |
| elif len(bbox) >= 5: | |
| x1, y1, x2, y2, _ = map(int, bbox[:5]) | |
| else: | |
| raise ValueError(f"Unexpected bbox format: {bbox}") | |
| # Calculate adaptive radius and thickness based on bounding box size | |
| bbox_width = x2 - x1 | |
| bbox_height = y2 - y1 | |
| bbox_size = np.sqrt(bbox_width * bbox_height) | |
| radius = max(1, int(bbox_size * 0.006)) # minimum 1 pixel | |
| thickness = max(1, int(bbox_size * 0.006)) # minimum 1 pixel | |
| bbox_thickness = max(1, thickness//4) | |
| cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), bbox_thickness) | |
| # Draw keypoints | |
| for i, (name, (x, y, conf)) in enumerate(keypoints.items()): | |
| if conf > kpt_threshold and i < len(GOLIATH_KPTS_COLORS): | |
| x_coord = int(x * bbox_width / 192) + x1 | |
| y_coord = int(y * bbox_height / 256) + y1 | |
| color = GOLIATH_KPTS_COLORS[i] | |
| cv2.circle(image, (x_coord, y_coord), radius, color, -1) | |
| # Draw skeleton | |
| for _, link_info in GOLIATH_SKELETON_INFO.items(): | |
| pt1_name, pt2_name = link_info['link'] | |
| color = link_info['color'] | |
| if pt1_name in keypoints and pt2_name in keypoints: | |
| pt1 = keypoints[pt1_name] | |
| pt2 = keypoints[pt2_name] | |
| if pt1[2] > kpt_threshold and pt2[2] > kpt_threshold: | |
| x1_coord = int(pt1[0] * bbox_width / 192) + x1 | |
| y1_coord = int(pt1[1] * bbox_height / 256) + y1 | |
| x2_coord = int(pt2[0] * bbox_width / 192) + x1 | |
| y2_coord = int(pt2[1] * bbox_height / 256) + y1 | |
| cv2.line(image, (x1_coord, y1_coord), (x2_coord, y2_coord), color, thickness=thickness) | |
| return Image.fromarray(image) | |
| class GradioInterface: | |
| def __init__(self): | |
| self.image_processor = ImageProcessor() | |
| def create_interface(self): | |
| app_styles = """ | |
| <style> | |
| /* Global Styles */ | |
| body, #root { | |
| font-family: Helvetica, Arial, sans-serif; | |
| background-color: #1a1a1a; | |
| color: #fafafa; | |
| } | |
| /* Header Styles */ | |
| .app-header { | |
| background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%); | |
| padding: 24px; | |
| border-radius: 8px; | |
| margin-bottom: 24px; | |
| text-align: center; | |
| } | |
| .app-title { | |
| font-size: 48px; | |
| margin: 0; | |
| color: #fafafa; | |
| } | |
| .app-subtitle { | |
| font-size: 24px; | |
| margin: 8px 0 16px; | |
| color: #fafafa; | |
| } | |
| .app-description { | |
| font-size: 16px; | |
| line-height: 1.6; | |
| opacity: 0.8; | |
| margin-bottom: 24px; | |
| } | |
| /* Button Styles */ | |
| .publication-links { | |
| display: flex; | |
| justify-content: center; | |
| flex-wrap: wrap; | |
| gap: 8px; | |
| margin-bottom: 16px; | |
| } | |
| .publication-link { | |
| display: inline-flex; | |
| align-items: center; | |
| padding: 8px 16px; | |
| background-color: #333; | |
| color: #fff !important; | |
| text-decoration: none !important; | |
| border-radius: 20px; | |
| font-size: 14px; | |
| transition: background-color 0.3s; | |
| } | |
| .publication-link:hover { | |
| background-color: #555; | |
| } | |
| .publication-link i { | |
| margin-right: 8px; | |
| } | |
| /* Content Styles */ | |
| .content-container { | |
| background-color: #2a2a2a; | |
| border-radius: 8px; | |
| padding: 24px; | |
| margin-bottom: 24px; | |
| } | |
| /* Image Styles */ | |
| .image-preview img { | |
| max-width: 512px; | |
| max-height: 512px; | |
| margin: 0 auto; | |
| border-radius: 4px; | |
| display: block; | |
| object-fit: contain; | |
| } | |
| /* Control Styles */ | |
| .control-panel { | |
| background-color: #333; | |
| padding: 16px; | |
| border-radius: 8px; | |
| margin-top: 16px; | |
| } | |
| /* Gradio Component Overrides */ | |
| .gr-button { | |
| background-color: #4a4a4a; | |
| color: #fff; | |
| border: none; | |
| border-radius: 4px; | |
| padding: 8px 16px; | |
| cursor: pointer; | |
| transition: background-color 0.3s; | |
| } | |
| .gr-button:hover { | |
| background-color: #5a5a5a; | |
| } | |
| .gr-input, .gr-dropdown { | |
| background-color: #3a3a3a; | |
| color: #fff; | |
| border: 1px solid #4a4a4a; | |
| border-radius: 4px; | |
| padding: 8px; | |
| } | |
| .gr-form { | |
| background-color: transparent; | |
| } | |
| .gr-panel { | |
| border: none; | |
| background-color: transparent; | |
| } | |
| /* Override any conflicting styles from Bulma */ | |
| .button.is-normal.is-rounded.is-dark { | |
| color: #fff !important; | |
| text-decoration: none !important; | |
| } | |
| </style> | |
| """ | |
| header_html = f""" | |
| <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css"> | |
| <link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css"> | |
| {app_styles} | |
| <div class="app-header"> | |
| <h1 class="app-title">Sapiens: Pose Estimation</h1> | |
| <h2 class="app-subtitle">ECCV 2024 (Oral)</h2> | |
| <p class="app-description"> | |
| Meta presents Sapiens, foundation models for human tasks pretrained on 300 million human images. | |
| This demo showcases the finetuned pose estimation model. <br> | |
| </p> | |
| <div class="publication-links"> | |
| <a href="https://arxiv.org/abs/2408.12569" class="publication-link"> | |
| <i class="fas fa-file-pdf"></i>arXiv | |
| </a> | |
| <a href="https://github.com/facebookresearch/sapiens" class="publication-link"> | |
| <i class="fab fa-github"></i>Code | |
| </a> | |
| <a href="https://about.meta.com/realitylabs/codecavatars/sapiens/" class="publication-link"> | |
| <i class="fas fa-globe"></i>Meta | |
| </a> | |
| <a href="https://rawalkhirodkar.github.io/sapiens" class="publication-link"> | |
| <i class="fas fa-chart-bar"></i>Results | |
| </a> | |
| </div> | |
| <div class="publication-links"> | |
| <a href="https://huggingface.co/spaces/facebook/sapiens_pose" class="publication-link"> | |
| <i class="fas fa-user"></i>Demo-Pose | |
| </a> | |
| <a href="https://huggingface.co/spaces/facebook/sapiens_seg" class="publication-link"> | |
| <i class="fas fa-puzzle-piece"></i>Demo-Seg | |
| </a> | |
| <a href="https://huggingface.co/spaces/facebook/sapiens_depth" class="publication-link"> | |
| <i class="fas fa-cube"></i>Demo-Depth | |
| </a> | |
| <a href="https://huggingface.co/spaces/facebook/sapiens_normal" class="publication-link"> | |
| <i class="fas fa-vector-square"></i>Demo-Normal | |
| </a> | |
| </div> | |
| </div> | |
| """ | |
| js_func = """ | |
| function refresh() { | |
| const url = new URL(window.location); | |
| if (url.searchParams.get('__theme') !== 'dark') { | |
| url.searchParams.set('__theme', 'dark'); | |
| window.location.href = url.href; | |
| } | |
| } | |
| """ | |
| def process_image(image, model_name, kpt_threshold): | |
| result_image, keypoints = self.image_processor.process_image(image, model_name, kpt_threshold) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w') as json_file: | |
| json.dump(keypoints, json_file) | |
| json_file_path = json_file.name | |
| return result_image, json_file_path | |
| with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo: | |
| gr.HTML(header_html) | |
| with gr.Row(elem_classes="content-container"): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview") | |
| with gr.Row(): | |
| model_name = gr.Dropdown( | |
| label="Model Size", | |
| choices=list(Config.CHECKPOINTS.keys()), | |
| value="1b", | |
| ) | |
| kpt_threshold = gr.Dropdown( | |
| label="Min Keypoint Confidence", | |
| choices=["0.1", "0.2", "0.3", "0.4", "0.5", "0.6", "0.7", "0.8", "0.9"], | |
| value="0.3", | |
| ) | |
| example_model = gr.Examples( | |
| inputs=input_image, | |
| examples_per_page=14, | |
| examples=[ | |
| os.path.join(Config.ASSETS_DIR, "images", img) | |
| for img in os.listdir(os.path.join(Config.ASSETS_DIR, "images")) | |
| ], | |
| ) | |
| with gr.Column(): | |
| result_image = gr.Image(label="Pose-308 Result", type="pil", elem_classes="image-preview") | |
| json_output = gr.File(label="Pose-308 Output (.json)") | |
| run_button = gr.Button("Run") | |
| run_button.click( | |
| fn=process_image, | |
| inputs=[input_image, model_name, kpt_threshold], | |
| outputs=[result_image, json_output], | |
| ) | |
| return demo | |
| def main(): | |
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| interface = GradioInterface() | |
| demo = interface.create_interface() | |
| demo.launch(share=False) | |
| if __name__ == "__main__": | |
| main() |