sapiens-pose / app.py
rawalkhirodkar's picture
Add better checkpoints
2d4ff18
import os
from typing import List
import spaces
import gradio as gr
import numpy as np
import torch
import json
import tempfile
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import cv2
from gradio.themes.utils import sizes
from classes_and_palettes import (
COCO_KPTS_COLORS,
COCO_WHOLEBODY_KPTS_COLORS,
GOLIATH_KPTS_COLORS,
GOLIATH_SKELETON_INFO,
GOLIATH_KEYPOINTS
)
import os
import sys
import subprocess
import importlib.util
def is_package_installed(package_name):
return importlib.util.find_spec(package_name) is not None
def find_wheel(package_path):
dist_dir = os.path.join(package_path, "dist")
if os.path.exists(dist_dir):
wheel_files = [f for f in os.listdir(dist_dir) if f.endswith('.whl')]
if wheel_files:
return os.path.join(dist_dir, wheel_files[0])
return None
def install_from_wheel(package_name, package_path):
wheel_file = find_wheel(package_path)
if wheel_file:
print(f"Installing {package_name} from wheel: {wheel_file}")
subprocess.check_call([sys.executable, "-m", "pip", "install", wheel_file])
else:
print(f"{package_name} wheel not found in {package_path}. Please build it first.")
sys.exit(1)
def install_local_packages():
packages = [
("mmengine", "./external/engine"),
("mmcv", "./external/cv"),
("mmdet", "./external/det")
]
for package_name, package_path in packages:
if not is_package_installed(package_name):
print(f"Installing {package_name}...")
install_from_wheel(package_name, package_path)
else:
print(f"{package_name} is already installed.")
# Run the installation at the start of your app
install_local_packages()
from detector_utils import (
adapt_mmdet_pipeline,
init_detector,
process_images_detector,
)
class Config:
ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets')
CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
CHECKPOINTS = {
"0.3b": "sapiens_0.3b_goliath_best_goliath_AP_573_torchscript.pt2",
"0.6b": "sapiens_0.6b_goliath_best_goliath_AP_609_torchscript.pt2",
"1b": "sapiens_1b_goliath_best_goliath_AP_639_torchscript.pt2",
}
DETECTION_CHECKPOINT = os.path.join(CHECKPOINTS_DIR, 'rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth')
DETECTION_CONFIG = os.path.join(ASSETS_DIR, 'rtmdet_m_640-8xb32_coco-person_no_nms.py')
class ModelManager:
@staticmethod
def load_model(checkpoint_name: str):
if checkpoint_name is None:
return None
checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, checkpoint_name)
model = torch.jit.load(checkpoint_path)
model.eval()
model.to("cuda")
return model
@staticmethod
@torch.inference_mode()
def run_model(model, input_tensor):
return model(input_tensor)
class ImageProcessor:
def __init__(self):
self.transform = transforms.Compose([
transforms.Resize((1024, 768)),
transforms.ToTensor(),
transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255],
std=[58.5/255, 57.0/255, 57.5/255])
])
self.detector = init_detector(
Config.DETECTION_CONFIG, Config.DETECTION_CHECKPOINT, device='cpu'
)
self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg)
def detect_persons(self, image: Image.Image):
# Convert PIL Image to tensor
image = np.array(image)
image = np.expand_dims(image, axis=0)
# Perform person detection
bboxes_batch = process_images_detector(
image,
self.detector
)
bboxes = self.get_person_bboxes(bboxes_batch[0]) # Get bboxes for the first (and only) image
return bboxes
def get_person_bboxes(self, bboxes_batch, score_thr=0.3):
person_bboxes = []
for bbox in bboxes_batch:
if len(bbox) == 5: # [x1, y1, x2, y2, score]
if bbox[4] > score_thr:
person_bboxes.append(bbox)
elif len(bbox) == 4: # [x1, y1, x2, y2]
person_bboxes.append(bbox + [1.0]) # Add a default score of 1.0
return person_bboxes
@spaces.GPU
@torch.inference_mode()
def estimate_pose(self, image: Image.Image, bboxes: List[List[float]], model_name: str, kpt_threshold: float):
pose_model = ModelManager.load_model(Config.CHECKPOINTS[model_name])
result_image = image.copy()
all_keypoints = [] # List to store keypoints for all persons
for bbox in bboxes:
cropped_img = self.crop_image(result_image, bbox)
input_tensor = self.transform(cropped_img).unsqueeze(0).to("cuda")
heatmaps = ModelManager.run_model(pose_model, input_tensor)
keypoints = self.heatmaps_to_keypoints(heatmaps[0].cpu().numpy(), bbox)
all_keypoints.append(keypoints) # Collect keypoints
result_image = self.draw_keypoints(result_image, keypoints, bbox, kpt_threshold)
return result_image, all_keypoints
def process_image(self, image: Image.Image, model_name: str, kpt_threshold: str):
bboxes = self.detect_persons(image)
result_image, keypoints = self.estimate_pose(image, bboxes, model_name, float(kpt_threshold))
return result_image, keypoints
def crop_image(self, image, bbox):
if len(bbox) == 4:
x1, y1, x2, y2 = map(int, bbox)
elif len(bbox) >= 5:
x1, y1, x2, y2, _ = map(int, bbox[:5])
else:
raise ValueError(f"Unexpected bbox format: {bbox}")
crop = image.crop((x1, y1, x2, y2))
return crop
@staticmethod
def heatmaps_to_keypoints(heatmaps, bbox):
num_joints = heatmaps.shape[0] # Should be 308
keypoints = {}
x1, y1, x2, y2 = map(int, bbox[:4])
bbox_width = x2 - x1
bbox_height = y2 - y1
for i, name in enumerate(GOLIATH_KEYPOINTS):
if i < num_joints:
heatmap = heatmaps[i]
y, x = np.unravel_index(np.argmax(heatmap), heatmap.shape)
conf = heatmap[y, x]
# Convert coordinates to image frame
x_image = x * bbox_width / 192 + x1
y_image = y * bbox_height / 256 + y1
keypoints[name] = (float(x_image), float(y_image), float(conf))
return keypoints
@staticmethod
def draw_keypoints(image, keypoints, bbox, kpt_threshold):
image = np.array(image)
# Handle both 4 and 5-element bounding boxes
if len(bbox) == 4:
x1, y1, x2, y2 = map(int, bbox)
elif len(bbox) >= 5:
x1, y1, x2, y2, _ = map(int, bbox[:5])
else:
raise ValueError(f"Unexpected bbox format: {bbox}")
# Calculate adaptive radius and thickness based on bounding box size
bbox_width = x2 - x1
bbox_height = y2 - y1
bbox_size = np.sqrt(bbox_width * bbox_height)
radius = max(1, int(bbox_size * 0.006)) # minimum 1 pixel
thickness = max(1, int(bbox_size * 0.006)) # minimum 1 pixel
bbox_thickness = max(1, thickness//4)
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), bbox_thickness)
# Draw keypoints
for i, (name, (x, y, conf)) in enumerate(keypoints.items()):
if conf > kpt_threshold and i < len(GOLIATH_KPTS_COLORS):
x_coord = int(x)
y_coord = int(y)
color = GOLIATH_KPTS_COLORS[i]
cv2.circle(image, (x_coord, y_coord), radius, color, -1)
# Draw skeleton
for _, link_info in GOLIATH_SKELETON_INFO.items():
pt1_name, pt2_name = link_info['link']
color = link_info['color']
if pt1_name in keypoints and pt2_name in keypoints:
pt1 = keypoints[pt1_name]
pt2 = keypoints[pt2_name]
if pt1[2] > kpt_threshold and pt2[2] > kpt_threshold:
x1_coord = int(pt1[0])
y1_coord = int(pt1[1])
x2_coord = int(pt2[0])
y2_coord = int(pt2[1])
cv2.line(image, (x1_coord, y1_coord), (x2_coord, y2_coord), color, thickness=thickness)
return Image.fromarray(image)
class GradioInterface:
def __init__(self):
self.image_processor = ImageProcessor()
def create_interface(self):
app_styles = """
<style>
/* Global Styles */
body, #root {
font-family: Helvetica, Arial, sans-serif;
background-color: #1a1a1a;
color: #fafafa;
}
/* Header Styles */
.app-header {
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
padding: 24px;
border-radius: 8px;
margin-bottom: 24px;
text-align: center;
}
.app-title {
font-size: 48px;
margin: 0;
color: #fafafa;
}
.app-subtitle {
font-size: 24px;
margin: 8px 0 16px;
color: #fafafa;
}
.app-description {
font-size: 16px;
line-height: 1.6;
opacity: 0.8;
margin-bottom: 24px;
}
/* Button Styles */
.publication-links {
display: flex;
justify-content: center;
flex-wrap: wrap;
gap: 8px;
margin-bottom: 16px;
}
.publication-link {
display: inline-flex;
align-items: center;
padding: 8px 16px;
background-color: #333;
color: #fff !important;
text-decoration: none !important;
border-radius: 20px;
font-size: 14px;
transition: background-color 0.3s;
}
.publication-link:hover {
background-color: #555;
}
.publication-link i {
margin-right: 8px;
}
/* Content Styles */
.content-container {
background-color: #2a2a2a;
border-radius: 8px;
padding: 24px;
margin-bottom: 24px;
}
/* Image Styles */
.image-preview img {
max-width: 512px;
max-height: 512px;
margin: 0 auto;
border-radius: 4px;
display: block;
object-fit: contain;
}
/* Control Styles */
.control-panel {
background-color: #333;
padding: 16px;
border-radius: 8px;
margin-top: 16px;
}
/* Gradio Component Overrides */
.gr-button {
background-color: #4a4a4a;
color: #fff;
border: none;
border-radius: 4px;
padding: 8px 16px;
cursor: pointer;
transition: background-color 0.3s;
}
.gr-button:hover {
background-color: #5a5a5a;
}
.gr-input, .gr-dropdown {
background-color: #3a3a3a;
color: #fff;
border: 1px solid #4a4a4a;
border-radius: 4px;
padding: 8px;
}
.gr-form {
background-color: transparent;
}
.gr-panel {
border: none;
background-color: transparent;
}
/* Override any conflicting styles from Bulma */
.button.is-normal.is-rounded.is-dark {
color: #fff !important;
text-decoration: none !important;
}
</style>
"""
header_html = f"""
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css">
{app_styles}
<div class="app-header">
<h1 class="app-title">Sapiens: Pose Estimation</h1>
<h2 class="app-subtitle">ECCV 2024 (Oral)</h2>
<p class="app-description">
Meta presents Sapiens, foundation models for human tasks pretrained on 300 million human images.
This demo showcases the finetuned pose estimation model. <br>
</p>
<div class="publication-links">
<a href="https://arxiv.org/abs/2408.12569" class="publication-link">
<i class="fas fa-file-pdf"></i>arXiv
</a>
<a href="https://github.com/facebookresearch/sapiens" class="publication-link">
<i class="fab fa-github"></i>Code
</a>
<a href="https://about.meta.com/realitylabs/codecavatars/sapiens/" class="publication-link">
<i class="fas fa-globe"></i>Meta
</a>
<a href="https://rawalkhirodkar.github.io/sapiens" class="publication-link">
<i class="fas fa-chart-bar"></i>Results
</a>
</div>
<div class="publication-links">
<a href="https://huggingface.co/spaces/facebook/sapiens_pose" class="publication-link">
<i class="fas fa-user"></i>Demo-Pose
</a>
<a href="https://huggingface.co/spaces/facebook/sapiens_seg" class="publication-link">
<i class="fas fa-puzzle-piece"></i>Demo-Seg
</a>
<a href="https://huggingface.co/spaces/facebook/sapiens_depth" class="publication-link">
<i class="fas fa-cube"></i>Demo-Depth
</a>
<a href="https://huggingface.co/spaces/facebook/sapiens_normal" class="publication-link">
<i class="fas fa-vector-square"></i>Demo-Normal
</a>
</div>
</div>
"""
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'dark') {
url.searchParams.set('__theme', 'dark');
window.location.href = url.href;
}
}
"""
def process_image(image, model_name, kpt_threshold):
result_image, keypoints = self.image_processor.process_image(image, model_name, kpt_threshold)
with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w') as json_file:
json.dump(keypoints, json_file)
json_file_path = json_file.name
return result_image, json_file_path
with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo:
gr.HTML(header_html)
with gr.Row(elem_classes="content-container"):
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview")
with gr.Row():
model_name = gr.Dropdown(
label="Model Size",
choices=list(Config.CHECKPOINTS.keys()),
value="1b",
)
kpt_threshold = gr.Dropdown(
label="Min Keypoint Confidence",
choices=["0.1", "0.2", "0.3", "0.4", "0.5", "0.6", "0.7", "0.8", "0.9"],
value="0.3",
)
example_model = gr.Examples(
inputs=input_image,
examples_per_page=14,
examples=[
os.path.join(Config.ASSETS_DIR, "images", img)
for img in os.listdir(os.path.join(Config.ASSETS_DIR, "images"))
],
)
with gr.Column():
result_image = gr.Image(label="Pose-308 Result", type="pil", elem_classes="image-preview")
json_output = gr.File(label="Pose-308 Output (.json)")
run_button = gr.Button("Run")
run_button.click(
fn=process_image,
inputs=[input_image, model_name, kpt_threshold],
outputs=[result_image, json_output],
)
return demo
def main():
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
interface = GradioInterface()
demo = interface.create_interface()
demo.launch(share=False)
if __name__ == "__main__":
main()