Spaces:
Running
on
Zero
Running
on
Zero
Add application file
Browse files- __pycache__/body_features.cpython-38.pyc +0 -0
- __pycache__/crop_face.cpython-38.pyc +0 -0
- __pycache__/crop_hands.cpython-38.pyc +0 -0
- __pycache__/dinov2_features.cpython-38.pyc +0 -0
- __pycache__/inference.cpython-38.pyc +0 -0
- __pycache__/kpe_mediapipe.cpython-38.pyc +0 -0
- __pycache__/shubert.cpython-38.pyc +0 -0
- app.py +536 -8
- attention.py +107 -0
- block.py +322 -0
- body_features.py +358 -0
- crop_face.py +415 -0
- crop_hands.py +445 -0
- dinov2_features.py +351 -0
- features.py +115 -0
- inference.py +738 -0
- kpe_mediapipe.py +408 -0
- shubert.py +479 -0
- shubert_inference.py +439 -0
__pycache__/body_features.cpython-38.pyc
ADDED
Binary file (10.2 kB). View file
|
|
__pycache__/crop_face.cpython-38.pyc
ADDED
Binary file (12.1 kB). View file
|
|
__pycache__/crop_hands.cpython-38.pyc
ADDED
Binary file (12.1 kB). View file
|
|
__pycache__/dinov2_features.cpython-38.pyc
ADDED
Binary file (10.6 kB). View file
|
|
__pycache__/inference.cpython-38.pyc
ADDED
Binary file (19.7 kB). View file
|
|
__pycache__/kpe_mediapipe.cpython-38.pyc
ADDED
Binary file (12 kB). View file
|
|
__pycache__/shubert.cpython-38.pyc
ADDED
Binary file (8.64 kB). View file
|
|
app.py
CHANGED
@@ -1,14 +1,542 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import spaces
|
3 |
-
import torch
|
4 |
|
5 |
-
|
6 |
-
|
|
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
@spaces.GPU
|
9 |
-
def
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
import huggingface_hub
|
5 |
+
import shutil
|
6 |
+
import logging
|
7 |
+
import traceback
|
8 |
+
from features import SHuBERTProcessor
|
9 |
import spaces
|
|
|
10 |
|
11 |
+
# Set up logging
|
12 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
|
15 |
+
# Set writable cache directories
|
16 |
+
def setup_cache_directories():
|
17 |
+
"""Set up cache directories with proper error handling"""
|
18 |
+
try:
|
19 |
+
cache_dirs = {
|
20 |
+
'MPLCONFIGDIR': '/tmp/matplotlib',
|
21 |
+
'TRANSFORMERS_CACHE': '/tmp/huggingface',
|
22 |
+
'HF_HOME': '/tmp/huggingface',
|
23 |
+
'FONTCONFIG_PATH': '/tmp/fontconfig',
|
24 |
+
'TORCH_HOME': '/tmp/torch', # PyTorch cache directory
|
25 |
+
}
|
26 |
+
|
27 |
+
for env_var, path in cache_dirs.items():
|
28 |
+
os.environ[env_var] = path
|
29 |
+
os.makedirs(path, exist_ok=True, mode=0o777)
|
30 |
+
logger.info(f"Cache directory created: {env_var} = {path}")
|
31 |
+
|
32 |
+
# Also set XDG_CACHE_HOME to override default .cache location
|
33 |
+
os.environ['XDG_CACHE_HOME'] = '/tmp/cache'
|
34 |
+
os.makedirs('/tmp/cache', exist_ok=True, mode=0o777)
|
35 |
+
logger.info(f"Cache directory created: XDG_CACHE_HOME = /tmp/cache")
|
36 |
+
|
37 |
+
# Clear any existing PyTorch Hub cache to avoid corruption issues
|
38 |
+
torch_hub_dir = '/tmp/torch/hub'
|
39 |
+
if os.path.exists(torch_hub_dir):
|
40 |
+
shutil.rmtree(torch_hub_dir)
|
41 |
+
logger.info("Cleared existing PyTorch Hub cache")
|
42 |
+
os.makedirs(torch_hub_dir, exist_ok=True, mode=0o777)
|
43 |
+
logger.info(f"Created clean PyTorch Hub cache directory: {torch_hub_dir}")
|
44 |
+
|
45 |
+
# Copy updated DINOv2 files to torch cache after clearing
|
46 |
+
# This ensures they're available when PyTorch Hub downloads the repo
|
47 |
+
try:
|
48 |
+
src_dir = os.path.dirname(os.path.abspath(__file__))
|
49 |
+
target_dir = '/tmp/torch/hub/facebookresearch_dinov2_main/dinov2/layers'
|
50 |
+
|
51 |
+
for filename in ['attention.py', 'block.py']:
|
52 |
+
src_path = os.path.join(src_dir, filename)
|
53 |
+
if os.path.exists(src_path):
|
54 |
+
# We'll copy these after the initial hub download
|
55 |
+
logger.info(f"Found {filename} in project directory - will copy after hub download")
|
56 |
+
else:
|
57 |
+
logger.warning(f"Could not find {filename} in project directory")
|
58 |
+
except Exception as e:
|
59 |
+
logger.warning(f"Error preparing DINOv2 files: {e}")
|
60 |
+
|
61 |
+
return True
|
62 |
+
except Exception as e:
|
63 |
+
logger.error(f"Error creating cache directories: {str(e)}")
|
64 |
+
return False
|
65 |
+
|
66 |
+
# Configuration for Hugging Face Spaces
|
67 |
+
MODEL_REPO = "ShesterG/SHuBERT"
|
68 |
+
TOKEN = os.environ.get('HF_TOKEN')
|
69 |
+
|
70 |
+
def validate_environment():
|
71 |
+
"""Validate required environment variables and setup"""
|
72 |
+
if not TOKEN:
|
73 |
+
raise ValueError("HF_TOKEN environment variable not set. This is required to access private model repository.")
|
74 |
+
|
75 |
+
# Check available disk space
|
76 |
+
free_space = shutil.disk_usage('/').free / (1024*1024*1024) # GB
|
77 |
+
logger.info(f"Available disk space: {free_space:.2f} GB")
|
78 |
+
|
79 |
+
if free_space < 2: # Less than 2GB
|
80 |
+
logger.warning("Low disk space available. This may cause issues.")
|
81 |
+
|
82 |
+
return True
|
83 |
+
|
84 |
+
def download_models():
|
85 |
+
"""Download all required models from Hugging Face Hub with enhanced error handling"""
|
86 |
+
logger.info("Starting model download process...")
|
87 |
+
|
88 |
+
try:
|
89 |
+
# Validate environment first
|
90 |
+
validate_environment()
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
logger.info("Downloading entire models folder...")
|
95 |
+
|
96 |
+
# Download the entire models folder
|
97 |
+
models_path = huggingface_hub.snapshot_download(
|
98 |
+
repo_id=MODEL_REPO,
|
99 |
+
allow_patterns="models/*", # Download everything in models folder
|
100 |
+
token=TOKEN,
|
101 |
+
cache_dir=os.environ['TRANSFORMERS_CACHE']
|
102 |
+
)
|
103 |
+
|
104 |
+
# Build config dict with expected file paths
|
105 |
+
config = {
|
106 |
+
'yolov8_model_path': os.path.join(models_path, "models/yolov8n.pt"),
|
107 |
+
'dino_face_model_path': os.path.join(models_path, "models/dinov2face.pth"),
|
108 |
+
'dino_hands_model_path': os.path.join(models_path, "models/dinov2hand.pth"),
|
109 |
+
'mediapipe_face_model_path': os.path.join(models_path, "models/face_landmarker_v2_with_blendshapes.task"),
|
110 |
+
'mediapipe_hands_model_path': os.path.join(models_path, "models/hand_landmarker.task"),
|
111 |
+
'shubert_model_path': os.path.join(models_path, "models/checkpoint_836_400000.pt"),
|
112 |
+
'slt_model_config': os.path.join(models_path, "models/byt5_base/config.json"),
|
113 |
+
'slt_model_checkpoint': os.path.join(models_path, "models/checkpoint-11625"),
|
114 |
+
'slt_tokenizer_checkpoint': os.path.join(models_path, "models/byt5_base"),
|
115 |
+
'temp_dir': 'temp'
|
116 |
+
}
|
117 |
+
|
118 |
+
# Verify all required files and folders exist
|
119 |
+
logger.info("Verifying downloaded files...")
|
120 |
+
missing_files = []
|
121 |
+
|
122 |
+
for key, path in config.items():
|
123 |
+
if key == 'temp_dir': # Skip temp_dir check
|
124 |
+
continue
|
125 |
+
|
126 |
+
if not os.path.exists(path):
|
127 |
+
missing_files.append(f"{key}: {path}")
|
128 |
+
logger.error(f"Missing: {path}")
|
129 |
+
else:
|
130 |
+
logger.info(f"✓ Found: {path}")
|
131 |
+
|
132 |
+
if missing_files:
|
133 |
+
logger.error(f"Missing {len(missing_files)} required files/folders:")
|
134 |
+
for missing in missing_files:
|
135 |
+
logger.error(f" - {missing}")
|
136 |
+
raise FileNotFoundError(f"Required files not found: {missing_files}")
|
137 |
+
|
138 |
+
logger.info("All models downloaded and verified successfully!")
|
139 |
+
logger.info(f"Models root path: {models_path}")
|
140 |
+
|
141 |
+
return config
|
142 |
+
|
143 |
+
except Exception as e:
|
144 |
+
logger.error(f"Error downloading models: {str(e)}")
|
145 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
146 |
+
|
147 |
+
# Additional debugging info
|
148 |
+
try:
|
149 |
+
cache_contents = os.listdir(os.environ['TRANSFORMERS_CACHE'])
|
150 |
+
logger.info(f"Cache directory contents: {cache_contents}")
|
151 |
+
except:
|
152 |
+
logger.error("Cannot access cache directory")
|
153 |
+
|
154 |
+
return None
|
155 |
+
|
156 |
+
def initialize_processor(config):
|
157 |
+
"""Initialize SHuBERT processor with error handling"""
|
158 |
+
try:
|
159 |
+
logger.info("Initializing SHuBERT processor...")
|
160 |
+
processor = SHuBERTProcessor(config)
|
161 |
+
logger.info("SHuBERT processor initialized successfully!")
|
162 |
+
return processor
|
163 |
+
except Exception as e:
|
164 |
+
logger.error(f"Error initializing SHuBERT processor: {str(e)}")
|
165 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
166 |
+
return None
|
167 |
+
|
168 |
+
# Initialize the application
|
169 |
+
def initialize_app():
|
170 |
+
"""Initialize the entire application with comprehensive error handling"""
|
171 |
+
try:
|
172 |
+
# Setup cache directories
|
173 |
+
if not setup_cache_directories():
|
174 |
+
raise RuntimeError("Failed to setup cache directories")
|
175 |
+
|
176 |
+
# Download models
|
177 |
+
config = download_models()
|
178 |
+
if config is None:
|
179 |
+
raise RuntimeError("Failed to download models")
|
180 |
+
|
181 |
+
# Initialize processor
|
182 |
+
processor = initialize_processor(config)
|
183 |
+
if processor is None:
|
184 |
+
raise RuntimeError("Failed to initialize SHuBERT processor")
|
185 |
+
|
186 |
+
logger.info("Application initialized successfully!")
|
187 |
+
return config, processor
|
188 |
+
|
189 |
+
except Exception as e:
|
190 |
+
error_msg = f"Application initialization failed: {str(e)}"
|
191 |
+
logger.error(error_msg)
|
192 |
+
logger.error(f"Full traceback: {traceback.format_exc()}")
|
193 |
+
raise RuntimeError(error_msg)
|
194 |
+
|
195 |
+
# Global variables for application state
|
196 |
+
config = None
|
197 |
+
processor = None
|
198 |
+
initialization_error = None
|
199 |
+
|
200 |
+
try:
|
201 |
+
config, processor = initialize_app()
|
202 |
+
except Exception as e:
|
203 |
+
initialization_error = str(e)
|
204 |
+
logger.error(f"Startup failed: {initialization_error}")
|
205 |
+
|
206 |
+
def copy_dinov2_files_if_needed():
|
207 |
+
"""Copy updated DINOv2 files after PyTorch Hub download if needed"""
|
208 |
+
try:
|
209 |
+
src_dir = os.path.dirname(os.path.abspath(__file__))
|
210 |
+
target_dir = '/tmp/torch/hub/facebookresearch_dinov2_main/dinov2/layers'
|
211 |
+
|
212 |
+
# Check if PyTorch Hub has downloaded the repository
|
213 |
+
hub_main_dir = '/tmp/torch/hub/facebookresearch_dinov2_main'
|
214 |
+
|
215 |
+
if os.path.exists(hub_main_dir):
|
216 |
+
# Ensure the target directory exists
|
217 |
+
os.makedirs(target_dir, exist_ok=True)
|
218 |
+
|
219 |
+
files_copied = 0
|
220 |
+
for filename in ['attention.py', 'block.py']:
|
221 |
+
src_path = os.path.join(src_dir, filename)
|
222 |
+
target_path = os.path.join(target_dir, filename)
|
223 |
+
|
224 |
+
if os.path.exists(src_path):
|
225 |
+
# Always overwrite with our robust versions
|
226 |
+
shutil.copy2(src_path, target_path)
|
227 |
+
# Make sure it's readable
|
228 |
+
os.chmod(target_path, 0o644)
|
229 |
+
logger.info(f"Replaced {filename} with robust version (numpy/Python 3.8 compatible)")
|
230 |
+
files_copied += 1
|
231 |
+
else:
|
232 |
+
logger.error(f"Source file not found: {src_path}")
|
233 |
+
|
234 |
+
if files_copied > 0:
|
235 |
+
# Clear Python's import cache to ensure new files are used
|
236 |
+
import importlib
|
237 |
+
import sys
|
238 |
+
|
239 |
+
# Remove any cached imports of dinov2 modules
|
240 |
+
modules_to_remove = [key for key in sys.modules.keys() if 'dinov2' in key]
|
241 |
+
for module in modules_to_remove:
|
242 |
+
del sys.modules[module]
|
243 |
+
logger.info(f"Cleared cached import: {module}")
|
244 |
+
|
245 |
+
logger.info(f"Successfully replaced {files_copied} DINOv2 files with robust versions")
|
246 |
+
return True
|
247 |
+
else:
|
248 |
+
logger.info("PyTorch Hub repository not yet downloaded")
|
249 |
+
return False
|
250 |
+
|
251 |
+
except Exception as e:
|
252 |
+
logger.error(f"Error copying DINOv2 files: {e}")
|
253 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
254 |
+
return False
|
255 |
+
|
256 |
@spaces.GPU
|
257 |
+
def process_video(video_file):
|
258 |
+
"""Process uploaded video file with enhanced error handling"""
|
259 |
+
# Check if initialization was successful
|
260 |
+
if initialization_error:
|
261 |
+
return f"Application initialization failed: {initialization_error}\n\nPlease check the logs for more details."
|
262 |
+
|
263 |
+
if processor is None:
|
264 |
+
return "Error: Model not initialized properly. Please check the logs."
|
265 |
+
|
266 |
+
if video_file is None:
|
267 |
+
return "Please upload a video file."
|
268 |
+
|
269 |
+
logger.info(f"=== Starting video processing ===")
|
270 |
+
logger.info(f"Video file input: {video_file}")
|
271 |
+
logger.info(f"Video file type: {type(video_file)}")
|
272 |
+
|
273 |
+
try:
|
274 |
+
# Create temp directory with proper permissions
|
275 |
+
temp_dir = config['temp_dir']
|
276 |
+
os.makedirs(temp_dir, exist_ok=True, mode=0o777)
|
277 |
+
logger.info(f"Temp directory: {temp_dir}")
|
278 |
+
|
279 |
+
# Generate unique filename to avoid conflicts
|
280 |
+
import time
|
281 |
+
timestamp = str(int(time.time() * 1000))
|
282 |
+
file_extension = '.mp4' # Default extension
|
283 |
+
|
284 |
+
# Try to get original extension if available
|
285 |
+
try:
|
286 |
+
if hasattr(video_file, 'name') and video_file.name:
|
287 |
+
file_extension = os.path.splitext(video_file.name)[1] or '.mp4'
|
288 |
+
elif isinstance(video_file, str):
|
289 |
+
file_extension = os.path.splitext(video_file)[1] or '.mp4'
|
290 |
+
except:
|
291 |
+
pass
|
292 |
+
|
293 |
+
temp_video_path = os.path.join(temp_dir, f"video_{timestamp}{file_extension}")
|
294 |
+
logger.info(f"Target temp video path: {temp_video_path}")
|
295 |
+
|
296 |
+
# Handle Gradio file upload - video_file is typically a string path to temp file
|
297 |
+
logger.info(f"Processing video file: {video_file} (type: {type(video_file)})")
|
298 |
+
|
299 |
+
if isinstance(video_file, str):
|
300 |
+
# Gradio provides a file path string
|
301 |
+
source_path = video_file
|
302 |
+
|
303 |
+
# Handle both absolute and relative paths
|
304 |
+
if not os.path.isabs(source_path):
|
305 |
+
# Try current working directory first
|
306 |
+
abs_source_path = os.path.abspath(source_path)
|
307 |
+
logger.info(f"Converting relative path {source_path} to absolute: {abs_source_path}")
|
308 |
+
if os.path.exists(abs_source_path):
|
309 |
+
source_path = abs_source_path
|
310 |
+
else:
|
311 |
+
# Try looking in common Gradio temp directories
|
312 |
+
possible_paths = [
|
313 |
+
source_path,
|
314 |
+
os.path.join('/tmp', os.path.basename(source_path)),
|
315 |
+
os.path.join('/tmp/gradio', os.path.basename(source_path)),
|
316 |
+
abs_source_path
|
317 |
+
]
|
318 |
+
|
319 |
+
found_path = None
|
320 |
+
for path in possible_paths:
|
321 |
+
logger.info(f"Checking path: {path}")
|
322 |
+
if os.path.exists(path):
|
323 |
+
found_path = path
|
324 |
+
logger.info(f"Found file at: {path}")
|
325 |
+
break
|
326 |
+
|
327 |
+
if found_path:
|
328 |
+
source_path = found_path
|
329 |
+
else:
|
330 |
+
logger.error(f"Could not find source file in any expected location")
|
331 |
+
logger.error(f"Tried paths: {possible_paths}")
|
332 |
+
raise FileNotFoundError(f"Source video file not found in any expected location: {video_file}")
|
333 |
+
|
334 |
+
logger.info(f"Final source file path: {source_path}")
|
335 |
+
logger.info(f"Source file exists: {os.path.exists(source_path)}")
|
336 |
+
|
337 |
+
if os.path.exists(source_path):
|
338 |
+
try:
|
339 |
+
# Check source file permissions and size
|
340 |
+
stat_info = os.stat(source_path)
|
341 |
+
logger.info(f"Source file size: {stat_info.st_size} bytes, mode: {oct(stat_info.st_mode)}")
|
342 |
+
|
343 |
+
# Try to read the file content
|
344 |
+
with open(source_path, 'rb') as src:
|
345 |
+
content = src.read()
|
346 |
+
logger.info(f"Successfully read {len(content)} bytes from source")
|
347 |
+
|
348 |
+
# Write to destination (with a different name to avoid conflicts)
|
349 |
+
final_temp_path = os.path.join(temp_dir, f"processed_{timestamp}{file_extension}")
|
350 |
+
with open(final_temp_path, 'wb') as dst:
|
351 |
+
dst.write(content)
|
352 |
+
logger.info(f"Successfully wrote to destination: {final_temp_path}")
|
353 |
+
|
354 |
+
# Update temp_video_path to the final location
|
355 |
+
temp_video_path = final_temp_path
|
356 |
+
|
357 |
+
except PermissionError as e:
|
358 |
+
logger.error(f"Permission error reading source file: {e}")
|
359 |
+
# Try alternative approach - use a completely different temp location
|
360 |
+
try:
|
361 |
+
import tempfile
|
362 |
+
# Create a new temporary file in system temp directory
|
363 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp:
|
364 |
+
alternative_temp_path = tmp.name
|
365 |
+
|
366 |
+
logger.info(f"Trying alternative temp path: {alternative_temp_path}")
|
367 |
+
|
368 |
+
# Try to copy using system copy command as fallback
|
369 |
+
import subprocess
|
370 |
+
result = subprocess.run(['cp', source_path, alternative_temp_path],
|
371 |
+
capture_output=True, text=True)
|
372 |
+
|
373 |
+
if result.returncode == 0:
|
374 |
+
logger.info("Successfully copied using system cp command")
|
375 |
+
temp_video_path = alternative_temp_path
|
376 |
+
else:
|
377 |
+
logger.error(f"System cp failed: {result.stderr}")
|
378 |
+
raise PermissionError(f"Cannot read video file due to permission restrictions: {e}")
|
379 |
+
|
380 |
+
except Exception as e2:
|
381 |
+
logger.error(f"Alternative copy method also failed: {e2}")
|
382 |
+
raise PermissionError(f"Cannot read video file due to permission restrictions: {e}")
|
383 |
+
else:
|
384 |
+
raise FileNotFoundError(f"Source video file not found: {source_path}")
|
385 |
+
|
386 |
+
elif hasattr(video_file, 'read'):
|
387 |
+
# If it's a file-like object with read method
|
388 |
+
try:
|
389 |
+
content = video_file.read()
|
390 |
+
with open(temp_video_path, 'wb') as f:
|
391 |
+
f.write(content)
|
392 |
+
logger.info(f"Saved video from file object: {temp_video_path} ({len(content)} bytes)")
|
393 |
+
except Exception as e:
|
394 |
+
logger.error(f"Error reading from file object: {e}")
|
395 |
+
raise ValueError(f"Cannot read from file object: {e}")
|
396 |
+
else:
|
397 |
+
# Handle other cases - try to extract file path or content
|
398 |
+
logger.info(f"Attempting to handle unknown file type: {type(video_file)}")
|
399 |
+
try:
|
400 |
+
# Check if it has a name attribute (common for file objects)
|
401 |
+
if hasattr(video_file, 'name'):
|
402 |
+
source_path = video_file.name
|
403 |
+
logger.info(f"Found name attribute: {source_path}")
|
404 |
+
|
405 |
+
if os.path.exists(source_path):
|
406 |
+
with open(source_path, 'rb') as src:
|
407 |
+
content = src.read()
|
408 |
+
with open(temp_video_path, 'wb') as dst:
|
409 |
+
dst.write(content)
|
410 |
+
logger.info(f"Successfully copied from name attribute")
|
411 |
+
else:
|
412 |
+
raise FileNotFoundError(f"File from name attribute not found: {source_path}")
|
413 |
+
else:
|
414 |
+
logger.error(f"Unsupported video file type: {type(video_file)}")
|
415 |
+
raise ValueError(f"Unsupported video file type: {type(video_file)}")
|
416 |
+
except Exception as e:
|
417 |
+
logger.error(f"Failed to handle unknown file type: {e}")
|
418 |
+
raise ValueError(f"Cannot process video file: {e}")
|
419 |
+
|
420 |
+
# Set proper permissions on the saved file
|
421 |
+
os.chmod(temp_video_path, 0o666)
|
422 |
+
|
423 |
+
# Verify file exists and has content
|
424 |
+
if not os.path.exists(temp_video_path) or os.path.getsize(temp_video_path) == 0:
|
425 |
+
raise ValueError("Video file is empty or could not be saved")
|
426 |
+
|
427 |
+
# Copy DINOv2 files if needed before processing
|
428 |
+
# This needs to happen right after PyTorch Hub downloads but before model loading
|
429 |
+
logger.info("Ensuring DINOv2 files are ready for processing...")
|
430 |
+
copy_dinov2_files_if_needed()
|
431 |
+
|
432 |
+
# Set up a monitoring patch for torch.hub.load to replace files immediately after download
|
433 |
+
original_torch_hub_load = None
|
434 |
+
try:
|
435 |
+
import torch.hub
|
436 |
+
original_torch_hub_load = torch.hub.load
|
437 |
+
|
438 |
+
def patched_torch_hub_load(*args, **kwargs):
|
439 |
+
logger.info(f"PyTorch Hub load called with: {args[0] if args else 'unknown'}")
|
440 |
+
|
441 |
+
# Call the original function first
|
442 |
+
result = original_torch_hub_load(*args, **kwargs)
|
443 |
+
|
444 |
+
# If this was a DINOv2 call, immediately replace the files
|
445 |
+
if args and 'dinov2' in str(args[0]):
|
446 |
+
logger.info("DINOv2 downloaded! Immediately replacing with robust versions...")
|
447 |
+
|
448 |
+
# Try multiple times to ensure files are replaced
|
449 |
+
import time
|
450 |
+
for attempt in range(5):
|
451 |
+
if copy_dinov2_files_if_needed():
|
452 |
+
logger.info("Successfully replaced DINOv2 files!")
|
453 |
+
break
|
454 |
+
else:
|
455 |
+
logger.info(f"Attempt {attempt + 1} failed, retrying in 1 second...")
|
456 |
+
time.sleep(1)
|
457 |
+
|
458 |
+
return result
|
459 |
+
|
460 |
+
# Temporarily patch torch.hub.load
|
461 |
+
torch.hub.load = patched_torch_hub_load
|
462 |
+
logger.info("Patched torch.hub.load to replace DINOv2 files after download")
|
463 |
+
except Exception as e:
|
464 |
+
logger.warning(f"Could not patch torch.hub.load: {e}")
|
465 |
+
|
466 |
+
logger.info(f"Processing video: {temp_video_path}")
|
467 |
+
try:
|
468 |
+
output_text = processor.process_video(temp_video_path)
|
469 |
+
finally:
|
470 |
+
# Restore original function
|
471 |
+
if original_torch_hub_load:
|
472 |
+
try:
|
473 |
+
import torch.hub
|
474 |
+
torch.hub.load = original_torch_hub_load
|
475 |
+
logger.info("Restored original torch.hub.load")
|
476 |
+
except:
|
477 |
+
pass
|
478 |
+
|
479 |
+
logger.info(f"Video processed successfully. Output: {output_text[:100]}...")
|
480 |
+
|
481 |
+
# Clean up temp file
|
482 |
+
if os.path.exists(temp_video_path):
|
483 |
+
os.remove(temp_video_path)
|
484 |
+
logger.info("Temporary video file cleaned up")
|
485 |
+
|
486 |
+
return output_text
|
487 |
+
|
488 |
+
except Exception as e:
|
489 |
+
logger.error(f"Error processing video: {str(e)}")
|
490 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
491 |
+
return f"Error processing video: {str(e)}\n\nPlease check that your video is a valid ASL video under 10 seconds."
|
492 |
+
|
493 |
+
# Create Gradio interface
|
494 |
+
def create_interface():
|
495 |
+
"""Create the Gradio interface"""
|
496 |
+
description = """
|
497 |
+
Upload an ASL* video to get an English translation. *Sign languages belonging to the same sign language family as ASL (e.g. Ghanaian Sign Language, as well as others listed in Table 7, Row 1 of https://aclanthology.org/2023.findings-emnlp.664.pdf) might also have non-trivial performance, although the model is trained only on ASL data.
|
498 |
+
|
499 |
+
|
500 |
+
This app uses TTIC's foundation model SHuBERT (introduced in an ACL 2025 paper, see http://shubert.pals.ttic.edu).
|
501 |
+
|
502 |
+
**Requirements:**
|
503 |
+
- We recommend that videos be under 60 seconds. Performance for longer videos has not been tested.
|
504 |
+
- The signer should be the main part of the video. Videos recorded from a phone camera, tablet, or personal computer should work well. Studio recordings where the signer is farther from the camera may not work as well.
|
505 |
+
- Supported formats: MP4, MOV
|
506 |
+
|
507 |
+
**Note:**
|
508 |
+
- Videos will be deleted after the output is generated.
|
509 |
+
- Inquires or Feedback? Please email us at [email protected]
|
510 |
+
"""
|
511 |
+
|
512 |
+
if initialization_error:
|
513 |
+
description += f"\n\n:warning: **Initialization Error:** {initialization_error}"
|
514 |
+
|
515 |
+
return gr.Interface(
|
516 |
+
fn=process_video,
|
517 |
+
inputs=gr.Video(label="ASL Video (under 60 seconds)", format="mp4"),
|
518 |
+
# inputs=gr.File(
|
519 |
+
# label="Upload ASL Video (under 60 seconds)",
|
520 |
+
# file_types=[".mp4", ".avi", ".mov", ".webm"],
|
521 |
+
# type="filepath" # This tells Gradio to provide the file path directly
|
522 |
+
# ),
|
523 |
+
outputs=gr.Textbox(label="English Translation", lines=5),
|
524 |
+
title="ASL Video to English Text Translation",
|
525 |
+
description=description,
|
526 |
+
article="",
|
527 |
+
examples=[],
|
528 |
+
allow_flagging="never"
|
529 |
+
)
|
530 |
+
|
531 |
+
# Create the demo
|
532 |
+
demo = create_interface()
|
533 |
|
534 |
+
if __name__ == "__main__":
|
535 |
+
# Launch with better configuration for Hugging Face Spaces
|
536 |
+
logger.info("Launching Gradio interface...")
|
537 |
+
demo.launch(
|
538 |
+
server_name="0.0.0.0",
|
539 |
+
server_port=7860,
|
540 |
+
share=False,
|
541 |
+
show_error=True
|
542 |
+
)
|
attention.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn, Tensor
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger("dinov2")
|
19 |
+
|
20 |
+
|
21 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
22 |
+
try:
|
23 |
+
if XFORMERS_ENABLED:
|
24 |
+
from xformers.ops import memory_efficient_attention, unbind
|
25 |
+
|
26 |
+
XFORMERS_AVAILABLE = True
|
27 |
+
warnings.warn("xFormers is available (Attention)")
|
28 |
+
else:
|
29 |
+
warnings.warn("xFormers is disabled (Attention)")
|
30 |
+
raise ImportError
|
31 |
+
except ImportError:
|
32 |
+
XFORMERS_AVAILABLE = False
|
33 |
+
warnings.warn("xFormers is not available (Attention)")
|
34 |
+
|
35 |
+
|
36 |
+
try:
|
37 |
+
from typing import Optional
|
38 |
+
from typing import Union
|
39 |
+
FloatOrNone = Union[float, None]
|
40 |
+
except ImportError:
|
41 |
+
FloatOrNone = float | None
|
42 |
+
|
43 |
+
|
44 |
+
class Attention(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
dim: int,
|
48 |
+
num_heads: int = 8,
|
49 |
+
qkv_bias: bool = False,
|
50 |
+
proj_bias: bool = True,
|
51 |
+
attn_drop: float = 0.0,
|
52 |
+
proj_drop: float = 0.0,
|
53 |
+
) -> None:
|
54 |
+
super().__init__()
|
55 |
+
self.dim = dim
|
56 |
+
self.num_heads = num_heads
|
57 |
+
head_dim = dim // num_heads
|
58 |
+
self.scale = head_dim**-0.5
|
59 |
+
|
60 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
61 |
+
self.attn_drop = attn_drop
|
62 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
63 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
64 |
+
|
65 |
+
def init_weights(
|
66 |
+
self, init_attn_std: FloatOrNone = None, init_proj_std: FloatOrNone = None, factor: float = 1.0
|
67 |
+
) -> None:
|
68 |
+
init_attn_std = init_attn_std or (self.dim**-0.5)
|
69 |
+
init_proj_std = init_proj_std or init_attn_std * factor
|
70 |
+
nn.init.normal_(self.qkv.weight, std=init_attn_std)
|
71 |
+
nn.init.normal_(self.proj.weight, std=init_proj_std)
|
72 |
+
if self.qkv.bias is not None:
|
73 |
+
nn.init.zeros_(self.qkv.bias)
|
74 |
+
if self.proj.bias is not None:
|
75 |
+
nn.init.zeros_(self.proj.bias)
|
76 |
+
|
77 |
+
def forward(self, x: Tensor, is_causal: bool = False) -> Tensor:
|
78 |
+
B, N, C = x.shape
|
79 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
80 |
+
q, k, v = torch.unbind(qkv, 2)
|
81 |
+
q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
|
82 |
+
x = nn.functional.scaled_dot_product_attention(
|
83 |
+
q, k, v, attn_mask=None, dropout_p=self.attn_drop if self.training else 0, is_causal=is_causal
|
84 |
+
)
|
85 |
+
x = x.transpose(1, 2).contiguous().view(B, N, C)
|
86 |
+
x = self.proj_drop(self.proj(x))
|
87 |
+
return x
|
88 |
+
|
89 |
+
|
90 |
+
class MemEffAttention(Attention):
|
91 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
92 |
+
if not XFORMERS_AVAILABLE:
|
93 |
+
if attn_bias is not None:
|
94 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
95 |
+
return super().forward(x)
|
96 |
+
|
97 |
+
B, N, C = x.shape
|
98 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
99 |
+
|
100 |
+
q, k, v = unbind(qkv, 2)
|
101 |
+
|
102 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
103 |
+
x = x.reshape([B, N, C])
|
104 |
+
|
105 |
+
x = self.proj(x)
|
106 |
+
x = self.proj_drop(x)
|
107 |
+
return x
|
block.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
from typing import Callable, List, Any, Tuple, Dict, Optional
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn, Tensor
|
17 |
+
|
18 |
+
from .attention import Attention, MemEffAttention
|
19 |
+
from .drop_path import DropPath
|
20 |
+
from .layer_scale import LayerScale
|
21 |
+
from .mlp import Mlp
|
22 |
+
|
23 |
+
try:
|
24 |
+
from typing import Optional
|
25 |
+
from typing import Union
|
26 |
+
FloatOrNone = Union[float, None]
|
27 |
+
except ImportError:
|
28 |
+
FloatOrNone = float | None
|
29 |
+
|
30 |
+
logger = logging.getLogger("dinov2")
|
31 |
+
|
32 |
+
|
33 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
34 |
+
try:
|
35 |
+
if XFORMERS_ENABLED:
|
36 |
+
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
37 |
+
|
38 |
+
XFORMERS_AVAILABLE = True
|
39 |
+
warnings.warn("xFormers is available (Block)")
|
40 |
+
else:
|
41 |
+
warnings.warn("xFormers is disabled (Block)")
|
42 |
+
raise ImportError
|
43 |
+
except ImportError:
|
44 |
+
XFORMERS_AVAILABLE = False
|
45 |
+
|
46 |
+
warnings.warn("xFormers is not available (Block)")
|
47 |
+
|
48 |
+
|
49 |
+
class Block(nn.Module):
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
dim: int,
|
53 |
+
num_heads: int,
|
54 |
+
mlp_ratio: float = 4.0,
|
55 |
+
qkv_bias: bool = False,
|
56 |
+
proj_bias: bool = True,
|
57 |
+
ffn_bias: bool = True,
|
58 |
+
drop: float = 0.0,
|
59 |
+
attn_drop: float = 0.0,
|
60 |
+
init_values=None,
|
61 |
+
drop_path: float = 0.0,
|
62 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
63 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
64 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
65 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
66 |
+
) -> None:
|
67 |
+
super().__init__()
|
68 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
69 |
+
self.norm1 = norm_layer(dim)
|
70 |
+
self.attn = attn_class(
|
71 |
+
dim,
|
72 |
+
num_heads=num_heads,
|
73 |
+
qkv_bias=qkv_bias,
|
74 |
+
proj_bias=proj_bias,
|
75 |
+
attn_drop=attn_drop,
|
76 |
+
proj_drop=drop,
|
77 |
+
)
|
78 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
79 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
80 |
+
|
81 |
+
self.norm2 = norm_layer(dim)
|
82 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
83 |
+
self.mlp = ffn_layer(
|
84 |
+
in_features=dim,
|
85 |
+
hidden_features=mlp_hidden_dim,
|
86 |
+
act_layer=act_layer,
|
87 |
+
drop=drop,
|
88 |
+
bias=ffn_bias,
|
89 |
+
)
|
90 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
91 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
92 |
+
|
93 |
+
self.sample_drop_ratio = drop_path
|
94 |
+
|
95 |
+
def forward(self, x: Tensor) -> Tensor:
|
96 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
97 |
+
return self.ls1(self.attn(self.norm1(x)))
|
98 |
+
|
99 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
100 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
101 |
+
|
102 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
103 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
104 |
+
x = drop_add_residual_stochastic_depth(
|
105 |
+
x,
|
106 |
+
residual_func=attn_residual_func,
|
107 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
108 |
+
)
|
109 |
+
x = drop_add_residual_stochastic_depth(
|
110 |
+
x,
|
111 |
+
residual_func=ffn_residual_func,
|
112 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
113 |
+
)
|
114 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
115 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
116 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
117 |
+
else:
|
118 |
+
x = x + attn_residual_func(x)
|
119 |
+
x = x + ffn_residual_func(x)
|
120 |
+
return x
|
121 |
+
|
122 |
+
|
123 |
+
class CausalAttentionBlock(nn.Module):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
dim: int,
|
127 |
+
num_heads: int,
|
128 |
+
ffn_ratio: float = 4.0,
|
129 |
+
ls_init_value: Optional[float] = None,
|
130 |
+
is_causal: bool = True,
|
131 |
+
act_layer: Callable = nn.GELU,
|
132 |
+
norm_layer: Callable = nn.LayerNorm,
|
133 |
+
dropout_prob: float = 0.0,
|
134 |
+
):
|
135 |
+
super().__init__()
|
136 |
+
|
137 |
+
self.dim = dim
|
138 |
+
self.is_causal = is_causal
|
139 |
+
self.ls1 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity()
|
140 |
+
self.attention_norm = norm_layer(dim)
|
141 |
+
self.attention = Attention(dim, num_heads, attn_drop=dropout_prob, proj_drop=dropout_prob)
|
142 |
+
|
143 |
+
self.ffn_norm = norm_layer(dim)
|
144 |
+
ffn_hidden_dim = int(dim * ffn_ratio)
|
145 |
+
self.feed_forward = Mlp(
|
146 |
+
in_features=dim,
|
147 |
+
hidden_features=ffn_hidden_dim,
|
148 |
+
drop=dropout_prob,
|
149 |
+
act_layer=act_layer,
|
150 |
+
)
|
151 |
+
|
152 |
+
self.ls2 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity()
|
153 |
+
|
154 |
+
def init_weights(
|
155 |
+
self,
|
156 |
+
init_attn_std: FloatOrNone = None,
|
157 |
+
init_proj_std: FloatOrNone = None,
|
158 |
+
init_fc_std: FloatOrNone = None,
|
159 |
+
factor: float = 1.0,
|
160 |
+
) -> None:
|
161 |
+
init_attn_std = init_attn_std or (self.dim**-0.5)
|
162 |
+
init_proj_std = init_proj_std or init_attn_std * factor
|
163 |
+
init_fc_std = init_fc_std or (2 * self.dim) ** -0.5
|
164 |
+
self.attention.init_weights(init_attn_std, init_proj_std)
|
165 |
+
self.attention_norm.reset_parameters()
|
166 |
+
nn.init.normal_(self.feed_forward.fc1.weight, std=init_fc_std)
|
167 |
+
nn.init.normal_(self.feed_forward.fc2.weight, std=init_proj_std)
|
168 |
+
self.ffn_norm.reset_parameters()
|
169 |
+
|
170 |
+
def forward(
|
171 |
+
self,
|
172 |
+
x: torch.Tensor,
|
173 |
+
):
|
174 |
+
x_attn = x + self.ls1(self.attention(self.attention_norm(x), self.is_causal))
|
175 |
+
x_ffn = x_attn + self.ls2(self.feed_forward(self.ffn_norm(x_attn)))
|
176 |
+
return x_ffn
|
177 |
+
|
178 |
+
|
179 |
+
def drop_add_residual_stochastic_depth(
|
180 |
+
x: Tensor,
|
181 |
+
residual_func: Callable[[Tensor], Tensor],
|
182 |
+
sample_drop_ratio: float = 0.0,
|
183 |
+
) -> Tensor:
|
184 |
+
# 1) extract subset using permutation
|
185 |
+
b, n, d = x.shape
|
186 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
187 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
188 |
+
x_subset = x[brange]
|
189 |
+
|
190 |
+
# 2) apply residual_func to get residual
|
191 |
+
residual = residual_func(x_subset)
|
192 |
+
|
193 |
+
x_flat = x.flatten(1)
|
194 |
+
residual = residual.flatten(1)
|
195 |
+
|
196 |
+
residual_scale_factor = b / sample_subset_size
|
197 |
+
|
198 |
+
# 3) add the residual
|
199 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
200 |
+
return x_plus_residual.view_as(x)
|
201 |
+
|
202 |
+
|
203 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
204 |
+
b, n, d = x.shape
|
205 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
206 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
207 |
+
residual_scale_factor = b / sample_subset_size
|
208 |
+
return brange, residual_scale_factor
|
209 |
+
|
210 |
+
|
211 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
212 |
+
if scaling_vector is None:
|
213 |
+
x_flat = x.flatten(1)
|
214 |
+
residual = residual.flatten(1)
|
215 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
216 |
+
else:
|
217 |
+
x_plus_residual = scaled_index_add(
|
218 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
219 |
+
)
|
220 |
+
return x_plus_residual
|
221 |
+
|
222 |
+
|
223 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
224 |
+
|
225 |
+
|
226 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
227 |
+
"""
|
228 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
229 |
+
"""
|
230 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
231 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
232 |
+
if all_shapes not in attn_bias_cache.keys():
|
233 |
+
seqlens = []
|
234 |
+
for b, x in zip(batch_sizes, x_list):
|
235 |
+
for _ in range(b):
|
236 |
+
seqlens.append(x.shape[1])
|
237 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
238 |
+
attn_bias._batch_sizes = batch_sizes
|
239 |
+
attn_bias_cache[all_shapes] = attn_bias
|
240 |
+
|
241 |
+
if branges is not None:
|
242 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
243 |
+
else:
|
244 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
245 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
246 |
+
|
247 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
248 |
+
|
249 |
+
|
250 |
+
def drop_add_residual_stochastic_depth_list(
|
251 |
+
x_list: List[Tensor],
|
252 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
253 |
+
sample_drop_ratio: float = 0.0,
|
254 |
+
scaling_vector=None,
|
255 |
+
) -> Tensor:
|
256 |
+
# 1) generate random set of indices for dropping samples in the batch
|
257 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
258 |
+
branges = [s[0] for s in branges_scales]
|
259 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
260 |
+
|
261 |
+
# 2) get attention bias and index+concat the tensors
|
262 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
263 |
+
|
264 |
+
# 3) apply residual_func to get residual, and split the result
|
265 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
266 |
+
|
267 |
+
outputs = []
|
268 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
269 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
270 |
+
return outputs
|
271 |
+
|
272 |
+
|
273 |
+
class NestedTensorBlock(Block):
|
274 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
275 |
+
"""
|
276 |
+
x_list contains a list of tensors to nest together and run
|
277 |
+
"""
|
278 |
+
assert isinstance(self.attn, MemEffAttention)
|
279 |
+
|
280 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
281 |
+
|
282 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
283 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
284 |
+
|
285 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
286 |
+
return self.mlp(self.norm2(x))
|
287 |
+
|
288 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
289 |
+
x_list,
|
290 |
+
residual_func=attn_residual_func,
|
291 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
292 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
293 |
+
)
|
294 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
295 |
+
x_list,
|
296 |
+
residual_func=ffn_residual_func,
|
297 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
298 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
299 |
+
)
|
300 |
+
return x_list
|
301 |
+
else:
|
302 |
+
|
303 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
304 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
305 |
+
|
306 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
307 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
308 |
+
|
309 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
310 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
311 |
+
x = x + ffn_residual_func(x)
|
312 |
+
return attn_bias.split(x)
|
313 |
+
|
314 |
+
def forward(self, x_or_x_list):
|
315 |
+
if isinstance(x_or_x_list, Tensor):
|
316 |
+
return super().forward(x_or_x_list)
|
317 |
+
elif isinstance(x_or_x_list, list):
|
318 |
+
if not XFORMERS_AVAILABLE:
|
319 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
320 |
+
return self.forward_nested(x_or_x_list)
|
321 |
+
else:
|
322 |
+
raise AssertionError
|
body_features.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
import gzip
|
6 |
+
from datetime import datetime
|
7 |
+
from pathlib import Path
|
8 |
+
import decord
|
9 |
+
import argparse
|
10 |
+
import json
|
11 |
+
import glob
|
12 |
+
import time
|
13 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
14 |
+
|
15 |
+
|
16 |
+
class PoseProcessor:
|
17 |
+
"""
|
18 |
+
A class for processing pose landmarks and converting them to normalized numpy arrays.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, pose_indices: Optional[List[int]] = None,
|
22 |
+
normalize_keypoints: bool = True, fill_missing_value: float = -9999.0):
|
23 |
+
"""
|
24 |
+
Initialize the PoseProcessor.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
pose_indices: List of pose landmark indices to extract.
|
28 |
+
Default is [0,11,12,13,14,15,16] (nose, shoulders, elbows, wrists)
|
29 |
+
normalize_keypoints: Whether to normalize keypoints to signing space
|
30 |
+
fill_missing_value: Value to use for missing keypoints
|
31 |
+
"""
|
32 |
+
self.pose_indices = pose_indices if pose_indices else [0, 11, 12, 13, 14, 15, 16]
|
33 |
+
self.normalize_keypoints = normalize_keypoints
|
34 |
+
self.fill_missing_value = fill_missing_value
|
35 |
+
|
36 |
+
# Number of coordinates per keypoint (x, y)
|
37 |
+
self.coords_per_keypoint = 2
|
38 |
+
self.output_shape = (len(self.pose_indices), self.coords_per_keypoint)
|
39 |
+
|
40 |
+
def normalize_pose_keypoints(self, pose_landmarks: List[List[float]]) -> List[List[float]]:
|
41 |
+
"""
|
42 |
+
Normalize pose keypoints to signing space.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
pose_landmarks: List of pose landmarks from MediaPipe
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
List of normalized pose keypoints
|
49 |
+
"""
|
50 |
+
# Extract relevant landmarks for normalization
|
51 |
+
left_shoulder = np.array(pose_landmarks[11][:2])
|
52 |
+
right_shoulder = np.array(pose_landmarks[12][:2])
|
53 |
+
left_eye = np.array(pose_landmarks[2][:2])
|
54 |
+
nose = np.array(pose_landmarks[0][:2])
|
55 |
+
|
56 |
+
# Calculate head unit in normalized space
|
57 |
+
head_unit = np.linalg.norm(right_shoulder - left_shoulder) / 2
|
58 |
+
|
59 |
+
# Define signing space dimensions in normalized space
|
60 |
+
signing_space_width = 6 * head_unit
|
61 |
+
signing_space_height = 7 * head_unit
|
62 |
+
|
63 |
+
# Calculate signing space bounding box in normalized space
|
64 |
+
signing_space_top = left_eye[1] - 0.5 * head_unit
|
65 |
+
signing_space_bottom = signing_space_top + signing_space_height
|
66 |
+
signing_space_left = nose[0] - signing_space_width / 2
|
67 |
+
signing_space_right = signing_space_left + signing_space_width
|
68 |
+
|
69 |
+
# Create transformation matrix
|
70 |
+
translation_matrix = np.array([[1, 0, -signing_space_left],
|
71 |
+
[0, 1, -signing_space_top],
|
72 |
+
[0, 0, 1]])
|
73 |
+
scale_matrix = np.array([[1 / signing_space_width, 0, 0],
|
74 |
+
[0, 1 / signing_space_height, 0],
|
75 |
+
[0, 0, 1]])
|
76 |
+
shift_matrix = np.array([[1, 0, -0.5],
|
77 |
+
[0, 1, -0.5],
|
78 |
+
[0, 0, 1]])
|
79 |
+
transformation_matrix = shift_matrix @ scale_matrix @ translation_matrix
|
80 |
+
|
81 |
+
# Apply transformation to pose keypoints
|
82 |
+
normalized_keypoints = []
|
83 |
+
for landmark in pose_landmarks:
|
84 |
+
keypoint = np.array([landmark[0], landmark[1], 1])
|
85 |
+
normalized_keypoint = transformation_matrix @ keypoint
|
86 |
+
normalized_keypoints.append(normalized_keypoint[:2].tolist())
|
87 |
+
|
88 |
+
return normalized_keypoints
|
89 |
+
|
90 |
+
def process_frame_landmarks(self, frame_landmarks: Optional[Dict[str, Any]]) -> np.ndarray:
|
91 |
+
"""
|
92 |
+
Process landmarks for a single frame.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
frame_landmarks: Dictionary containing pose landmarks for one frame
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
Numpy array of processed pose keypoints
|
99 |
+
"""
|
100 |
+
if frame_landmarks is None or frame_landmarks.get('pose_landmarks') is None:
|
101 |
+
# Return missing value array
|
102 |
+
return np.full(self.output_shape, self.fill_missing_value).flatten()
|
103 |
+
|
104 |
+
# Get pose landmarks
|
105 |
+
pose_landmarks = frame_landmarks['pose_landmarks'][0]
|
106 |
+
|
107 |
+
# Normalize keypoints if required
|
108 |
+
if self.normalize_keypoints:
|
109 |
+
# Take first 25 landmarks for normalization (MediaPipe pose has 33 total)
|
110 |
+
normalized_landmarks = self.normalize_pose_keypoints(pose_landmarks[:25])
|
111 |
+
else:
|
112 |
+
normalized_landmarks = pose_landmarks
|
113 |
+
|
114 |
+
# Extract only the specified indices
|
115 |
+
selected_landmarks = [normalized_landmarks[i] for i in self.pose_indices]
|
116 |
+
|
117 |
+
# Convert to numpy array and flatten
|
118 |
+
frame_keypoints = np.array(selected_landmarks).flatten()
|
119 |
+
|
120 |
+
return frame_keypoints
|
121 |
+
|
122 |
+
def process_landmarks_sequence(self, landmarks_data: Dict[int, Any]) -> np.ndarray:
|
123 |
+
"""
|
124 |
+
Process landmarks for an entire sequence (video).
|
125 |
+
|
126 |
+
Args:
|
127 |
+
landmarks_data: Dictionary containing landmarks for each frame
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
Numpy array of shape (num_frames, num_keypoints * 2)
|
131 |
+
"""
|
132 |
+
# Get number of frames
|
133 |
+
if not landmarks_data:
|
134 |
+
return np.array([])
|
135 |
+
|
136 |
+
max_frame = max(landmarks_data.keys())
|
137 |
+
num_frames = max_frame + 1
|
138 |
+
|
139 |
+
video_pose_landmarks = []
|
140 |
+
prev_pose = None
|
141 |
+
|
142 |
+
for i in range(num_frames):
|
143 |
+
frame_landmarks = landmarks_data.get(i, None)
|
144 |
+
|
145 |
+
if frame_landmarks is None:
|
146 |
+
# Use previous pose if available, otherwise use missing values
|
147 |
+
if prev_pose is not None:
|
148 |
+
frame_keypoints = prev_pose
|
149 |
+
else:
|
150 |
+
frame_keypoints = np.full(self.output_shape, self.fill_missing_value).flatten()
|
151 |
+
else:
|
152 |
+
# Process current frame
|
153 |
+
frame_keypoints = self.process_frame_landmarks(frame_landmarks)
|
154 |
+
if not np.all(frame_keypoints == self.fill_missing_value):
|
155 |
+
prev_pose = frame_keypoints
|
156 |
+
|
157 |
+
video_pose_landmarks.append(frame_keypoints)
|
158 |
+
|
159 |
+
# Convert to numpy array
|
160 |
+
video_pose_landmarks = np.array(video_pose_landmarks)
|
161 |
+
|
162 |
+
# Apply any post-processing (like the original code's wrist masking)
|
163 |
+
# video_pose_landmarks = self._apply_post_processing(video_pose_landmarks)
|
164 |
+
|
165 |
+
return video_pose_landmarks
|
166 |
+
|
167 |
+
def _apply_post_processing(self, pose_array: np.ndarray) -> np.ndarray:
|
168 |
+
"""
|
169 |
+
Apply post-processing to the pose array.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
pose_array: Input pose array
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
Post-processed pose array
|
176 |
+
"""
|
177 |
+
# The original code fills left and right wrist with -9999
|
178 |
+
# This corresponds to indices 15 and 16 in the original pose landmarks
|
179 |
+
# In our selected indices [0,11,12,13,14,15,16], wrists are at positions 5 and 6
|
180 |
+
# Each keypoint has 2 coordinates, so wrists are at positions 10-11 and 12-13
|
181 |
+
|
182 |
+
# if len(self.pose_indices) >= 7 and 15 in self.pose_indices and 16 in self.pose_indices:
|
183 |
+
# # Find positions of wrists in our selected indices
|
184 |
+
# left_wrist_idx = self.pose_indices.index(15) * 2 # *2 because each keypoint has x,y
|
185 |
+
# right_wrist_idx = self.pose_indices.index(16) * 2
|
186 |
+
|
187 |
+
# # Fill wrist coordinates with missing value
|
188 |
+
# pose_array[:, left_wrist_idx:left_wrist_idx+2] = self.fill_missing_value
|
189 |
+
# pose_array[:, right_wrist_idx:right_wrist_idx+2] = self.fill_missing_value
|
190 |
+
|
191 |
+
return pose_array
|
192 |
+
|
193 |
+
def process_landmarks_from_file(self, pose_file_path: str) -> np.ndarray:
|
194 |
+
"""
|
195 |
+
Process landmarks from a JSON file.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
pose_file_path: Path to the pose landmarks JSON file
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
Numpy array of processed pose keypoints
|
202 |
+
"""
|
203 |
+
try:
|
204 |
+
with open(pose_file_path, 'r') as f:
|
205 |
+
landmarks_data = json.load(f)
|
206 |
+
|
207 |
+
# Convert string keys to integers
|
208 |
+
landmarks_data = {int(k): v for k, v in landmarks_data.items()}
|
209 |
+
|
210 |
+
return self.process_landmarks_sequence(landmarks_data)
|
211 |
+
|
212 |
+
except Exception as e:
|
213 |
+
print(f"Error processing {pose_file_path}: {e}")
|
214 |
+
return np.array([])
|
215 |
+
|
216 |
+
def process_and_save_landmarks(self, landmarks_data: Dict[int, Any],
|
217 |
+
output_path: str, filename: str) -> str:
|
218 |
+
"""
|
219 |
+
Process landmarks and save to file.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
landmarks_data: Dictionary containing landmarks for each frame
|
223 |
+
output_path: Directory to save the processed landmarks
|
224 |
+
filename: Name for the output file (without extension)
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
Path to the saved file
|
228 |
+
"""
|
229 |
+
# Process landmarks
|
230 |
+
processed_landmarks = self.process_landmarks_sequence(landmarks_data)
|
231 |
+
|
232 |
+
# Create output directory if it doesn't exist
|
233 |
+
output_dir = Path(output_path)
|
234 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
235 |
+
|
236 |
+
# Save to file
|
237 |
+
save_path = output_dir / f"{filename}.npy"
|
238 |
+
np.save(save_path, processed_landmarks)
|
239 |
+
|
240 |
+
return str(save_path)
|
241 |
+
|
242 |
+
|
243 |
+
# Convenience functions for backward compatibility
|
244 |
+
def process_pose_landmarks(landmarks_data: Dict[int, Any],
|
245 |
+
normalize: bool = True,
|
246 |
+
pose_indices: Optional[List[int]] = None) -> np.ndarray:
|
247 |
+
"""
|
248 |
+
Convenience function to process pose landmarks.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
landmarks_data: Dictionary containing landmarks for each frame
|
252 |
+
normalize: Whether to normalize keypoints to signing space
|
253 |
+
pose_indices: List of pose landmark indices to extract
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
Numpy array of processed pose keypoints
|
257 |
+
"""
|
258 |
+
processor = PoseProcessor(pose_indices=pose_indices, normalize_keypoints=normalize)
|
259 |
+
return processor.process_landmarks_sequence(landmarks_data)
|
260 |
+
|
261 |
+
|
262 |
+
def keypoints_to_numpy(pose_file: str, pose_emb_path: str):
|
263 |
+
"""
|
264 |
+
Original function for backward compatibility with command-line usage.
|
265 |
+
"""
|
266 |
+
try:
|
267 |
+
processor = PoseProcessor()
|
268 |
+
processed_landmarks = processor.process_landmarks_from_file(pose_file)
|
269 |
+
|
270 |
+
if processed_landmarks.size > 0:
|
271 |
+
# Save the processed landmarks
|
272 |
+
video_name = Path(pose_file).stem
|
273 |
+
save_path = Path(pose_emb_path) / f"{video_name}.npy"
|
274 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
275 |
+
np.save(save_path, processed_landmarks)
|
276 |
+
|
277 |
+
except Exception as e:
|
278 |
+
print(f"Error processing {pose_file}: {e}")
|
279 |
+
|
280 |
+
|
281 |
+
# Utility functions for batch processing
|
282 |
+
def get_mp4_files(directory: str) -> List[str]:
|
283 |
+
"""Get all MP4 files in a directory."""
|
284 |
+
if not os.path.exists(directory):
|
285 |
+
raise FileNotFoundError(f'Directory not found: {directory}')
|
286 |
+
|
287 |
+
mp4_files = glob.glob(os.path.join(directory, '*.mp4'))
|
288 |
+
return [os.path.abspath(file) for file in mp4_files]
|
289 |
+
|
290 |
+
|
291 |
+
def load_file(filename: str):
|
292 |
+
"""Load a pickled and gzipped file."""
|
293 |
+
with gzip.open(filename, "rb") as f:
|
294 |
+
return pickle.load(f)
|
295 |
+
|
296 |
+
|
297 |
+
def is_string_in_file(file_path: str, target_string: str) -> bool:
|
298 |
+
"""Check if a string exists in a file."""
|
299 |
+
try:
|
300 |
+
with Path(file_path).open("r") as f:
|
301 |
+
for line in f:
|
302 |
+
if target_string in line:
|
303 |
+
return True
|
304 |
+
return False
|
305 |
+
except Exception as e:
|
306 |
+
print(f"Error: {e}")
|
307 |
+
return False
|
308 |
+
|
309 |
+
|
310 |
+
def main():
|
311 |
+
"""Main function for command-line usage."""
|
312 |
+
parser = argparse.ArgumentParser()
|
313 |
+
parser.add_argument('--index', type=int, required=True,
|
314 |
+
help='index of the sub_list to work with')
|
315 |
+
parser.add_argument('--files_list', type=str, required=True,
|
316 |
+
help='path to the pose file')
|
317 |
+
parser.add_argument('--pose_features_path', type=str, required=True,
|
318 |
+
help='path to the pose features file')
|
319 |
+
parser.add_argument('--batch_size', type=int, required=True,
|
320 |
+
help='batch size')
|
321 |
+
parser.add_argument('--time_limit', type=int, required=True,
|
322 |
+
help='time limit')
|
323 |
+
|
324 |
+
args = parser.parse_args()
|
325 |
+
start_time = time.time()
|
326 |
+
|
327 |
+
# Load files list
|
328 |
+
fixed_list = load_file(args.files_list)
|
329 |
+
|
330 |
+
# Initialize processor
|
331 |
+
processor = PoseProcessor()
|
332 |
+
|
333 |
+
# Process files in batches
|
334 |
+
video_batches = [fixed_list[i:i + args.batch_size] for i in range(0, len(fixed_list), args.batch_size)]
|
335 |
+
|
336 |
+
for pose_file in video_batches[args.index]:
|
337 |
+
pose_file_path = Path(pose_file)
|
338 |
+
output_path = Path(args.pose_features_path) / f"{pose_file_path.stem}.npy"
|
339 |
+
|
340 |
+
if output_path.exists():
|
341 |
+
print(f"Skipping {pose_file} - output already exists")
|
342 |
+
continue
|
343 |
+
|
344 |
+
current_time = time.time()
|
345 |
+
if current_time - start_time > args.time_limit:
|
346 |
+
print("Time limit reached. Stopping execution.")
|
347 |
+
break
|
348 |
+
|
349 |
+
try:
|
350 |
+
print(f"Processing {pose_file}")
|
351 |
+
keypoints_to_numpy(pose_file, args.pose_features_path)
|
352 |
+
print(f"Successfully processed {pose_file}")
|
353 |
+
except Exception as e:
|
354 |
+
print(f"Error processing {pose_file}: {e}")
|
355 |
+
|
356 |
+
|
357 |
+
if __name__ == "__main__":
|
358 |
+
main()
|
crop_face.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
import gzip
|
6 |
+
from datetime import datetime
|
7 |
+
from pathlib import Path
|
8 |
+
import decord
|
9 |
+
import argparse
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
from typing import Dict, Optional, Tuple, List, Union, Any
|
13 |
+
|
14 |
+
|
15 |
+
class FaceExtractor:
|
16 |
+
"""
|
17 |
+
A class for extracting face regions from videos based on pose and face landmarks.
|
18 |
+
Creates face frames with only eyes and mouth visible on grey background.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, output_size: Tuple[int, int] = (224, 224),
|
22 |
+
scale_factor: float = 1.2, grey_background_color: int = 128):
|
23 |
+
"""
|
24 |
+
Initialize the FaceExtractor.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
output_size: Size of the output face frames (width, height)
|
28 |
+
scale_factor: Scale factor for bounding box expansion
|
29 |
+
grey_background_color: Color value for grey background (0-255)
|
30 |
+
"""
|
31 |
+
self.output_size = output_size
|
32 |
+
self.scale_factor = scale_factor
|
33 |
+
self.grey_background_color = grey_background_color
|
34 |
+
|
35 |
+
# Face landmark indices for eyes and mouth
|
36 |
+
self.left_eye_indices = [69, 168, 156, 118, 54]
|
37 |
+
self.right_eye_indices = [168, 299, 347, 336, 301]
|
38 |
+
self.mouth_indices = [164, 212, 432, 18]
|
39 |
+
|
40 |
+
def resize_frame(self, frame: np.ndarray, frame_size: Tuple[int, int]) -> Optional[np.ndarray]:
|
41 |
+
"""Resize frame to specified size."""
|
42 |
+
if frame is not None and frame.size > 0:
|
43 |
+
return cv2.resize(frame, frame_size, interpolation=cv2.INTER_AREA)
|
44 |
+
else:
|
45 |
+
return None
|
46 |
+
|
47 |
+
def calculate_bounding_box(self, landmarks: List[List[float]], indices: List[int],
|
48 |
+
image_shape: Tuple[int, int, int]) -> Tuple[int, int, int, int]:
|
49 |
+
"""Calculate bounding box for specific landmark indices."""
|
50 |
+
x_coordinates = [landmarks[i][0] for i in indices]
|
51 |
+
y_coordinates = [landmarks[i][1] for i in indices]
|
52 |
+
|
53 |
+
left = min(x_coordinates)
|
54 |
+
right = max(x_coordinates)
|
55 |
+
top = min(y_coordinates)
|
56 |
+
bottom = max(y_coordinates)
|
57 |
+
|
58 |
+
return (int(left * image_shape[1]), int(top * image_shape[0]),
|
59 |
+
int(right * image_shape[1]), int(bottom * image_shape[0]))
|
60 |
+
|
61 |
+
def crop_and_paste(self, src: np.ndarray, dst: np.ndarray,
|
62 |
+
src_box: Tuple[int, int, int, int], dst_origin: Tuple[int, int]):
|
63 |
+
"""Crop region from source and paste to destination."""
|
64 |
+
x1, y1, x2, y2 = src_box
|
65 |
+
dx, dy = dst_origin
|
66 |
+
crop = src[y1:y2, x1:x2]
|
67 |
+
crop_height, crop_width = crop.shape[:2]
|
68 |
+
dst[dy:dy+crop_height, dx:dx+crop_width] = crop
|
69 |
+
|
70 |
+
def cues_on_grey_background(self, image: np.ndarray, facial_landmarks: List[List[float]]) -> np.ndarray:
|
71 |
+
"""
|
72 |
+
Create face frame with only eyes and mouth visible on grey background.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
image: Input image as numpy array
|
76 |
+
facial_landmarks: Face landmarks from MediaPipe
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
Face frame with eyes and mouth on grey background
|
80 |
+
"""
|
81 |
+
image_shape = image.shape
|
82 |
+
|
83 |
+
# Calculate bounding boxes for facial features
|
84 |
+
left_eye_box = self.calculate_bounding_box(facial_landmarks, self.left_eye_indices, image_shape)
|
85 |
+
right_eye_box = self.calculate_bounding_box(facial_landmarks, self.right_eye_indices, image_shape)
|
86 |
+
mouth_box = self.calculate_bounding_box(facial_landmarks, self.mouth_indices, image_shape)
|
87 |
+
|
88 |
+
# Calculate the overall bounding box
|
89 |
+
min_x = min(left_eye_box[0], right_eye_box[0], mouth_box[0])
|
90 |
+
min_y = min(left_eye_box[1], right_eye_box[1], mouth_box[1])
|
91 |
+
max_x = max(left_eye_box[2], right_eye_box[2], mouth_box[2])
|
92 |
+
max_y = max(left_eye_box[3], right_eye_box[3], mouth_box[3])
|
93 |
+
|
94 |
+
# Add padding
|
95 |
+
padding = 10
|
96 |
+
min_x = max(0, min_x - padding)
|
97 |
+
min_y = max(0, min_y - padding)
|
98 |
+
max_x = min(image.shape[1], max_x + padding)
|
99 |
+
max_y = min(image.shape[0], max_y + padding)
|
100 |
+
|
101 |
+
# Make the crop a square by adjusting either width or height
|
102 |
+
width = max_x - min_x
|
103 |
+
height = max_y - min_y
|
104 |
+
side_length = max(width, height)
|
105 |
+
|
106 |
+
# Adjust to ensure square
|
107 |
+
if width < side_length:
|
108 |
+
extra = side_length - width
|
109 |
+
min_x = max(0, min_x - extra // 2)
|
110 |
+
max_x = min(image.shape[1], max_x + extra // 2)
|
111 |
+
|
112 |
+
if height < side_length:
|
113 |
+
extra = side_length - height
|
114 |
+
min_y = max(0, min_y - extra // 2)
|
115 |
+
max_y = min(image.shape[0], max_y + extra // 2)
|
116 |
+
|
117 |
+
# Create grey background image
|
118 |
+
grey_background = np.ones((side_length, side_length, 3), dtype=np.uint8) * self.grey_background_color
|
119 |
+
|
120 |
+
# Crop and paste facial features onto grey background
|
121 |
+
self.crop_and_paste(image, grey_background, left_eye_box, (left_eye_box[0]-min_x, left_eye_box[1]-min_y))
|
122 |
+
self.crop_and_paste(image, grey_background, right_eye_box, (right_eye_box[0]-min_x, right_eye_box[1]-min_y))
|
123 |
+
self.crop_and_paste(image, grey_background, mouth_box, (mouth_box[0]-min_x, mouth_box[1]-min_y))
|
124 |
+
|
125 |
+
return grey_background
|
126 |
+
|
127 |
+
def select_face(self, pose_landmarks: List[List[float]], face_landmarks: List[List[List[float]]]) -> List[List[float]]:
|
128 |
+
"""
|
129 |
+
Select the face that is closest to the pose nose landmark.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
pose_landmarks: Pose landmarks from MediaPipe
|
133 |
+
face_landmarks: List of face landmarks from MediaPipe
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
Selected face landmarks
|
137 |
+
"""
|
138 |
+
nose_landmark_from_pose = pose_landmarks[0] # Nose from pose
|
139 |
+
nose_landmarks_from_face = [face_landmarks[i][0] for i in range(len(face_landmarks))]
|
140 |
+
|
141 |
+
# Find closest face based on nose landmark
|
142 |
+
distances = [np.linalg.norm(np.array(nose_landmark_from_pose) - np.array(nose_landmark))
|
143 |
+
for nose_landmark in nose_landmarks_from_face]
|
144 |
+
closest_nose_index = np.argmin(distances)
|
145 |
+
|
146 |
+
return face_landmarks[closest_nose_index]
|
147 |
+
|
148 |
+
def extract_face_frames(self, video_input, landmarks_data: Dict[int, Any]) -> List[np.ndarray]:
|
149 |
+
"""
|
150 |
+
Extract face frames from video based on landmarks.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
video_input: Either a path to video file (str) or a decord.VideoReader object
|
154 |
+
landmarks_data: Dictionary containing pose and face landmarks for each frame
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
List of face frames as numpy arrays
|
158 |
+
"""
|
159 |
+
# Handle different input types
|
160 |
+
if isinstance(video_input, str):
|
161 |
+
video_path = Path(video_input)
|
162 |
+
if not video_path.exists():
|
163 |
+
raise FileNotFoundError(f"Video file not found: {video_input}")
|
164 |
+
video = decord.VideoReader(str(video_path))
|
165 |
+
# elif hasattr(video_input, '__len__') and hasattr(video_input, '__getitem__'):
|
166 |
+
else:
|
167 |
+
video = video_input
|
168 |
+
# else:
|
169 |
+
# raise TypeError("video_input must be either a file path (str) or a VideoReader object")
|
170 |
+
|
171 |
+
face_frames = []
|
172 |
+
prev_face_frame = None
|
173 |
+
prev_landmarks = None
|
174 |
+
|
175 |
+
for i in range(len(video)):
|
176 |
+
# frame = video[i].asnumpy()
|
177 |
+
frame = video[i]
|
178 |
+
if hasattr(video, 'seek'):
|
179 |
+
video.seek(0)
|
180 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
181 |
+
|
182 |
+
# Get landmarks for this frame
|
183 |
+
frame_landmarks = landmarks_data.get(i, None)
|
184 |
+
|
185 |
+
# Handle missing landmarks
|
186 |
+
if frame_landmarks is None:
|
187 |
+
if prev_landmarks is not None:
|
188 |
+
frame_landmarks = prev_landmarks
|
189 |
+
else:
|
190 |
+
# Use blank frame if no landmarks available
|
191 |
+
face_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
|
192 |
+
continue
|
193 |
+
else:
|
194 |
+
prev_landmarks = frame_landmarks
|
195 |
+
|
196 |
+
# Check if pose landmarks exist
|
197 |
+
if frame_landmarks.get('pose_landmarks') is None:
|
198 |
+
if prev_face_frame is not None:
|
199 |
+
face_frames.append(prev_face_frame)
|
200 |
+
else:
|
201 |
+
face_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
|
202 |
+
continue
|
203 |
+
|
204 |
+
# Process face if face landmarks exist
|
205 |
+
if frame_landmarks.get('face_landmarks') is not None:
|
206 |
+
# Select the face closest to the pose
|
207 |
+
selected_face = self.select_face(
|
208 |
+
frame_landmarks['pose_landmarks'][0],
|
209 |
+
frame_landmarks['face_landmarks']
|
210 |
+
)
|
211 |
+
|
212 |
+
# Create face frame with cues on grey background
|
213 |
+
face_frame = self.cues_on_grey_background(frame_rgb, selected_face)
|
214 |
+
face_frame = self.resize_frame(face_frame, self.output_size)
|
215 |
+
face_frames.append(face_frame)
|
216 |
+
prev_face_frame = face_frame
|
217 |
+
|
218 |
+
elif prev_face_frame is not None:
|
219 |
+
face_frames.append(prev_face_frame)
|
220 |
+
else:
|
221 |
+
# Use blank frame if no face landmarks
|
222 |
+
face_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
|
223 |
+
|
224 |
+
return face_frames
|
225 |
+
|
226 |
+
def extract_and_save_face_video(self, video_input, landmarks_data: Dict[int, Any],
|
227 |
+
output_dir: str, video_name: Optional[str] = None) -> str:
|
228 |
+
"""
|
229 |
+
Extract face frames and save as video file.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
video_input: Either a path to video file (str) or a decord.VideoReader object
|
233 |
+
landmarks_data: Dictionary containing pose and face landmarks for each frame
|
234 |
+
output_dir: Directory to save the face video
|
235 |
+
video_name: Name for output video (auto-generated if not provided)
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
Path to the saved face video
|
239 |
+
"""
|
240 |
+
# Handle video input and get FPS
|
241 |
+
if isinstance(video_input, str):
|
242 |
+
video_path = Path(video_input)
|
243 |
+
if not video_path.exists():
|
244 |
+
raise FileNotFoundError(f"Video file not found: {video_input}")
|
245 |
+
video = decord.VideoReader(str(video_path))
|
246 |
+
if video_name is None:
|
247 |
+
video_name = video_path.stem
|
248 |
+
# elif hasattr(video_input, '__len__') and hasattr(video_input, '__getitem__'):
|
249 |
+
else:
|
250 |
+
video = video_input
|
251 |
+
if video_name is None:
|
252 |
+
video_name = "video"
|
253 |
+
# else:
|
254 |
+
# raise TypeError("video_input must be either a file path (str) or a VideoReader object")
|
255 |
+
|
256 |
+
fps = video.get_avg_fps() if hasattr(video, 'get_avg_fps') else 30.0
|
257 |
+
|
258 |
+
# Create output directory
|
259 |
+
output_path = Path(output_dir)
|
260 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
261 |
+
|
262 |
+
# Define output path
|
263 |
+
face_video_path = output_path / f"{video_name}_face.mp4"
|
264 |
+
|
265 |
+
# Remove existing file
|
266 |
+
if face_video_path.exists():
|
267 |
+
face_video_path.unlink()
|
268 |
+
|
269 |
+
# Create video writer
|
270 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
271 |
+
writer = cv2.VideoWriter(str(face_video_path), fourcc, fps, self.output_size)
|
272 |
+
|
273 |
+
try:
|
274 |
+
# Extract face frames
|
275 |
+
face_frames = self.extract_face_frames(video, landmarks_data)
|
276 |
+
|
277 |
+
# Write frames to video file
|
278 |
+
for frame in face_frames:
|
279 |
+
writer.write(frame)
|
280 |
+
|
281 |
+
finally:
|
282 |
+
# Clean up
|
283 |
+
writer.release()
|
284 |
+
del writer
|
285 |
+
|
286 |
+
return str(face_video_path)
|
287 |
+
|
288 |
+
|
289 |
+
# Convenience function for backward compatibility
|
290 |
+
def extract_face_frames(video_input, landmarks_data: Dict[int, Any],
|
291 |
+
output_size: Tuple[int, int] = (224, 224)) -> List[np.ndarray]:
|
292 |
+
"""
|
293 |
+
Convenience function to extract face frames from video.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
video_input: Either a path to video file (str) or a decord.VideoReader object
|
297 |
+
landmarks_data: Dictionary containing pose and face landmarks for each frame
|
298 |
+
output_size: Size of the output face frames (width, height)
|
299 |
+
|
300 |
+
Returns:
|
301 |
+
List of face frames as numpy arrays
|
302 |
+
"""
|
303 |
+
extractor = FaceExtractor(output_size=output_size)
|
304 |
+
return extractor.extract_face_frames(video_input, landmarks_data)
|
305 |
+
|
306 |
+
|
307 |
+
def video_holistic(video_file: str, face_path: str, problem_file_path: str, pose_path: str):
|
308 |
+
"""
|
309 |
+
Original function for backward compatibility with command-line usage.
|
310 |
+
"""
|
311 |
+
try:
|
312 |
+
video = decord.VideoReader(video_file)
|
313 |
+
fps = video.get_avg_fps()
|
314 |
+
|
315 |
+
video_name = Path(video_file).stem
|
316 |
+
clip_face_path = Path(face_path) / f"{video_name}_face.mp4"
|
317 |
+
landmark_json_path = Path(pose_path) / f"{video_name}_pose.json"
|
318 |
+
|
319 |
+
# Load landmarks
|
320 |
+
with open(landmark_json_path, 'r') as rd:
|
321 |
+
landmarks_data = json.load(rd)
|
322 |
+
|
323 |
+
# Convert string keys to integers
|
324 |
+
landmarks_data = {int(k): v for k, v in landmarks_data.items()}
|
325 |
+
|
326 |
+
# Extract face video
|
327 |
+
extractor = FaceExtractor()
|
328 |
+
extractor.extract_and_save_face_video(video, landmarks_data, face_path, video_name)
|
329 |
+
|
330 |
+
except Exception as e:
|
331 |
+
print(f"Error processing {video_file}: {e}")
|
332 |
+
with open(problem_file_path, "a") as p:
|
333 |
+
p.write(video_file + "\n")
|
334 |
+
|
335 |
+
|
336 |
+
# Utility functions for batch processing
|
337 |
+
def load_file(filename: str):
|
338 |
+
"""Load a pickled and gzipped file."""
|
339 |
+
with gzip.open(filename, "rb") as f:
|
340 |
+
return pickle.load(f)
|
341 |
+
|
342 |
+
|
343 |
+
def is_string_in_file(file_path: str, target_string: str) -> bool:
|
344 |
+
"""Check if a string exists in a file."""
|
345 |
+
try:
|
346 |
+
with Path(file_path).open("r") as f:
|
347 |
+
for line in f:
|
348 |
+
if target_string in line:
|
349 |
+
return True
|
350 |
+
return False
|
351 |
+
except Exception as e:
|
352 |
+
print(f"Error: {e}")
|
353 |
+
return False
|
354 |
+
|
355 |
+
|
356 |
+
def main():
|
357 |
+
"""Main function for command-line usage."""
|
358 |
+
parser = argparse.ArgumentParser()
|
359 |
+
parser.add_argument('--index', type=int, required=True,
|
360 |
+
help='index of the sub_list to work with')
|
361 |
+
parser.add_argument('--batch_size', type=int, required=True,
|
362 |
+
help='batch size')
|
363 |
+
parser.add_argument('--time_limit', type=int, required=True,
|
364 |
+
help='time limit')
|
365 |
+
parser.add_argument('--files_list', type=str, required=True,
|
366 |
+
help='files list')
|
367 |
+
parser.add_argument('--problem_file_path', type=str, required=True,
|
368 |
+
help='problem file path')
|
369 |
+
parser.add_argument('--pose_path', type=str, required=True,
|
370 |
+
help='pose path')
|
371 |
+
parser.add_argument('--face_path', type=str, required=True,
|
372 |
+
help='face path')
|
373 |
+
|
374 |
+
args = parser.parse_args()
|
375 |
+
start_time = time.time()
|
376 |
+
|
377 |
+
# Load files list
|
378 |
+
fixed_list = load_file(args.files_list)
|
379 |
+
|
380 |
+
# Create problem file if it doesn't exist
|
381 |
+
if not os.path.exists(args.problem_file_path):
|
382 |
+
with open(args.problem_file_path, "w") as f:
|
383 |
+
f.write("")
|
384 |
+
|
385 |
+
# Process videos in batches
|
386 |
+
video_batches = [fixed_list[i:i + args.batch_size] for i in range(0, len(fixed_list), args.batch_size)]
|
387 |
+
|
388 |
+
for video_file in video_batches[args.index]:
|
389 |
+
current_time = time.time()
|
390 |
+
if current_time - start_time > args.time_limit:
|
391 |
+
print("Time limit reached. Stopping execution.")
|
392 |
+
break
|
393 |
+
|
394 |
+
video_name = Path(video_file).stem
|
395 |
+
clip_face_path = Path(args.face_path) / f"{video_name}_face.mp4"
|
396 |
+
|
397 |
+
if clip_face_path.exists():
|
398 |
+
print(f"Skipping {video_file} - output already exists")
|
399 |
+
continue
|
400 |
+
elif is_string_in_file(args.problem_file_path, video_file):
|
401 |
+
print(f"Skipping {video_file} - found in problem file")
|
402 |
+
continue
|
403 |
+
else:
|
404 |
+
try:
|
405 |
+
print(f"Processing {video_file}")
|
406 |
+
video_holistic(video_file, args.face_path, args.problem_file_path, args.pose_path)
|
407 |
+
print(f"Successfully processed {video_file}")
|
408 |
+
except Exception as e:
|
409 |
+
print(f"Error processing {video_file}: {e}")
|
410 |
+
with open(args.problem_file_path, "a") as p:
|
411 |
+
p.write(video_file + "\n")
|
412 |
+
|
413 |
+
|
414 |
+
if __name__ == "__main__":
|
415 |
+
main()
|
crop_hands.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
import gzip
|
6 |
+
from datetime import datetime
|
7 |
+
from pathlib import Path
|
8 |
+
import decord
|
9 |
+
import argparse
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
from typing import Dict, Optional, Tuple, List, Union, Any
|
13 |
+
import tempfile
|
14 |
+
|
15 |
+
|
16 |
+
class HandExtractor:
|
17 |
+
"""
|
18 |
+
A class for extracting hand regions from videos based on pose landmarks.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, output_size: Tuple[int, int] = (224, 224),
|
22 |
+
scale_factor: float = 1.5, distance_threshold: float = 0.1):
|
23 |
+
"""
|
24 |
+
Initialize the HandExtractor.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
output_size: Size of the output hand frames (width, height)
|
28 |
+
scale_factor: Scale factor for bounding box expansion
|
29 |
+
distance_threshold: Distance threshold for hand-pose matching
|
30 |
+
"""
|
31 |
+
self.output_size = output_size
|
32 |
+
self.scale_factor = scale_factor
|
33 |
+
self.distance_threshold = distance_threshold
|
34 |
+
|
35 |
+
def resize_frame(self, frame: np.ndarray, frame_size: Tuple[int, int]) -> Optional[np.ndarray]:
|
36 |
+
"""Resize frame to specified size."""
|
37 |
+
if frame is not None and frame.size > 0:
|
38 |
+
return cv2.resize(frame, frame_size, interpolation=cv2.INTER_AREA)
|
39 |
+
else:
|
40 |
+
return None
|
41 |
+
|
42 |
+
def crop_frame(self, image: np.ndarray, bounding_box: Tuple[int, int, int, int]) -> np.ndarray:
|
43 |
+
"""Crop frame using bounding box."""
|
44 |
+
x, y, w, h = bounding_box
|
45 |
+
cropped_frame = image[y:y + h, x:x + w]
|
46 |
+
return cropped_frame
|
47 |
+
|
48 |
+
def get_bounding_box(self, landmarks: List[List[float]], image_shape: Tuple[int, int, int],
|
49 |
+
scale_factor: float = 1.2) -> Tuple[int, int, int, int]:
|
50 |
+
"""Get bounding box from landmarks."""
|
51 |
+
ih, iw, _ = image_shape
|
52 |
+
landmarks_px = np.array([(int(l[0] * iw), int(l[1] * ih)) for l in landmarks])
|
53 |
+
center_x, center_y = np.mean(landmarks_px, axis=0, dtype=int)
|
54 |
+
xb, yb, wb, hb = cv2.boundingRect(landmarks_px)
|
55 |
+
box_size = max(wb, hb)
|
56 |
+
half_size = box_size // 2
|
57 |
+
x = center_x - half_size
|
58 |
+
y = center_y - half_size
|
59 |
+
w = box_size
|
60 |
+
h = box_size
|
61 |
+
|
62 |
+
w_padding = int((scale_factor - 1) * w / 2)
|
63 |
+
h_padding = int((scale_factor - 1) * h / 2)
|
64 |
+
x -= w_padding
|
65 |
+
y -= h_padding
|
66 |
+
w += 2 * w_padding
|
67 |
+
h += 2 * h_padding
|
68 |
+
|
69 |
+
return x, y, w, h
|
70 |
+
|
71 |
+
def adjust_bounding_box(self, bounding_box: Tuple[int, int, int, int],
|
72 |
+
image_shape: Tuple[int, int, int]) -> Tuple[int, int, int, int]:
|
73 |
+
"""Adjust bounding box to fit within image boundaries."""
|
74 |
+
x, y, w, h = bounding_box
|
75 |
+
ih, iw, _ = image_shape
|
76 |
+
|
77 |
+
# Adjust x-coordinate if the bounding box extends beyond the image's right edge
|
78 |
+
if x + w > iw:
|
79 |
+
x = iw - w
|
80 |
+
|
81 |
+
# Adjust y-coordinate if the bounding box extends beyond the image's bottom edge
|
82 |
+
if y + h > ih:
|
83 |
+
y = ih - h
|
84 |
+
|
85 |
+
# Ensure bounding box's x and y coordinates are not negative
|
86 |
+
x = max(x, 0)
|
87 |
+
y = max(y, 0)
|
88 |
+
|
89 |
+
return x, y, w, h
|
90 |
+
|
91 |
+
def select_hands(self, pose_landmarks: List[List[float]], hand_landmarks: Optional[List[List[List[float]]]],
|
92 |
+
image_shape: Tuple[int, int, int]) -> Tuple[Optional[List[List[float]]], Optional[List[List[float]]]]:
|
93 |
+
"""
|
94 |
+
Select left and right hands from detected hand landmarks based on pose wrist positions.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
pose_landmarks: Pose landmarks from MediaPipe
|
98 |
+
hand_landmarks: Hand landmarks from MediaPipe
|
99 |
+
image_shape: Shape of the image (height, width, channels)
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Tuple of (left_hand_landmarks, right_hand_landmarks)
|
103 |
+
"""
|
104 |
+
if hand_landmarks is None:
|
105 |
+
return None, None
|
106 |
+
|
107 |
+
# Get wrist landmarks from pose (indices 15 and 16 for left and right wrists)
|
108 |
+
left_wrist_from_pose = pose_landmarks[15]
|
109 |
+
right_wrist_from_pose = pose_landmarks[16]
|
110 |
+
|
111 |
+
# Get wrist landmarks from hand detections (index 0 is wrist in hand landmarks)
|
112 |
+
wrist_from_hand = [hand_landmarks[i][0] for i in range(len(hand_landmarks))]
|
113 |
+
|
114 |
+
# Match right hand
|
115 |
+
right_hand_landmarks = None
|
116 |
+
if right_wrist_from_pose is not None:
|
117 |
+
minimum_distance = 100
|
118 |
+
best_hand_idx = 0
|
119 |
+
for i in range(len(hand_landmarks)):
|
120 |
+
distance = np.linalg.norm(np.array(right_wrist_from_pose[0:2]) - np.array(wrist_from_hand[i][0:2]))
|
121 |
+
if distance < minimum_distance:
|
122 |
+
minimum_distance = distance
|
123 |
+
best_hand_idx = i
|
124 |
+
|
125 |
+
if minimum_distance < self.distance_threshold:
|
126 |
+
right_hand_landmarks = hand_landmarks[best_hand_idx]
|
127 |
+
|
128 |
+
# Match left hand
|
129 |
+
left_hand_landmarks = None
|
130 |
+
if left_wrist_from_pose is not None:
|
131 |
+
minimum_distance = 100
|
132 |
+
best_hand_idx = 0
|
133 |
+
for i in range(len(hand_landmarks)):
|
134 |
+
distance = np.linalg.norm(np.array(left_wrist_from_pose[0:2]) - np.array(wrist_from_hand[i][0:2]))
|
135 |
+
if distance < minimum_distance:
|
136 |
+
minimum_distance = distance
|
137 |
+
best_hand_idx = i
|
138 |
+
|
139 |
+
if minimum_distance < self.distance_threshold:
|
140 |
+
left_hand_landmarks = hand_landmarks[best_hand_idx]
|
141 |
+
|
142 |
+
return left_hand_landmarks, right_hand_landmarks
|
143 |
+
|
144 |
+
def extract_hand_frames(self, video_input, landmarks_data: Dict[int, Any]) -> Tuple[List[np.ndarray], List[np.ndarray]]:
|
145 |
+
"""
|
146 |
+
Extract hand frames from video based on landmarks.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
video_input: Either a path to video file (str) or a decord.VideoReader object
|
150 |
+
landmarks_data: Dictionary containing pose and hand landmarks for each frame
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
Tuple of (left_hand_frames, right_hand_frames) as lists of numpy arrays
|
154 |
+
"""
|
155 |
+
# Handle different input types
|
156 |
+
if isinstance(video_input, str):
|
157 |
+
video_path = Path(video_input)
|
158 |
+
if not video_path.exists():
|
159 |
+
raise FileNotFoundError(f"Video file not found: {video_input}")
|
160 |
+
video = decord.VideoReader(str(video_path))
|
161 |
+
# elif hasattr(video_input, '__len__') and hasattr(video_input, '__getitem__'):
|
162 |
+
else:
|
163 |
+
video = video_input
|
164 |
+
# else:
|
165 |
+
# raise TypeError("video_input must be either a file path (str) or a VideoReader object")
|
166 |
+
|
167 |
+
left_hand_frames = []
|
168 |
+
right_hand_frames = []
|
169 |
+
|
170 |
+
prev_left_frame = None
|
171 |
+
prev_right_frame = None
|
172 |
+
prev_landmarks = None
|
173 |
+
|
174 |
+
for i in range(len(video)):
|
175 |
+
# frame = video[i].asnumpy()
|
176 |
+
frame = video[i]
|
177 |
+
if hasattr(video, 'seek'):
|
178 |
+
video.seek(0)
|
179 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
180 |
+
|
181 |
+
# Get landmarks for this frame
|
182 |
+
frame_landmarks = landmarks_data.get(i, None)
|
183 |
+
|
184 |
+
# Handle missing landmarks
|
185 |
+
if frame_landmarks is None:
|
186 |
+
if prev_landmarks is not None:
|
187 |
+
frame_landmarks = prev_landmarks
|
188 |
+
else:
|
189 |
+
# Use blank frames if no landmarks available
|
190 |
+
left_hand_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
|
191 |
+
right_hand_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
|
192 |
+
continue
|
193 |
+
else:
|
194 |
+
prev_landmarks = frame_landmarks
|
195 |
+
|
196 |
+
# Check if pose landmarks exist
|
197 |
+
if frame_landmarks.get('pose_landmarks') is None:
|
198 |
+
# Use previous frames or blank frames
|
199 |
+
if prev_left_frame is not None:
|
200 |
+
left_hand_frames.append(prev_left_frame)
|
201 |
+
else:
|
202 |
+
left_hand_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
|
203 |
+
|
204 |
+
if prev_right_frame is not None:
|
205 |
+
right_hand_frames.append(prev_right_frame)
|
206 |
+
else:
|
207 |
+
right_hand_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
|
208 |
+
continue
|
209 |
+
|
210 |
+
# Select hands based on pose landmarks
|
211 |
+
left_hand_landmarks, right_hand_landmarks = self.select_hands(
|
212 |
+
frame_landmarks['pose_landmarks'][0],
|
213 |
+
frame_landmarks.get('hand_landmarks'),
|
214 |
+
frame_rgb.shape
|
215 |
+
)
|
216 |
+
|
217 |
+
# Process left hand
|
218 |
+
if left_hand_landmarks is not None:
|
219 |
+
left_box = self.get_bounding_box(left_hand_landmarks, frame_rgb.shape, self.scale_factor)
|
220 |
+
left_box = self.adjust_bounding_box(left_box, frame_rgb.shape)
|
221 |
+
left_frame = self.crop_frame(frame_rgb, left_box)
|
222 |
+
left_frame = self.resize_frame(left_frame, self.output_size)
|
223 |
+
left_hand_frames.append(left_frame)
|
224 |
+
prev_left_frame = left_frame
|
225 |
+
elif prev_left_frame is not None:
|
226 |
+
left_hand_frames.append(prev_left_frame)
|
227 |
+
else:
|
228 |
+
left_hand_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
|
229 |
+
|
230 |
+
# Process right hand
|
231 |
+
if right_hand_landmarks is not None:
|
232 |
+
right_box = self.get_bounding_box(right_hand_landmarks, frame_rgb.shape, self.scale_factor)
|
233 |
+
right_box = self.adjust_bounding_box(right_box, frame_rgb.shape)
|
234 |
+
right_frame = self.crop_frame(frame_rgb, right_box)
|
235 |
+
right_frame = self.resize_frame(right_frame, self.output_size)
|
236 |
+
right_hand_frames.append(right_frame)
|
237 |
+
prev_right_frame = right_frame
|
238 |
+
elif prev_right_frame is not None:
|
239 |
+
right_hand_frames.append(prev_right_frame)
|
240 |
+
else:
|
241 |
+
right_hand_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
|
242 |
+
|
243 |
+
return left_hand_frames, right_hand_frames
|
244 |
+
|
245 |
+
def extract_and_save_hand_videos(self, video_input, landmarks_data: Dict[int, Any],
|
246 |
+
output_dir: str, video_name: Optional[str] = None) -> Tuple[str, str]:
|
247 |
+
"""
|
248 |
+
Extract hand frames and save as video files.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
video_input: Either a path to video file (str) or a decord.VideoReader object
|
252 |
+
landmarks_data: Dictionary containing pose and hand landmarks for each frame
|
253 |
+
output_dir: Directory to save the hand videos
|
254 |
+
video_name: Name for output videos (auto-generated if not provided)
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
Tuple of (left_hand_video_path, right_hand_video_path)
|
258 |
+
"""
|
259 |
+
# Handle video input and get FPS
|
260 |
+
if isinstance(video_input, str):
|
261 |
+
video_path = Path(video_input)
|
262 |
+
if not video_path.exists():
|
263 |
+
raise FileNotFoundError(f"Video file not found: {video_input}")
|
264 |
+
video = decord.VideoReader(str(video_path))
|
265 |
+
if video_name is None:
|
266 |
+
video_name = video_path.stem
|
267 |
+
# elif hasattr(video_input, '__len__') and hasattr(video_input, '__getitem__'):
|
268 |
+
else:
|
269 |
+
video = video_input
|
270 |
+
if video_name is None:
|
271 |
+
video_name = "video"
|
272 |
+
# else:
|
273 |
+
# raise TypeError("video_input must be either a file path (str) or a VideoReader object")
|
274 |
+
|
275 |
+
fps = video.get_avg_fps() if hasattr(video, 'get_avg_fps') else 30.0
|
276 |
+
|
277 |
+
# Create output directory
|
278 |
+
output_path = Path(output_dir)
|
279 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
280 |
+
|
281 |
+
# Define output paths
|
282 |
+
left_hand_path = output_path / f"{video_name}_hand1.mp4"
|
283 |
+
right_hand_path = output_path / f"{video_name}_hand2.mp4"
|
284 |
+
|
285 |
+
# Remove existing files
|
286 |
+
if left_hand_path.exists():
|
287 |
+
left_hand_path.unlink()
|
288 |
+
if right_hand_path.exists():
|
289 |
+
right_hand_path.unlink()
|
290 |
+
|
291 |
+
# Create video writers
|
292 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
293 |
+
left_writer = cv2.VideoWriter(str(left_hand_path), fourcc, fps, self.output_size)
|
294 |
+
right_writer = cv2.VideoWriter(str(right_hand_path), fourcc, fps, self.output_size)
|
295 |
+
|
296 |
+
try:
|
297 |
+
# Extract hand frames
|
298 |
+
left_frames, right_frames = self.extract_hand_frames(video, landmarks_data)
|
299 |
+
|
300 |
+
# Write frames to video files
|
301 |
+
for left_frame, right_frame in zip(left_frames, right_frames):
|
302 |
+
left_writer.write(left_frame)
|
303 |
+
right_writer.write(right_frame)
|
304 |
+
|
305 |
+
finally:
|
306 |
+
# Clean up
|
307 |
+
left_writer.release()
|
308 |
+
right_writer.release()
|
309 |
+
del left_writer
|
310 |
+
del right_writer
|
311 |
+
|
312 |
+
return str(left_hand_path), str(right_hand_path)
|
313 |
+
|
314 |
+
|
315 |
+
# Convenience function for backward compatibility
|
316 |
+
def extract_hand_frames(video_input, landmarks_data: Dict[int, Any],
|
317 |
+
output_size: Tuple[int, int] = (224, 224)) -> Tuple[List[np.ndarray], List[np.ndarray]]:
|
318 |
+
"""
|
319 |
+
Convenience function to extract hand frames from video.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
video_input: Either a path to video file (str) or a decord.VideoReader object
|
323 |
+
landmarks_data: Dictionary containing pose and hand landmarks for each frame
|
324 |
+
output_size: Size of the output hand frames (width, height)
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
Tuple of (left_hand_frames, right_hand_frames) as lists of numpy arrays
|
328 |
+
"""
|
329 |
+
extractor = HandExtractor(output_size=output_size)
|
330 |
+
return extractor.extract_hand_frames(video_input, landmarks_data)
|
331 |
+
|
332 |
+
|
333 |
+
def video_holistic(video_file: str, hand_path: str, problem_file_path: str, pose_path: str):
|
334 |
+
"""
|
335 |
+
Original function for backward compatibility with command-line usage.
|
336 |
+
"""
|
337 |
+
try:
|
338 |
+
video = decord.VideoReader(video_file)
|
339 |
+
fps = video.get_avg_fps()
|
340 |
+
|
341 |
+
video_name = Path(video_file).stem
|
342 |
+
clip_hand1_path = Path(hand_path) / f"{video_name}_hand1.mp4"
|
343 |
+
clip_hand2_path = Path(hand_path) / f"{video_name}_hand2.mp4"
|
344 |
+
landmark_json_path = Path(pose_path) / f"{video_name}_pose.json"
|
345 |
+
|
346 |
+
# Load landmarks
|
347 |
+
with open(landmark_json_path, 'r') as rd:
|
348 |
+
landmarks_data = json.load(rd)
|
349 |
+
|
350 |
+
# Convert string keys to integers
|
351 |
+
landmarks_data = {int(k): v for k, v in landmarks_data.items()}
|
352 |
+
|
353 |
+
# Extract hand videos
|
354 |
+
extractor = HandExtractor()
|
355 |
+
extractor.extract_and_save_hand_videos(video, landmarks_data, hand_path, video_name)
|
356 |
+
|
357 |
+
except Exception as e:
|
358 |
+
print(f"Error processing {video_file}: {e}")
|
359 |
+
with open(problem_file_path, "a") as p:
|
360 |
+
p.write(video_file + "\n")
|
361 |
+
|
362 |
+
|
363 |
+
# Utility functions for batch processing
|
364 |
+
def load_file(filename: str):
|
365 |
+
"""Load a pickled and gzipped file."""
|
366 |
+
with gzip.open(filename, "rb") as f:
|
367 |
+
return pickle.load(f)
|
368 |
+
|
369 |
+
|
370 |
+
def is_string_in_file(file_path: str, target_string: str) -> bool:
|
371 |
+
"""Check if a string exists in a file."""
|
372 |
+
try:
|
373 |
+
with Path(file_path).open("r") as f:
|
374 |
+
for line in f:
|
375 |
+
if target_string in line:
|
376 |
+
return True
|
377 |
+
return False
|
378 |
+
except Exception as e:
|
379 |
+
print(f"Error: {e}")
|
380 |
+
return False
|
381 |
+
|
382 |
+
|
383 |
+
def main():
|
384 |
+
"""Main function for command-line usage."""
|
385 |
+
parser = argparse.ArgumentParser()
|
386 |
+
parser.add_argument('--index', type=int, required=True,
|
387 |
+
help='index of the sub_list to work with')
|
388 |
+
parser.add_argument('--batch_size', type=int, required=True,
|
389 |
+
help='batch size')
|
390 |
+
parser.add_argument('--time_limit', type=int, required=True,
|
391 |
+
help='time limit')
|
392 |
+
parser.add_argument('--files_list', type=str, required=True,
|
393 |
+
help='files list')
|
394 |
+
parser.add_argument('--problem_file_path', type=str, required=True,
|
395 |
+
help='problem file path')
|
396 |
+
parser.add_argument('--pose_path', type=str, required=True,
|
397 |
+
help='pose path')
|
398 |
+
parser.add_argument('--hand_path', type=str, required=True,
|
399 |
+
help='hand path')
|
400 |
+
|
401 |
+
args = parser.parse_args()
|
402 |
+
start_time = time.time()
|
403 |
+
|
404 |
+
# Create directories if they do not exist
|
405 |
+
Path(args.hand_path).mkdir(parents=True, exist_ok=True)
|
406 |
+
|
407 |
+
# Load files list
|
408 |
+
fixed_list = load_file(args.files_list)
|
409 |
+
|
410 |
+
# Create problem file if it doesn't exist
|
411 |
+
if not os.path.exists(args.problem_file_path):
|
412 |
+
with open(args.problem_file_path, "w") as f:
|
413 |
+
f.write("")
|
414 |
+
|
415 |
+
# Process videos in batches
|
416 |
+
video_batches = [fixed_list[i:i + args.batch_size] for i in range(0, len(fixed_list), args.batch_size)]
|
417 |
+
|
418 |
+
for video_file in video_batches[args.index]:
|
419 |
+
current_time = time.time()
|
420 |
+
if current_time - start_time > args.time_limit:
|
421 |
+
print("Time limit reached. Stopping execution.")
|
422 |
+
break
|
423 |
+
|
424 |
+
video_name = Path(video_file).stem
|
425 |
+
clip_hand2_path = Path(args.hand_path) / f"{video_name}_hand2.mp4"
|
426 |
+
|
427 |
+
if clip_hand2_path.exists():
|
428 |
+
print(f"Skipping {video_file} - output already exists")
|
429 |
+
continue
|
430 |
+
elif is_string_in_file(args.problem_file_path, video_file):
|
431 |
+
print(f"Skipping {video_file} - found in problem file")
|
432 |
+
continue
|
433 |
+
else:
|
434 |
+
try:
|
435 |
+
print(f"Processing {video_file}")
|
436 |
+
video_holistic(video_file, args.hand_path, args.problem_file_path, args.pose_path)
|
437 |
+
print(f"Successfully processed {video_file}")
|
438 |
+
except Exception as e:
|
439 |
+
print(f"Error processing {video_file}: {e}")
|
440 |
+
with open(args.problem_file_path, "a") as p:
|
441 |
+
p.write(video_file + "\n")
|
442 |
+
|
443 |
+
|
444 |
+
if __name__ == "__main__":
|
445 |
+
main()
|
dinov2_features.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image
|
5 |
+
import decord
|
6 |
+
from decord import VideoReader
|
7 |
+
from decord import cpu, gpu
|
8 |
+
import numpy as np
|
9 |
+
import os
|
10 |
+
import pickle
|
11 |
+
import gzip
|
12 |
+
from pathlib import Path
|
13 |
+
import argparse
|
14 |
+
import json
|
15 |
+
import csv
|
16 |
+
import glob
|
17 |
+
import time
|
18 |
+
from typing import List, Union, Optional, Tuple
|
19 |
+
|
20 |
+
|
21 |
+
class DINOEmbedder:
|
22 |
+
"""
|
23 |
+
A class for extracting DINOv2 embeddings from video frames or images.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, dino_model_path: str, batch_size: int = 128, device: Optional[str] = None):
|
27 |
+
"""
|
28 |
+
Initialize the DINOEmbedder.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
dino_model_path: Path to the fine-tuned DINOv2 model
|
32 |
+
batch_size: Batch size for processing frames
|
33 |
+
device: Device to use ('cuda' or 'cpu'). Auto-detected if None
|
34 |
+
"""
|
35 |
+
self.dino_model_path = dino_model_path
|
36 |
+
self.batch_size = batch_size
|
37 |
+
self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
38 |
+
|
39 |
+
# Initialize model
|
40 |
+
self.model = self._load_dino_model()
|
41 |
+
self.model.eval()
|
42 |
+
|
43 |
+
# Initialize transform
|
44 |
+
self.transform = transforms.Compose([
|
45 |
+
transforms.Resize((224, 224)),
|
46 |
+
transforms.ToTensor(),
|
47 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
48 |
+
])
|
49 |
+
|
50 |
+
print(f"DINOEmbedder initialized on device: {self.device}")
|
51 |
+
|
52 |
+
def _load_dino_model(self) -> nn.Module:
|
53 |
+
"""Load the fine-tuned DINOv2 model."""
|
54 |
+
# Load the original DINOv2 model with the correct architecture
|
55 |
+
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg', pretrained=False)
|
56 |
+
|
57 |
+
# Load fine-tuned weights
|
58 |
+
pretrained = torch.load(self.dino_model_path, map_location=self.device)
|
59 |
+
|
60 |
+
# Make correct state dict for loading
|
61 |
+
new_state_dict = {}
|
62 |
+
for key, value in pretrained['teacher'].items():
|
63 |
+
if 'dino_head' in key:
|
64 |
+
continue # Skip dino_head layers
|
65 |
+
else:
|
66 |
+
new_key = key.replace('backbone.', '')
|
67 |
+
new_state_dict[new_key] = value
|
68 |
+
|
69 |
+
# Change shape of pos_embed
|
70 |
+
pos_embed = nn.Parameter(torch.zeros(1, 257, 384))
|
71 |
+
model.pos_embed = pos_embed
|
72 |
+
|
73 |
+
# Load state dict
|
74 |
+
model.load_state_dict(new_state_dict, strict=True)
|
75 |
+
|
76 |
+
# Move model to device
|
77 |
+
model.to(self.device)
|
78 |
+
return model
|
79 |
+
|
80 |
+
def _preprocess_frame(self, frame: np.ndarray) -> torch.Tensor:
|
81 |
+
"""Preprocess a single frame."""
|
82 |
+
if isinstance(frame, np.ndarray):
|
83 |
+
image = Image.fromarray(frame)
|
84 |
+
else:
|
85 |
+
image = frame
|
86 |
+
|
87 |
+
tensor = self.transform(image)
|
88 |
+
# Ensure only RGB channels are considered
|
89 |
+
return tensor[:3]
|
90 |
+
|
91 |
+
def _preprocess_frames_batch(self, frames: List[np.ndarray]) -> torch.Tensor:
|
92 |
+
"""Preprocess a batch of frames."""
|
93 |
+
batch_tensors = torch.stack([self._preprocess_frame(frame) for frame in frames])
|
94 |
+
return batch_tensors.to(self.device)
|
95 |
+
|
96 |
+
def extract_embeddings_from_frames(self, frames: List[np.ndarray]) -> np.ndarray:
|
97 |
+
"""
|
98 |
+
Extract DINOv2 embeddings from a list of frames.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
frames: List of frames as numpy arrays
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
Numpy array of embeddings with shape (num_frames, embedding_dim)
|
105 |
+
"""
|
106 |
+
all_embeddings = []
|
107 |
+
|
108 |
+
# Process frames in batches
|
109 |
+
for idx in range(0, len(frames), self.batch_size):
|
110 |
+
batch_frames = frames[idx:idx + self.batch_size]
|
111 |
+
|
112 |
+
# Preprocess batch
|
113 |
+
batch_tensors = self._preprocess_frames_batch(batch_frames)
|
114 |
+
|
115 |
+
# Extract embeddings
|
116 |
+
with torch.no_grad():
|
117 |
+
batch_embeddings = self.model(batch_tensors).cpu().numpy()
|
118 |
+
|
119 |
+
all_embeddings.append(batch_embeddings)
|
120 |
+
|
121 |
+
# Concatenate all embeddings
|
122 |
+
embeddings = np.concatenate(all_embeddings, axis=0)
|
123 |
+
return embeddings
|
124 |
+
|
125 |
+
def extract_embeddings_from_video(self, video_input: Union[str, VideoReader],
|
126 |
+
target_size: Tuple[int, int] = (224, 224)) -> np.ndarray:
|
127 |
+
"""
|
128 |
+
Extract DINOv2 embeddings from a video.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
video_input: Either a path to video file (str) or a VideoReader object
|
132 |
+
target_size: Target size for video frames (width, height)
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
Numpy array of embeddings with shape (num_frames, embedding_dim)
|
136 |
+
"""
|
137 |
+
# Handle different input types
|
138 |
+
if isinstance(video_input, str):
|
139 |
+
video_path = Path(video_input)
|
140 |
+
if not video_path.exists():
|
141 |
+
raise FileNotFoundError(f"Video file not found: {video_input}")
|
142 |
+
try:
|
143 |
+
vr = VideoReader(str(video_path), width=target_size[0], height=target_size[1])
|
144 |
+
except Exception as e:
|
145 |
+
raise RuntimeError(f"Error loading video {video_input}: {e}")
|
146 |
+
# elif hasattr(video_input, 'get_batch'):
|
147 |
+
else:
|
148 |
+
vr = video_input
|
149 |
+
# else:
|
150 |
+
# raise TypeError("video_input must be either a file path (str) or a VideoReader object")
|
151 |
+
|
152 |
+
total_frames = len(vr)
|
153 |
+
all_embeddings = []
|
154 |
+
|
155 |
+
# Process video in batches
|
156 |
+
for idx in range(0, total_frames, self.batch_size):
|
157 |
+
batch_indices = range(idx, min(idx + self.batch_size, total_frames))
|
158 |
+
# batch_frames = vr.get_batch(batch_indices).asnumpy()
|
159 |
+
batch_frames = vr[batch_indices]
|
160 |
+
|
161 |
+
# Preprocess batch
|
162 |
+
batch_tensors = self._preprocess_frames_batch(batch_frames)
|
163 |
+
|
164 |
+
# Extract embeddings
|
165 |
+
with torch.no_grad():
|
166 |
+
batch_embeddings = self.model(batch_tensors).cpu().numpy()
|
167 |
+
|
168 |
+
all_embeddings.append(batch_embeddings)
|
169 |
+
|
170 |
+
# Concatenate all embeddings
|
171 |
+
embeddings = np.concatenate(all_embeddings, axis=0)
|
172 |
+
return embeddings
|
173 |
+
|
174 |
+
def extract_embeddings_from_video_and_save(self, video_path: str, output_folder: str) -> str:
|
175 |
+
"""
|
176 |
+
Extract embeddings from video and save to file.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
video_path: Path to the video file
|
180 |
+
output_folder: Folder to save the embeddings
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
Path to the saved embeddings file
|
184 |
+
"""
|
185 |
+
# Create output folder if it doesn't exist
|
186 |
+
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
187 |
+
|
188 |
+
# Extract embeddings
|
189 |
+
embeddings = self.extract_embeddings_from_video(video_path)
|
190 |
+
|
191 |
+
# Save embeddings
|
192 |
+
video_name = Path(video_path).stem
|
193 |
+
np_path = Path(output_folder) / f"{video_name}.npy"
|
194 |
+
np.save(np_path, embeddings)
|
195 |
+
|
196 |
+
return str(np_path)
|
197 |
+
|
198 |
+
def extract_embedding_from_single_image(self, image: Union[np.ndarray, Image.Image]) -> np.ndarray:
|
199 |
+
"""
|
200 |
+
Extract DINOv2 embedding from a single image.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
image: Image as numpy array or PIL Image
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
Numpy array of embedding with shape (1, embedding_dim)
|
207 |
+
"""
|
208 |
+
# Preprocess image
|
209 |
+
if isinstance(image, np.ndarray):
|
210 |
+
image = Image.fromarray(image)
|
211 |
+
|
212 |
+
tensor = self.transform(image).unsqueeze(0).to(self.device)
|
213 |
+
|
214 |
+
# Extract embedding
|
215 |
+
with torch.no_grad():
|
216 |
+
embedding = self.model(tensor).cpu().numpy()
|
217 |
+
|
218 |
+
return embedding
|
219 |
+
|
220 |
+
|
221 |
+
# Convenience functions for backward compatibility
|
222 |
+
def extract_embeddings_from_frames(frames: List[np.ndarray], dino_model_path: str,
|
223 |
+
batch_size: int = 128) -> np.ndarray:
|
224 |
+
"""
|
225 |
+
Convenience function to extract embeddings from frames.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
frames: List of frames as numpy arrays
|
229 |
+
dino_model_path: Path to the fine-tuned DINOv2 model
|
230 |
+
batch_size: Batch size for processing
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
Numpy array of embeddings
|
234 |
+
"""
|
235 |
+
embedder = DINOEmbedder(dino_model_path, batch_size)
|
236 |
+
return embedder.extract_embeddings_from_frames(frames)
|
237 |
+
|
238 |
+
|
239 |
+
def extract_embeddings_from_video(video_path: str, dino_model_path: str,
|
240 |
+
batch_size: int = 128) -> np.ndarray:
|
241 |
+
"""
|
242 |
+
Convenience function to extract embeddings from video.
|
243 |
+
|
244 |
+
Args:
|
245 |
+
video_path: Path to the video file
|
246 |
+
dino_model_path: Path to the fine-tuned DINOv2 model
|
247 |
+
batch_size: Batch size for processing
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
Numpy array of embeddings
|
251 |
+
"""
|
252 |
+
embedder = DINOEmbedder(dino_model_path, batch_size)
|
253 |
+
return embedder.extract_embeddings_from_video(video_path)
|
254 |
+
|
255 |
+
|
256 |
+
def video_to_embeddings(video_path: str, output_folder: str, dino_path: str, batch_size: int = 128):
|
257 |
+
"""
|
258 |
+
Original function for backward compatibility with command-line usage.
|
259 |
+
"""
|
260 |
+
try:
|
261 |
+
embedder = DINOEmbedder(dino_path, batch_size)
|
262 |
+
embedder.extract_embeddings_from_video_and_save(video_path, output_folder)
|
263 |
+
except Exception as e:
|
264 |
+
print(f'Error processing {video_path}: {e}')
|
265 |
+
|
266 |
+
|
267 |
+
# Utility functions for batch processing
|
268 |
+
def get_mp4_files(directory: str) -> List[str]:
|
269 |
+
"""Get all MP4 files in a directory."""
|
270 |
+
if not os.path.exists(directory):
|
271 |
+
raise FileNotFoundError(f'Directory not found: {directory}')
|
272 |
+
|
273 |
+
mp4_files = glob.glob(os.path.join(directory, '*.mp4'))
|
274 |
+
return [os.path.abspath(file) for file in mp4_files]
|
275 |
+
|
276 |
+
|
277 |
+
def load_file(filename: str):
|
278 |
+
"""Load a pickled and gzipped file."""
|
279 |
+
with gzip.open(filename, "rb") as f:
|
280 |
+
return pickle.load(f)
|
281 |
+
|
282 |
+
|
283 |
+
def is_string_in_file(file_path: str, target_string: str) -> bool:
|
284 |
+
"""Check if a string exists in a file."""
|
285 |
+
try:
|
286 |
+
with Path(file_path).open("r") as f:
|
287 |
+
for line in f:
|
288 |
+
if target_string in line:
|
289 |
+
return True
|
290 |
+
return False
|
291 |
+
except Exception as e:
|
292 |
+
print(f"Error: {e}")
|
293 |
+
return False
|
294 |
+
|
295 |
+
|
296 |
+
def main():
|
297 |
+
"""Main function for command-line usage."""
|
298 |
+
parser = argparse.ArgumentParser()
|
299 |
+
parser.add_argument('--index', type=int, required=True,
|
300 |
+
help='index of the sub_list to work with')
|
301 |
+
parser.add_argument('--time_limit', type=int, required=True,
|
302 |
+
help='time limit in seconds')
|
303 |
+
parser.add_argument('--batch_size', type=int, required=True,
|
304 |
+
help='number of videos to process in this batch')
|
305 |
+
parser.add_argument('--files_list', type=str, required=True,
|
306 |
+
help='path to the files list file')
|
307 |
+
parser.add_argument('--output_folder', type=str, required=True,
|
308 |
+
help='path to the output folder')
|
309 |
+
parser.add_argument('--dino_path', type=str, required=True,
|
310 |
+
help='path to the dino model')
|
311 |
+
|
312 |
+
args = parser.parse_args()
|
313 |
+
start_time = time.time()
|
314 |
+
|
315 |
+
# Load files list
|
316 |
+
fixed_list = load_file(args.files_list)
|
317 |
+
|
318 |
+
# Create output folder if it doesn't exist
|
319 |
+
if not os.path.exists(args.output_folder):
|
320 |
+
os.makedirs(args.output_folder)
|
321 |
+
|
322 |
+
# Initialize embedder
|
323 |
+
embedder = DINOEmbedder(args.dino_path, batch_size=512)
|
324 |
+
|
325 |
+
# Process videos in batches
|
326 |
+
video_batches = [fixed_list[i:i + args.batch_size] for i in range(0, len(fixed_list), args.batch_size)]
|
327 |
+
print(f"Total number of video batches: {len(video_batches)}")
|
328 |
+
|
329 |
+
for video_path in video_batches[args.index]:
|
330 |
+
current_time = time.time()
|
331 |
+
if current_time - start_time > args.time_limit:
|
332 |
+
print("Time limit reached. Stopping execution.")
|
333 |
+
break
|
334 |
+
|
335 |
+
video_name = Path(video_path).stem
|
336 |
+
np_path = Path(args.output_folder) / f"{video_name}.npy"
|
337 |
+
|
338 |
+
if np_path.exists():
|
339 |
+
print(f"Skipping {video_path} - output already exists")
|
340 |
+
continue
|
341 |
+
else:
|
342 |
+
try:
|
343 |
+
print(f"Processing {video_path}")
|
344 |
+
embedder.extract_embeddings_from_video_and_save(video_path, args.output_folder)
|
345 |
+
print(f"Successfully processed {video_path}")
|
346 |
+
except Exception as e:
|
347 |
+
print(f"Error processing {video_path}: {e}")
|
348 |
+
|
349 |
+
|
350 |
+
if __name__ == "__main__":
|
351 |
+
main()
|
features.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import decord
|
5 |
+
import torch.nn as nn
|
6 |
+
import json
|
7 |
+
import cv2
|
8 |
+
from kpe_mediapipe import video_holistic
|
9 |
+
from crop_hands import HandExtractor
|
10 |
+
from crop_face import FaceExtractor
|
11 |
+
from dinov2_features import extract_embeddings_from_frames
|
12 |
+
from body_features import process_pose_landmarks
|
13 |
+
# from shubert import SignHubertModel, SignHubertConfig
|
14 |
+
from inference import test
|
15 |
+
import subprocess
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
class SHuBERTProcessor:
|
20 |
+
|
21 |
+
def __init__(self, config):
|
22 |
+
self.config = config
|
23 |
+
|
24 |
+
def process_video(self, video_path):
|
25 |
+
|
26 |
+
# output_file = f"{output_path}/{os.path.basename(video_file)}"
|
27 |
+
|
28 |
+
|
29 |
+
# # Target FPS is 12.5
|
30 |
+
# cmd = [
|
31 |
+
# 'ffmpeg',
|
32 |
+
# '-i', video_path,
|
33 |
+
# '-filter:v', 'fps=15',
|
34 |
+
# '-c:v', 'libx264',
|
35 |
+
# '-preset', 'medium', # Balance between speed and quality
|
36 |
+
# '-crf', '23', # Quality level (lower is better)
|
37 |
+
# '-y', # Overwrite output file if it exists
|
38 |
+
# video_path
|
39 |
+
# ]
|
40 |
+
|
41 |
+
|
42 |
+
# try:
|
43 |
+
# subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
44 |
+
# print(f"Saved to {video_path} at 15 fps")
|
45 |
+
# except subprocess.CalledProcessError as e:
|
46 |
+
# print(f"Error reading video {video_path}: {e}")
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
# Step 1: Change the fps to 15
|
51 |
+
signer_video = decord.VideoReader(video_path)
|
52 |
+
|
53 |
+
signer_video_fps = signer_video.get_avg_fps()
|
54 |
+
target_fps = 12
|
55 |
+
stride = max(1, int(round(signer_video_fps / target_fps)))
|
56 |
+
index_list = list(range(0, len(signer_video), stride))
|
57 |
+
signer_video = signer_video.get_batch(index_list)
|
58 |
+
signer_video = signer_video.asnumpy()
|
59 |
+
|
60 |
+
# Step 2: Extract pose using kpe_mediapipe
|
61 |
+
landmarks = video_holistic(
|
62 |
+
video_input=signer_video,
|
63 |
+
face_model_path=self.config['mediapipe_face_model_path'],
|
64 |
+
hand_model_path=self.config['mediapipe_hands_model_path'],
|
65 |
+
)
|
66 |
+
|
67 |
+
# Step 3: Extract stream features
|
68 |
+
hand_extractor = HandExtractor()
|
69 |
+
left_hand_frames, right_hand_frames = hand_extractor.extract_hand_frames(signer_video, landmarks)
|
70 |
+
left_hand_embeddings = extract_embeddings_from_frames(left_hand_frames, self.config['dino_hands_model_path'])
|
71 |
+
right_hand_embeddings = extract_embeddings_from_frames(right_hand_frames, self.config['dino_hands_model_path'])
|
72 |
+
del left_hand_frames, right_hand_frames
|
73 |
+
|
74 |
+
face_extractor = FaceExtractor()
|
75 |
+
face_frames = face_extractor.extract_face_frames(signer_video, landmarks)
|
76 |
+
face_embeddings = extract_embeddings_from_frames(face_frames, self.config['dino_face_model_path'])
|
77 |
+
del face_frames, signer_video
|
78 |
+
|
79 |
+
pose_embeddings = process_pose_landmarks(landmarks)
|
80 |
+
del landmarks
|
81 |
+
|
82 |
+
output_text = test(face_embeddings,
|
83 |
+
left_hand_embeddings,
|
84 |
+
right_hand_embeddings,
|
85 |
+
pose_embeddings,
|
86 |
+
self.config['slt_model_config'],
|
87 |
+
self.config['slt_model_checkpoint'],
|
88 |
+
self.config['slt_tokenizer_checkpoint'],
|
89 |
+
self.config['temp_dir'])
|
90 |
+
|
91 |
+
return output_text
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
config = {
|
95 |
+
'yolov8_model_path': '/share/data/pals/shester/inference/models/yolov8n.pt',
|
96 |
+
'dino_face_model_path': '/share/data/pals/shester/inference/models/dinov2face.pth',
|
97 |
+
'dino_hands_model_path': '/share/data/pals/shester/inference/models/dinov2hand.pth',
|
98 |
+
'mediapipe_face_model_path': '/share/data/pals/shester/inference/models/face_landmarker_v2_with_blendshapes.task',
|
99 |
+
'mediapipe_hands_model_path': '/share/data/pals/shester/inference/models/hand_landmarker.task',
|
100 |
+
'shubert_model_path': '/share/data/pals/shester/inference/models/checkpoint_836_400000.pt',
|
101 |
+
'temp_dir': '/share/data/pals/shester/inference',
|
102 |
+
'slt_model_config': '/share/data/pals/shester/inference/models/byt5_base/config.json',
|
103 |
+
'slt_model_checkpoint': '/share/data/pals/shester/inference/models/checkpoint-11625',
|
104 |
+
'slt_tokenizer_checkpoint': '/share/data/pals/shester/inference/models/byt5_base',
|
105 |
+
}
|
106 |
+
|
107 |
+
# input_clip = "/share/data/pals/shester/datasets/openasl/clips_bbox/J-0KHhPS_m4.029676-029733.mp4"
|
108 |
+
# input_clip = "/share/data/pals/shester/inference/recordings/sabrin30fps.mp4"
|
109 |
+
input_clip = "/share/data/pals/shester/inference/recordings/sabrina30fps.mp4"
|
110 |
+
processor = SHuBERTProcessor(config)
|
111 |
+
output_text = processor.process_video(input_clip)
|
112 |
+
print(f"The English translation is: {output_text}")
|
113 |
+
|
114 |
+
# /home-nfs/shesterg/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/attention.py
|
115 |
+
# /home-nfs/shesterg/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/block.py
|
inference.py
ADDED
@@ -0,0 +1,738 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
import random
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
from transformers import (
|
11 |
+
ByT5Tokenizer,
|
12 |
+
Seq2SeqTrainingArguments,
|
13 |
+
Seq2SeqTrainer,
|
14 |
+
)
|
15 |
+
|
16 |
+
from transformers.models.t5 import T5Config
|
17 |
+
from transformers.models.t5.modeling_t5 import *
|
18 |
+
|
19 |
+
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput
|
20 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
21 |
+
from torch.nn import CrossEntropyLoss
|
22 |
+
|
23 |
+
from collections.abc import Mapping
|
24 |
+
from dataclasses import dataclass
|
25 |
+
from random import randint
|
26 |
+
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
|
27 |
+
from transformers.utils import PaddingStrategy
|
28 |
+
|
29 |
+
from shubert import SignHubertModel, SignHubertConfig
|
30 |
+
|
31 |
+
class SignHubertAdapter(nn.Module):
|
32 |
+
def __init__(self, channels):
|
33 |
+
super().__init__()
|
34 |
+
# Adjust intermediate_dim based on number of channels
|
35 |
+
intermediate_dim_shubert = 1024
|
36 |
+
|
37 |
+
self.signhubert = SignHubertModel(SignHubertConfig(
|
38 |
+
channels=channels,
|
39 |
+
intermediate_dim=intermediate_dim_shubert
|
40 |
+
))
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
features = self.signhubert.extract_features(x, padding_mask=None, kmeans_labels=None, mask=False)
|
44 |
+
|
45 |
+
# Extract layer outputs
|
46 |
+
layer_outputs = []
|
47 |
+
for layer in features['layer_results']:
|
48 |
+
layer_output = layer[-1] # Shape: [B, T, D]
|
49 |
+
layer_outputs.append(layer_output)
|
50 |
+
|
51 |
+
# Stack the outputs from all layers
|
52 |
+
stacked_features = torch.stack(layer_outputs, dim=1) # Shape: [B, L, T, D]
|
53 |
+
return stacked_features
|
54 |
+
|
55 |
+
class LinearAdapter(nn.Module):
|
56 |
+
def __init__(self, face_dim, hand_dim, pose_dim, representations_dim, out_dim, extraction_layer, channels):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.signhubert_adapter = SignHubertAdapter(channels)
|
60 |
+
self.layer_weights = nn.Parameter(torch.ones(12)) # Learnable weights for each layer
|
61 |
+
self.final_layer = nn.Linear(representations_dim, out_dim)
|
62 |
+
self.extraction_layer = extraction_layer
|
63 |
+
|
64 |
+
def forward(self, face_features, left_hand_features, right_hand_features, body_posture_features):
|
65 |
+
dtype = torch.float32
|
66 |
+
face_features = face_features.to(dtype=dtype)
|
67 |
+
left_hand_features = left_hand_features.to(dtype=dtype)
|
68 |
+
right_hand_features = right_hand_features.to(dtype=dtype)
|
69 |
+
body_posture_features = body_posture_features.to(dtype=dtype)
|
70 |
+
|
71 |
+
batch_size, seq_len = face_features.shape[:2]
|
72 |
+
dummy_labels = torch.zeros((seq_len, 1), dtype=dtype, device=face_features.device)
|
73 |
+
|
74 |
+
source = []
|
75 |
+
for i in range(batch_size):
|
76 |
+
source.append({
|
77 |
+
"face": face_features[i],
|
78 |
+
"left_hand": left_hand_features[i],
|
79 |
+
"right_hand": right_hand_features[i],
|
80 |
+
"body_posture": body_posture_features[i],
|
81 |
+
"label_face": dummy_labels,
|
82 |
+
"label_left_hand": dummy_labels,
|
83 |
+
"label_right_hand": dummy_labels,
|
84 |
+
"label_body_posture": dummy_labels
|
85 |
+
})
|
86 |
+
|
87 |
+
# Get representations from SignHubert
|
88 |
+
representations_features = self.signhubert_adapter(source) # [T, L, B, D]
|
89 |
+
representations_features = representations_features.permute(2, 1, 0, 3) # [B, L, T, D]
|
90 |
+
|
91 |
+
if self.extraction_layer == 0:
|
92 |
+
normalized_weights = self.layer_weights
|
93 |
+
weighted_representations = representations_features * normalized_weights.view(1, -1, 1, 1)
|
94 |
+
representations_for_downstream_task = torch.sum(weighted_representations, dim=1)
|
95 |
+
else:
|
96 |
+
representations_for_downstream_task = representations_features[:, self.extraction_layer-1, :, :]
|
97 |
+
|
98 |
+
final_output = self.final_layer(representations_for_downstream_task)
|
99 |
+
|
100 |
+
return final_output
|
101 |
+
|
102 |
+
class SignLanguageByT5Config(T5Config):
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
representations_dim=768,
|
106 |
+
adapter="linear",
|
107 |
+
finetune_signhubert=False,
|
108 |
+
face_dim=384,
|
109 |
+
hand_dim=384,
|
110 |
+
pose_dim=14,
|
111 |
+
extraction_layer=0, # use last layer
|
112 |
+
channels="face,left_hand,right_hand,body_posture",
|
113 |
+
**kwargs
|
114 |
+
):
|
115 |
+
self.representations_dim = representations_dim
|
116 |
+
self.adapter = adapter
|
117 |
+
self.finetune_signhubert = finetune_signhubert
|
118 |
+
self.face_dim = face_dim
|
119 |
+
self.hand_dim = hand_dim
|
120 |
+
self.pose_dim = pose_dim
|
121 |
+
self.extraction_layer = extraction_layer
|
122 |
+
self.channels = channels
|
123 |
+
super().__init__(**kwargs)
|
124 |
+
|
125 |
+
class SignLanguageByT5Encoder(T5PreTrainedModel):
|
126 |
+
def __init__(self, config):
|
127 |
+
super().__init__(config)
|
128 |
+
|
129 |
+
# Initialize the adapter based on the configuration
|
130 |
+
if config.adapter == "linear":
|
131 |
+
self.adapter = LinearAdapter(
|
132 |
+
config.face_dim,
|
133 |
+
config.hand_dim,
|
134 |
+
config.pose_dim,
|
135 |
+
config.representations_dim,
|
136 |
+
config.d_model,
|
137 |
+
config.extraction_layer,
|
138 |
+
config.channels
|
139 |
+
)
|
140 |
+
else:
|
141 |
+
raise NotImplementedError("Adapter type not implemented.")
|
142 |
+
|
143 |
+
self.is_decoder = config.is_decoder
|
144 |
+
|
145 |
+
# Define the encoder blocks
|
146 |
+
self.block = nn.ModuleList(
|
147 |
+
[T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
|
148 |
+
)
|
149 |
+
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
150 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
151 |
+
|
152 |
+
# Initialize weights and apply final processing
|
153 |
+
self.post_init()
|
154 |
+
|
155 |
+
# Model parallel settings
|
156 |
+
self.model_parallel = False
|
157 |
+
self.device_map = None
|
158 |
+
self.gradient_checkpointing = False
|
159 |
+
|
160 |
+
def parallelize(self, device_map=None):
|
161 |
+
warnings.warn(
|
162 |
+
"`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
|
163 |
+
" with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
|
164 |
+
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
|
165 |
+
" 'block.1': 1, ...}",
|
166 |
+
FutureWarning,
|
167 |
+
)
|
168 |
+
# Check validity of device_map
|
169 |
+
self.device_map = (
|
170 |
+
get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
|
171 |
+
)
|
172 |
+
assert_device_map(self.device_map, len(self.block))
|
173 |
+
self.model_parallel = True
|
174 |
+
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
|
175 |
+
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
176 |
+
# Load onto devices
|
177 |
+
for k, v in self.device_map.items():
|
178 |
+
for layer in v:
|
179 |
+
cuda_device = "cuda:" + str(k)
|
180 |
+
self.block[layer] = self.block[layer].to(cuda_device)
|
181 |
+
|
182 |
+
# Set embed_tokens to first layer
|
183 |
+
self.embed_tokens = self.embed_tokens.to(self.first_device)
|
184 |
+
# Set final layer norm to last device
|
185 |
+
self.final_layer_norm = self.final_layer_norm.to(self.last_device)
|
186 |
+
|
187 |
+
def deparallelize(self):
|
188 |
+
warnings.warn(
|
189 |
+
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
|
190 |
+
FutureWarning,
|
191 |
+
)
|
192 |
+
self.model_parallel = False
|
193 |
+
self.device_map = None
|
194 |
+
self.first_device = "cpu"
|
195 |
+
self.last_device = "cpu"
|
196 |
+
for i in range(len(self.block)):
|
197 |
+
self.block[i] = self.block[i].to("cpu")
|
198 |
+
self.embed_tokens = self.embed_tokens.to("cpu")
|
199 |
+
self.final_layer_norm = self.final_layer_norm.to("cpu")
|
200 |
+
torch.cuda.empty_cache()
|
201 |
+
|
202 |
+
def get_input_embeddings(self):
|
203 |
+
return self.embed_tokens
|
204 |
+
|
205 |
+
def set_input_embeddings(self, new_embeddings):
|
206 |
+
self.embed_tokens = new_embeddings
|
207 |
+
|
208 |
+
def forward(
|
209 |
+
self,
|
210 |
+
face_features=None,
|
211 |
+
left_hand_features=None,
|
212 |
+
right_hand_features=None,
|
213 |
+
pose_features=None,
|
214 |
+
attention_mask=None,
|
215 |
+
head_mask=None,
|
216 |
+
encoder_hidden_states=None,
|
217 |
+
encoder_attention_mask=None,
|
218 |
+
cross_attn_head_mask=None,
|
219 |
+
past_key_values=None,
|
220 |
+
use_cache=None,
|
221 |
+
output_attentions=None,
|
222 |
+
output_hidden_states=None,
|
223 |
+
return_dict=None,
|
224 |
+
):
|
225 |
+
# Set default values if not provided
|
226 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
227 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
228 |
+
output_hidden_states = (
|
229 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
230 |
+
)
|
231 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
232 |
+
|
233 |
+
# Use the adapter to convert representation features into embeddings
|
234 |
+
inputs_embeds = self.adapter(face_features, left_hand_features, right_hand_features, pose_features)
|
235 |
+
|
236 |
+
input_shape = inputs_embeds.shape[:2]
|
237 |
+
batch_size, seq_length = input_shape
|
238 |
+
|
239 |
+
mask_seq_length = seq_length
|
240 |
+
|
241 |
+
if attention_mask is None:
|
242 |
+
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
243 |
+
|
244 |
+
# Initialize past_key_values if not provided
|
245 |
+
if past_key_values is None:
|
246 |
+
past_key_values = [None] * len(self.block)
|
247 |
+
|
248 |
+
# Extend attention mask
|
249 |
+
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
250 |
+
|
251 |
+
# Prepare head mask if needed
|
252 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
253 |
+
present_key_value_states = () if use_cache else None
|
254 |
+
all_hidden_states = () if output_hidden_states else None
|
255 |
+
all_attentions = () if output_attentions else None
|
256 |
+
|
257 |
+
hidden_states = self.dropout(inputs_embeds)
|
258 |
+
|
259 |
+
# Iterate over each encoder block
|
260 |
+
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
|
261 |
+
layer_head_mask = head_mask[i]
|
262 |
+
|
263 |
+
if output_hidden_states:
|
264 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
265 |
+
|
266 |
+
layer_outputs = layer_module(
|
267 |
+
hidden_states,
|
268 |
+
attention_mask=extended_attention_mask,
|
269 |
+
position_bias=None,
|
270 |
+
encoder_hidden_states=encoder_hidden_states,
|
271 |
+
encoder_attention_mask=encoder_attention_mask,
|
272 |
+
encoder_decoder_position_bias=None,
|
273 |
+
layer_head_mask=layer_head_mask,
|
274 |
+
cross_attn_layer_head_mask=cross_attn_head_mask,
|
275 |
+
past_key_value=past_key_value,
|
276 |
+
use_cache=use_cache,
|
277 |
+
output_attentions=output_attentions,
|
278 |
+
)
|
279 |
+
|
280 |
+
hidden_states = layer_outputs[0]
|
281 |
+
|
282 |
+
if use_cache:
|
283 |
+
present_key_value_states = present_key_value_states + (layer_outputs[1],)
|
284 |
+
|
285 |
+
if output_attentions:
|
286 |
+
all_attentions = all_attentions + (layer_outputs[2],)
|
287 |
+
|
288 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
289 |
+
hidden_states = self.dropout(hidden_states)
|
290 |
+
|
291 |
+
# Add last hidden state
|
292 |
+
if output_hidden_states:
|
293 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
294 |
+
|
295 |
+
if not return_dict:
|
296 |
+
return tuple(
|
297 |
+
v
|
298 |
+
for v in [hidden_states, present_key_value_states, all_hidden_states, all_attentions]
|
299 |
+
if v is not None
|
300 |
+
)
|
301 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
302 |
+
last_hidden_state=hidden_states,
|
303 |
+
past_key_values=present_key_value_states,
|
304 |
+
hidden_states=all_hidden_states,
|
305 |
+
attentions=all_attentions,
|
306 |
+
cross_attentions=None,
|
307 |
+
)
|
308 |
+
|
309 |
+
class SignLanguageByT5ForConditionalGeneration(T5PreTrainedModel):
|
310 |
+
_keys_to_ignore_on_load_unexpected = [
|
311 |
+
"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
|
312 |
+
]
|
313 |
+
_tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
|
314 |
+
|
315 |
+
def __init__(self, config: T5Config):
|
316 |
+
super().__init__(config)
|
317 |
+
self.model_dim = config.d_model
|
318 |
+
|
319 |
+
# Initialize the decoder embedding
|
320 |
+
self.decoder_emb = nn.Embedding(config.vocab_size, config.d_model)
|
321 |
+
|
322 |
+
# Initialize the encoder with the custom SignLanguageByT5Encoder
|
323 |
+
encoder_config = copy.deepcopy(config)
|
324 |
+
encoder_config.is_decoder = False
|
325 |
+
encoder_config.use_cache = False
|
326 |
+
encoder_config.is_encoder_decoder = False
|
327 |
+
self.encoder = SignLanguageByT5Encoder(encoder_config)
|
328 |
+
|
329 |
+
# Initialize the decoder
|
330 |
+
decoder_config = copy.deepcopy(config)
|
331 |
+
decoder_config.is_decoder = True
|
332 |
+
decoder_config.is_encoder_decoder = False
|
333 |
+
decoder_config.num_layers = config.num_decoder_layers
|
334 |
+
self.decoder = T5Stack(decoder_config, self.decoder_emb)
|
335 |
+
|
336 |
+
# Initialize the language modeling head
|
337 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
338 |
+
|
339 |
+
# Initialize weights and apply final processing
|
340 |
+
self.post_init()
|
341 |
+
|
342 |
+
# Model parallel settings
|
343 |
+
self.model_parallel = False
|
344 |
+
self.device_map = None
|
345 |
+
|
346 |
+
def parallelize(self, device_map=None):
|
347 |
+
warnings.warn(
|
348 |
+
"`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
|
349 |
+
" should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
|
350 |
+
" provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
|
351 |
+
" {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
|
352 |
+
FutureWarning,
|
353 |
+
)
|
354 |
+
self.device_map = (
|
355 |
+
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
|
356 |
+
if device_map is None
|
357 |
+
else device_map
|
358 |
+
)
|
359 |
+
assert_device_map(self.device_map, len(self.encoder.block))
|
360 |
+
self.encoder.parallelize(self.device_map)
|
361 |
+
self.decoder.parallelize(self.device_map)
|
362 |
+
self.lm_head = self.lm_head.to(self.decoder.first_device)
|
363 |
+
self.model_parallel = True
|
364 |
+
|
365 |
+
def deparallelize(self):
|
366 |
+
warnings.warn(
|
367 |
+
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
|
368 |
+
FutureWarning,
|
369 |
+
)
|
370 |
+
self.encoder.deparallelize()
|
371 |
+
self.decoder.deparallelize()
|
372 |
+
self.encoder = self.encoder.to("cpu")
|
373 |
+
self.decoder = self.decoder.to("cpu")
|
374 |
+
self.lm_head = self.lm_head.to("cpu")
|
375 |
+
self.model_parallel = False
|
376 |
+
self.device_map = None
|
377 |
+
torch.cuda.empty_cache()
|
378 |
+
|
379 |
+
def get_input_embeddings(self):
|
380 |
+
return self.decoder_emb
|
381 |
+
|
382 |
+
def set_input_embeddings(self, new_embeddings):
|
383 |
+
self.shared = new_embeddings
|
384 |
+
self.encoder.set_input_embeddings(new_embeddings)
|
385 |
+
self.decoder.set_input_embeddings(new_embeddings)
|
386 |
+
|
387 |
+
def set_output_embeddings(self, new_embeddings):
|
388 |
+
self.lm_head = new_embeddings
|
389 |
+
|
390 |
+
def get_output_embeddings(self):
|
391 |
+
return self.lm_head
|
392 |
+
|
393 |
+
def get_encoder(self):
|
394 |
+
return self.encoder
|
395 |
+
|
396 |
+
def get_decoder(self):
|
397 |
+
return self.decoder
|
398 |
+
|
399 |
+
def prepare_inputs_for_generation(
|
400 |
+
self,
|
401 |
+
input_ids,
|
402 |
+
past_key_values=None,
|
403 |
+
attention_mask=None,
|
404 |
+
head_mask=None,
|
405 |
+
decoder_head_mask=None,
|
406 |
+
decoder_attention_mask=None,
|
407 |
+
cross_attn_head_mask=None,
|
408 |
+
use_cache=None,
|
409 |
+
encoder_outputs=None,
|
410 |
+
**kwargs,
|
411 |
+
):
|
412 |
+
# cut decoder_input_ids if past is used
|
413 |
+
if past_key_values is not None:
|
414 |
+
input_ids = input_ids[:, -1:]
|
415 |
+
|
416 |
+
return {
|
417 |
+
"decoder_input_ids": input_ids,
|
418 |
+
"past_key_values": past_key_values,
|
419 |
+
"encoder_outputs": encoder_outputs,
|
420 |
+
"attention_mask": attention_mask,
|
421 |
+
"head_mask": head_mask,
|
422 |
+
"decoder_head_mask": decoder_head_mask,
|
423 |
+
"decoder_attention_mask": decoder_attention_mask,
|
424 |
+
"cross_attn_head_mask": cross_attn_head_mask,
|
425 |
+
"use_cache": use_cache,
|
426 |
+
}
|
427 |
+
|
428 |
+
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
429 |
+
return self._shift_right(labels)
|
430 |
+
|
431 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
432 |
+
# if decoder past is not included in output
|
433 |
+
# speedy decoding is disabled and no need to reorder
|
434 |
+
if past_key_values is None:
|
435 |
+
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
|
436 |
+
return past_key_values
|
437 |
+
|
438 |
+
reordered_decoder_past = ()
|
439 |
+
for layer_past_states in past_key_values:
|
440 |
+
# get the correct batch idx from layer past batch dim
|
441 |
+
# batch dim of `past` is at 2nd position
|
442 |
+
reordered_layer_past_states = ()
|
443 |
+
for layer_past_state in layer_past_states:
|
444 |
+
# need to set correct `past` for each of the four key / value states
|
445 |
+
reordered_layer_past_states = reordered_layer_past_states + (
|
446 |
+
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
|
447 |
+
)
|
448 |
+
|
449 |
+
if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
|
450 |
+
raise ValueError(
|
451 |
+
f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
|
452 |
+
)
|
453 |
+
if len(reordered_layer_past_states) != len(layer_past_states):
|
454 |
+
raise ValueError(
|
455 |
+
f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
|
456 |
+
)
|
457 |
+
|
458 |
+
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
|
459 |
+
return reordered_decoder_past
|
460 |
+
|
461 |
+
def forward(
|
462 |
+
self,
|
463 |
+
face_features=None,
|
464 |
+
left_hand_features=None,
|
465 |
+
right_hand_features=None,
|
466 |
+
pose_features=None,
|
467 |
+
attention_mask=None,
|
468 |
+
decoder_input_ids=None,
|
469 |
+
decoder_attention_mask=None,
|
470 |
+
head_mask=None,
|
471 |
+
decoder_head_mask=None,
|
472 |
+
cross_attn_head_mask=None,
|
473 |
+
encoder_outputs=None,
|
474 |
+
past_key_values=None,
|
475 |
+
decoder_inputs_embeds=None,
|
476 |
+
labels=None, # Keep this for training compatibility
|
477 |
+
use_cache=None,
|
478 |
+
output_attentions=None,
|
479 |
+
output_hidden_states=None,
|
480 |
+
return_dict=None,
|
481 |
+
):
|
482 |
+
# Set default values if not provided
|
483 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
484 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
485 |
+
|
486 |
+
# Prepare head masks if needed
|
487 |
+
if head_mask is not None and decoder_head_mask is None:
|
488 |
+
if self.config.num_layers == self.config.num_decoder_layers:
|
489 |
+
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
490 |
+
decoder_head_mask = head_mask
|
491 |
+
|
492 |
+
# Encode if encoder outputs are not provided
|
493 |
+
if encoder_outputs is None:
|
494 |
+
encoder_outputs = self.encoder(
|
495 |
+
face_features=face_features,
|
496 |
+
left_hand_features=left_hand_features,
|
497 |
+
right_hand_features=right_hand_features,
|
498 |
+
pose_features=pose_features,
|
499 |
+
attention_mask=attention_mask,
|
500 |
+
head_mask=head_mask,
|
501 |
+
output_attentions=output_attentions,
|
502 |
+
output_hidden_states=output_hidden_states,
|
503 |
+
return_dict=return_dict,
|
504 |
+
)
|
505 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutputWithPastAndCrossAttentions):
|
506 |
+
encoder_outputs = BaseModelOutputWithPastAndCrossAttentions(
|
507 |
+
last_hidden_state=encoder_outputs[0],
|
508 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
509 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
510 |
+
)
|
511 |
+
|
512 |
+
hidden_states = encoder_outputs[0]
|
513 |
+
|
514 |
+
# Prepare decoder inputs
|
515 |
+
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
516 |
+
decoder_input_ids = self._shift_right(labels)
|
517 |
+
|
518 |
+
# Decode
|
519 |
+
decoder_outputs = self.decoder(
|
520 |
+
input_ids=decoder_input_ids,
|
521 |
+
attention_mask=decoder_attention_mask,
|
522 |
+
inputs_embeds=decoder_inputs_embeds,
|
523 |
+
past_key_values=past_key_values,
|
524 |
+
encoder_hidden_states=hidden_states,
|
525 |
+
encoder_attention_mask=attention_mask,
|
526 |
+
head_mask=decoder_head_mask,
|
527 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
528 |
+
use_cache=use_cache,
|
529 |
+
output_attentions=output_attentions,
|
530 |
+
output_hidden_states=output_hidden_states,
|
531 |
+
return_dict=return_dict,
|
532 |
+
)
|
533 |
+
|
534 |
+
sequence_output = decoder_outputs[0]
|
535 |
+
|
536 |
+
# Scale sequence output if embeddings are tied
|
537 |
+
if self.config.tie_word_embeddings:
|
538 |
+
sequence_output = sequence_output * (self.model_dim ** -0.5)
|
539 |
+
|
540 |
+
# Compute language modeling logits
|
541 |
+
lm_logits = self.lm_head(sequence_output)
|
542 |
+
|
543 |
+
loss = None
|
544 |
+
if labels is not None:
|
545 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
546 |
+
labels = labels.to(lm_logits.device)
|
547 |
+
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
548 |
+
|
549 |
+
if not return_dict:
|
550 |
+
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
551 |
+
return ((loss,) + output) if loss is not None else output
|
552 |
+
|
553 |
+
return Seq2SeqLMOutput(
|
554 |
+
loss=loss,
|
555 |
+
logits=lm_logits,
|
556 |
+
past_key_values=decoder_outputs.past_key_values,
|
557 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
558 |
+
decoder_attentions=decoder_outputs.attentions,
|
559 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
560 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
561 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
562 |
+
encoder_attentions=encoder_outputs.attentions,
|
563 |
+
)
|
564 |
+
|
565 |
+
def generate(
|
566 |
+
self,
|
567 |
+
face_features=None,
|
568 |
+
left_hand_features=None,
|
569 |
+
right_hand_features=None,
|
570 |
+
pose_features=None,
|
571 |
+
attention_mask=None,
|
572 |
+
**kwargs
|
573 |
+
):
|
574 |
+
"""
|
575 |
+
Generate method to handle sign language features and generate output sequences.
|
576 |
+
"""
|
577 |
+
# Compute encoder outputs using sign language features
|
578 |
+
encoder_outputs = self.encoder(
|
579 |
+
face_features=face_features,
|
580 |
+
left_hand_features=left_hand_features,
|
581 |
+
right_hand_features=right_hand_features,
|
582 |
+
pose_features=pose_features,
|
583 |
+
attention_mask=attention_mask,
|
584 |
+
return_dict=True
|
585 |
+
)
|
586 |
+
|
587 |
+
# Pass encoder outputs to the decoder
|
588 |
+
kwargs["encoder_outputs"] = encoder_outputs
|
589 |
+
|
590 |
+
# Generate sequences using the parent class's generate method
|
591 |
+
return super().generate(
|
592 |
+
attention_mask=attention_mask,
|
593 |
+
**kwargs
|
594 |
+
)
|
595 |
+
|
596 |
+
@dataclass
|
597 |
+
class SignLanguageT5Collator:
|
598 |
+
model: Optional[Any] = None
|
599 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
600 |
+
max_length: Optional[int] = None
|
601 |
+
pad_to_multiple_of: Optional[int] = None
|
602 |
+
label_pad_token_id: int = -100
|
603 |
+
return_tensors: str = "pt"
|
604 |
+
|
605 |
+
def __call__(self, features, return_tensors=None):
|
606 |
+
if return_tensors is None:
|
607 |
+
return_tensors = self.return_tensors
|
608 |
+
|
609 |
+
face_embeds = [feature["face_features"] for feature in features]
|
610 |
+
left_hand_embeds = [feature["left_hand_features"] for feature in features]
|
611 |
+
right_hand_embeds = [feature["right_hand_features"] for feature in features]
|
612 |
+
pose_embeds = [feature["pose_features"] for feature in features]
|
613 |
+
|
614 |
+
# Padding
|
615 |
+
max_len = max([emb.shape[0] for emb in face_embeds])
|
616 |
+
|
617 |
+
def pad_embeds(embeds):
|
618 |
+
padded_embeds = []
|
619 |
+
for emb in embeds:
|
620 |
+
if emb.dim() == 3: # For 3D tensors (pose features)
|
621 |
+
pad_len = max_len - emb.shape[1] # padding the second dimension (T)
|
622 |
+
emb_pad = torch.nn.functional.pad(emb, (0, 0, 0, pad_len, 0, 0), value=0)
|
623 |
+
else: # For 2D tensors (face, hand features)
|
624 |
+
pad_len = max_len - emb.shape[0]
|
625 |
+
emb_pad = torch.nn.functional.pad(emb, (0, 0, 0, pad_len), value=0)
|
626 |
+
padded_embeds.append(emb_pad)
|
627 |
+
return padded_embeds
|
628 |
+
|
629 |
+
padded_face_embeds = pad_embeds(face_embeds)
|
630 |
+
padded_left_hand_embeds = pad_embeds(left_hand_embeds)
|
631 |
+
padded_right_hand_embeds = pad_embeds(right_hand_embeds)
|
632 |
+
padded_pose_embeds = pad_embeds(pose_embeds)
|
633 |
+
|
634 |
+
batch = {}
|
635 |
+
batch["face_features"] = torch.stack(padded_face_embeds, dim=0)
|
636 |
+
batch["left_hand_features"] = torch.stack(padded_left_hand_embeds, dim=0)
|
637 |
+
batch["right_hand_features"] = torch.stack(padded_right_hand_embeds, dim=0)
|
638 |
+
batch["pose_features"] = torch.stack(padded_pose_embeds, dim=0)
|
639 |
+
|
640 |
+
# For inference, we don't need decoder_input_ids - the model.generate() will handle this
|
641 |
+
# Remove the decoder_input_ids requirement
|
642 |
+
return batch
|
643 |
+
|
644 |
+
class TranslationFeatures(torch.utils.data.Dataset):
|
645 |
+
def __init__(self, face_embeddings, left_hand_embeddings, right_hand_embeddings, body_posture_embeddings):
|
646 |
+
self.face_embeddings = face_embeddings
|
647 |
+
self.left_hand_embeddings = left_hand_embeddings
|
648 |
+
self.right_hand_embeddings = right_hand_embeddings
|
649 |
+
self.body_posture_embeddings = body_posture_embeddings
|
650 |
+
|
651 |
+
def __len__(self):
|
652 |
+
return 1
|
653 |
+
|
654 |
+
def __getitem__(self, idx):
|
655 |
+
return {
|
656 |
+
"face_features": torch.tensor(self.face_embeddings),
|
657 |
+
"left_hand_features": torch.tensor(self.left_hand_embeddings),
|
658 |
+
"right_hand_features": torch.tensor(self.right_hand_embeddings),
|
659 |
+
"pose_features": torch.tensor(self.body_posture_embeddings),
|
660 |
+
}
|
661 |
+
|
662 |
+
def generate_text_from_features(
|
663 |
+
face_embeddings: np.ndarray,
|
664 |
+
left_hand_embeddings: np.ndarray,
|
665 |
+
right_hand_embeddings: np.ndarray,
|
666 |
+
body_posture_embeddings: np.ndarray,
|
667 |
+
model_config: str,
|
668 |
+
model_checkpoint: str,
|
669 |
+
tokenizer_checkpoint: str,
|
670 |
+
output_dir: str,
|
671 |
+
generation_max_length: int = 2048,
|
672 |
+
generation_num_beams: int = 5,
|
673 |
+
):
|
674 |
+
"""
|
675 |
+
Direct inference function that generates text from sign language features.
|
676 |
+
"""
|
677 |
+
# Load model and tokenizer
|
678 |
+
config = SignLanguageByT5Config.from_pretrained(model_config)
|
679 |
+
model = SignLanguageByT5ForConditionalGeneration.from_pretrained(
|
680 |
+
model_checkpoint,
|
681 |
+
# config=config,
|
682 |
+
cache_dir=os.path.join(output_dir, "cache"),
|
683 |
+
)
|
684 |
+
tokenizer = ByT5Tokenizer.from_pretrained(tokenizer_checkpoint)
|
685 |
+
|
686 |
+
# Move model to appropriate device
|
687 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
688 |
+
model.to(device)
|
689 |
+
model.eval()
|
690 |
+
|
691 |
+
# Convert inputs to tensors and move to device
|
692 |
+
face_tensor = torch.tensor(face_embeddings, dtype=torch.float32).unsqueeze(0).to(device)
|
693 |
+
left_hand_tensor = torch.tensor(left_hand_embeddings, dtype=torch.float32).unsqueeze(0).to(device)
|
694 |
+
right_hand_tensor = torch.tensor(right_hand_embeddings, dtype=torch.float32).unsqueeze(0).to(device)
|
695 |
+
pose_tensor = torch.tensor(body_posture_embeddings, dtype=torch.float32).unsqueeze(0).to(device)
|
696 |
+
|
697 |
+
# Generate text
|
698 |
+
with torch.no_grad():
|
699 |
+
generated_ids = model.generate(
|
700 |
+
face_features=face_tensor,
|
701 |
+
left_hand_features=left_hand_tensor,
|
702 |
+
right_hand_features=right_hand_tensor,
|
703 |
+
pose_features=pose_tensor,
|
704 |
+
max_length=generation_max_length,
|
705 |
+
num_beams=generation_num_beams,
|
706 |
+
early_stopping=True,
|
707 |
+
pad_token_id=tokenizer.pad_token_id,
|
708 |
+
eos_token_id=tokenizer.eos_token_id,
|
709 |
+
)
|
710 |
+
|
711 |
+
# Decode generated text
|
712 |
+
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
713 |
+
return generated_text
|
714 |
+
|
715 |
+
def test(
|
716 |
+
face_embeddings: np.ndarray,
|
717 |
+
left_hand_embeddings: np.ndarray,
|
718 |
+
right_hand_embeddings: np.ndarray,
|
719 |
+
body_posture_embeddings: np.ndarray,
|
720 |
+
model_config: str,
|
721 |
+
model_checkpoint: str,
|
722 |
+
tokenizer_checkpoint: str,
|
723 |
+
output_dir: str,
|
724 |
+
):
|
725 |
+
"""
|
726 |
+
Test function for inference - generates text from sign language features.
|
727 |
+
This is a simpler wrapper around the direct inference function.
|
728 |
+
"""
|
729 |
+
return generate_text_from_features(
|
730 |
+
face_embeddings=face_embeddings,
|
731 |
+
left_hand_embeddings=left_hand_embeddings,
|
732 |
+
right_hand_embeddings=right_hand_embeddings,
|
733 |
+
body_posture_embeddings=body_posture_embeddings,
|
734 |
+
model_config=model_config,
|
735 |
+
model_checkpoint=model_checkpoint,
|
736 |
+
tokenizer_checkpoint=tokenizer_checkpoint,
|
737 |
+
output_dir=output_dir,
|
738 |
+
)
|
kpe_mediapipe.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mediapipe as mp
|
2 |
+
from mediapipe.tasks import python
|
3 |
+
from mediapipe.tasks.python import vision
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import json
|
7 |
+
from pathlib import Path
|
8 |
+
import decord
|
9 |
+
from typing import Dict, Optional, Tuple, Any
|
10 |
+
|
11 |
+
|
12 |
+
class HolisticDetector:
|
13 |
+
"""
|
14 |
+
A class for detecting face, hand, and pose landmarks in videos using MediaPipe.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, face_model_path: str, hand_model_path: str,
|
18 |
+
min_detection_confidence: float = 0.1,
|
19 |
+
min_hand_detection_confidence: float = 0.05,
|
20 |
+
max_faces: int = 6, max_hands: int = 6):
|
21 |
+
"""
|
22 |
+
Initialize the HolisticDetector with model paths and configuration.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
face_model_path: Path to the face detection model
|
26 |
+
hand_model_path: Path to the hand detection model
|
27 |
+
min_detection_confidence: Minimum confidence for pose detection
|
28 |
+
min_hand_detection_confidence: Minimum confidence for hand detection
|
29 |
+
max_faces: Maximum number of faces to detect
|
30 |
+
max_hands: Maximum number of hands to detect
|
31 |
+
"""
|
32 |
+
self.face_model_path = face_model_path
|
33 |
+
self.hand_model_path = hand_model_path
|
34 |
+
self.min_detection_confidence = min_detection_confidence
|
35 |
+
self.min_hand_detection_confidence = min_hand_detection_confidence
|
36 |
+
self.max_faces = max_faces
|
37 |
+
self.max_hands = max_hands
|
38 |
+
|
39 |
+
self._initialize_detectors()
|
40 |
+
|
41 |
+
def _initialize_detectors(self):
|
42 |
+
"""Initialize the MediaPipe detectors."""
|
43 |
+
# Initialize face detector
|
44 |
+
base_options_face = python.BaseOptions(model_asset_path=self.face_model_path)
|
45 |
+
options_face = vision.FaceLandmarkerOptions(
|
46 |
+
base_options=base_options_face,
|
47 |
+
output_face_blendshapes=True,
|
48 |
+
output_facial_transformation_matrixes=True,
|
49 |
+
num_faces=self.max_faces
|
50 |
+
)
|
51 |
+
self.face_detector = vision.FaceLandmarker.create_from_options(options_face)
|
52 |
+
|
53 |
+
# Initialize hand detector
|
54 |
+
base_options_hand = python.BaseOptions(model_asset_path=self.hand_model_path)
|
55 |
+
options_hand = vision.HandLandmarkerOptions(
|
56 |
+
base_options=base_options_hand,
|
57 |
+
num_hands=self.max_hands,
|
58 |
+
min_hand_detection_confidence=self.min_hand_detection_confidence
|
59 |
+
)
|
60 |
+
self.hand_detector = vision.HandLandmarker.create_from_options(options_hand)
|
61 |
+
|
62 |
+
# Initialize holistic model for pose
|
63 |
+
self.mp_holistic = mp.solutions.holistic.Holistic(
|
64 |
+
min_detection_confidence=self.min_detection_confidence
|
65 |
+
)
|
66 |
+
|
67 |
+
def detect_frame_landmarks(self, image: np.ndarray) -> Tuple[Dict[str, int], Dict[str, Any]]:
|
68 |
+
"""
|
69 |
+
Detect landmarks in a single frame.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
image: Input image as numpy array
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Tuple of (bounding_boxes_count, landmarks_data)
|
76 |
+
"""
|
77 |
+
results = self.mp_holistic.process(image)
|
78 |
+
|
79 |
+
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
|
80 |
+
face_prediction = self.face_detector.detect(mp_image)
|
81 |
+
hand_prediction = self.hand_detector.detect(mp_image)
|
82 |
+
|
83 |
+
bounding_boxes = {}
|
84 |
+
landmarks_data = {}
|
85 |
+
|
86 |
+
# Process face landmarks
|
87 |
+
if face_prediction.face_landmarks:
|
88 |
+
bounding_boxes['#face'] = len(face_prediction.face_landmarks)
|
89 |
+
landmarks_data['face_landmarks'] = []
|
90 |
+
for face in face_prediction.face_landmarks:
|
91 |
+
landmarks_face = [[landmark.x, landmark.y, landmark.z] for landmark in face]
|
92 |
+
landmarks_data['face_landmarks'].append(landmarks_face)
|
93 |
+
else:
|
94 |
+
bounding_boxes['#face'] = 0
|
95 |
+
landmarks_data['face_landmarks'] = None
|
96 |
+
|
97 |
+
# Process hand landmarks
|
98 |
+
if hand_prediction.hand_landmarks:
|
99 |
+
bounding_boxes['#hands'] = len(hand_prediction.hand_landmarks)
|
100 |
+
landmarks_data['hand_landmarks'] = []
|
101 |
+
for hand in hand_prediction.hand_landmarks:
|
102 |
+
landmarks_hand = [[landmark.x, landmark.y, landmark.z] for landmark in hand]
|
103 |
+
landmarks_data['hand_landmarks'].append(landmarks_hand)
|
104 |
+
else:
|
105 |
+
bounding_boxes['#hands'] = 0
|
106 |
+
landmarks_data['hand_landmarks'] = None
|
107 |
+
|
108 |
+
# Process pose landmarks
|
109 |
+
if results.pose_landmarks:
|
110 |
+
bounding_boxes['#pose'] = 1
|
111 |
+
landmarks_data['pose_landmarks'] = []
|
112 |
+
pose_landmarks = [[landmark.x, landmark.y, landmark.z] for landmark in results.pose_landmarks.landmark]
|
113 |
+
landmarks_data['pose_landmarks'].append(pose_landmarks)
|
114 |
+
else:
|
115 |
+
bounding_boxes['#pose'] = 0
|
116 |
+
landmarks_data['pose_landmarks'] = None
|
117 |
+
|
118 |
+
return bounding_boxes, landmarks_data
|
119 |
+
|
120 |
+
def process_video(self, video_input, save_results: bool = False,
|
121 |
+
output_dir: Optional[str] = None, video_name: Optional[str] = None) -> Dict[int, Any]:
|
122 |
+
"""
|
123 |
+
Process a video and extract landmarks from all frames.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
video_input: Either a path to video file (str) or a decord.VideoReader object
|
127 |
+
save_results: Whether to save results to files
|
128 |
+
output_dir: Directory to save results (required if save_results=True)
|
129 |
+
video_name: Name for output files (required if save_results=True and video_input is VideoReader)
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
Dictionary containing landmarks for each frame
|
133 |
+
|
134 |
+
Raises:
|
135 |
+
FileNotFoundError: If video file doesn't exist
|
136 |
+
ValueError: If save_results=True but output_dir is None, or if video_name is None when needed
|
137 |
+
TypeError: If video_input is neither string nor VideoReader
|
138 |
+
"""
|
139 |
+
if save_results and output_dir is None:
|
140 |
+
raise ValueError("output_dir must be provided when save_results=True")
|
141 |
+
|
142 |
+
# Handle different input types
|
143 |
+
if isinstance(video_input, str):
|
144 |
+
# Input is a file path
|
145 |
+
video_path = Path(video_input)
|
146 |
+
if not video_path.exists():
|
147 |
+
raise FileNotFoundError(f"Video file not found: {video_input}")
|
148 |
+
|
149 |
+
try:
|
150 |
+
video = decord.VideoReader(str(video_path))
|
151 |
+
except Exception as e:
|
152 |
+
raise RuntimeError(f"Error loading video {video_input}: {e}")
|
153 |
+
|
154 |
+
file_name = video_path.stem
|
155 |
+
|
156 |
+
# elif hasattr(video_input, '__len__') and hasattr(video_input, '__getitem__'):
|
157 |
+
else:
|
158 |
+
# Input is a VideoReader object or similar
|
159 |
+
video = video_input
|
160 |
+
if save_results and video_name is None:
|
161 |
+
raise ValueError("video_name must be provided when save_results=True and video_input is a VideoReader object")
|
162 |
+
file_name = video_name or "video"
|
163 |
+
|
164 |
+
# else:
|
165 |
+
# raise TypeError("video_input must be either a file path (str) or a VideoReader object")
|
166 |
+
|
167 |
+
result_dict = {}
|
168 |
+
stats = {}
|
169 |
+
|
170 |
+
# Process each frame
|
171 |
+
for i in range(len(video)):
|
172 |
+
try:
|
173 |
+
# frame_rgb = video[i].asnumpy()
|
174 |
+
frame_rgb = video[i]
|
175 |
+
if hasattr(video, 'seek'):
|
176 |
+
video.seek(0)
|
177 |
+
bounding_boxes, landmarks = self.detect_frame_landmarks(frame_rgb)
|
178 |
+
result_dict[i] = landmarks
|
179 |
+
stats[i] = bounding_boxes
|
180 |
+
except Exception as e:
|
181 |
+
print(f"Error processing frame {i}: {e}")
|
182 |
+
result_dict[i] = None
|
183 |
+
stats[i] = {'#face': 0, '#hands': 0, '#pose': 0}
|
184 |
+
|
185 |
+
# Save results if requested
|
186 |
+
if save_results:
|
187 |
+
self._save_results(file_name, result_dict, stats, output_dir)
|
188 |
+
|
189 |
+
return result_dict
|
190 |
+
|
191 |
+
def process_video_frames(self, frames: list, save_results: bool = False,
|
192 |
+
output_dir: Optional[str] = None, video_name: str = "video") -> Dict[int, Any]:
|
193 |
+
"""
|
194 |
+
Process a list of frames and extract landmarks.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
frames: List of frame images as numpy arrays
|
198 |
+
save_results: Whether to save results to files
|
199 |
+
output_dir: Directory to save results (required if save_results=True)
|
200 |
+
video_name: Name for output files
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
Dictionary containing landmarks for each frame
|
204 |
+
"""
|
205 |
+
if save_results and output_dir is None:
|
206 |
+
raise ValueError("output_dir must be provided when save_results=True")
|
207 |
+
|
208 |
+
result_dict = {}
|
209 |
+
stats = {}
|
210 |
+
|
211 |
+
# Process each frame
|
212 |
+
for i, frame in enumerate(frames):
|
213 |
+
try:
|
214 |
+
bounding_boxes, landmarks = self.detect_frame_landmarks(frame)
|
215 |
+
result_dict[i] = landmarks
|
216 |
+
stats[i] = bounding_boxes
|
217 |
+
except Exception as e:
|
218 |
+
print(f"Error processing frame {i}: {e}")
|
219 |
+
result_dict[i] = None
|
220 |
+
stats[i] = {'#face': 0, '#hands': 0, '#pose': 0}
|
221 |
+
|
222 |
+
# Save results if requested
|
223 |
+
if save_results:
|
224 |
+
self._save_results(video_name, result_dict, stats, output_dir)
|
225 |
+
|
226 |
+
return result_dict
|
227 |
+
|
228 |
+
def _save_results(self, video_name: str, landmarks_data: Dict, stats_data: Dict, output_dir: str):
|
229 |
+
"""Save landmarks and stats to JSON files."""
|
230 |
+
output_path = Path(output_dir)
|
231 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
232 |
+
|
233 |
+
# Save landmarks
|
234 |
+
landmarks_file = output_path / f"{video_name}_pose.json"
|
235 |
+
with open(landmarks_file, 'w') as f:
|
236 |
+
json.dump(landmarks_data, f)
|
237 |
+
|
238 |
+
# Save stats
|
239 |
+
stats_file = output_path / f"{video_name}_stats.json"
|
240 |
+
with open(stats_file, 'w') as f:
|
241 |
+
json.dump(stats_data, f)
|
242 |
+
|
243 |
+
def compute_video_stats(self, landmarks_data: Dict) -> Dict[str, Any]:
|
244 |
+
"""
|
245 |
+
Compute statistics from landmarks data.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
landmarks_data: Dictionary containing landmarks for each frame
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
Dictionary containing frame-by-frame stats and maximums
|
252 |
+
"""
|
253 |
+
stats = {}
|
254 |
+
max_counts = {'#face': 0, '#hands': 0, '#pose': 0}
|
255 |
+
|
256 |
+
for frame, landmarks in landmarks_data.items():
|
257 |
+
if landmarks is None:
|
258 |
+
presence = {'#face': 0, '#hands': 0, '#pose': 0}
|
259 |
+
else:
|
260 |
+
presence = {
|
261 |
+
'#face': len(landmarks.get('face_landmarks', [])) if landmarks.get('face_landmarks') else 0,
|
262 |
+
'#hands': len(landmarks.get('hand_landmarks', [])) if landmarks.get('hand_landmarks') else 0,
|
263 |
+
'#pose': len(landmarks.get('pose_landmarks', [])) if landmarks.get('pose_landmarks') else 0
|
264 |
+
}
|
265 |
+
stats[frame] = presence
|
266 |
+
|
267 |
+
# Update max counts
|
268 |
+
for key in max_counts:
|
269 |
+
max_counts[key] = max(max_counts[key], presence[key])
|
270 |
+
|
271 |
+
stats['max'] = max_counts
|
272 |
+
return stats
|
273 |
+
|
274 |
+
|
275 |
+
# Convenience function for backward compatibility and simple usage
|
276 |
+
def video_holistic(video_input, face_model_path: str, hand_model_path: str,
|
277 |
+
save_results: bool = False, output_dir: Optional[str] = None,
|
278 |
+
video_name: Optional[str] = None) -> Dict[int, Any]:
|
279 |
+
"""
|
280 |
+
Convenience function to process a video and extract holistic landmarks.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
video_input: Either a path to video file (str) or a decord.VideoReader object
|
284 |
+
face_model_path: Path to the face detection model
|
285 |
+
hand_model_path: Path to the hand detection model
|
286 |
+
save_results: Whether to save results to files
|
287 |
+
output_dir: Directory to save results
|
288 |
+
video_name: Name for output files (required if save_results=True and video_input is VideoReader)
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
Dictionary containing landmarks for each frame
|
292 |
+
"""
|
293 |
+
detector = HolisticDetector(face_model_path, hand_model_path)
|
294 |
+
return detector.process_video(video_input, save_results, output_dir, video_name)
|
295 |
+
|
296 |
+
|
297 |
+
# Utility functions for batch processing
|
298 |
+
def load_file(filename: str):
|
299 |
+
"""Load a pickled and gzipped file."""
|
300 |
+
import pickle
|
301 |
+
import gzip
|
302 |
+
with gzip.open(filename, "rb") as f:
|
303 |
+
return pickle.load(f)
|
304 |
+
|
305 |
+
|
306 |
+
def is_string_in_file(file_path: str, target_string: str) -> bool:
|
307 |
+
"""Check if a string exists in a file."""
|
308 |
+
try:
|
309 |
+
with Path(file_path).open("r") as f:
|
310 |
+
for line in f:
|
311 |
+
if target_string in line:
|
312 |
+
return True
|
313 |
+
return False
|
314 |
+
except Exception as e:
|
315 |
+
print(f"Error: {e}")
|
316 |
+
return False
|
317 |
+
|
318 |
+
|
319 |
+
def main():
|
320 |
+
"""Main function for command-line usage."""
|
321 |
+
import argparse
|
322 |
+
import time
|
323 |
+
import os
|
324 |
+
|
325 |
+
parser = argparse.ArgumentParser()
|
326 |
+
parser.add_argument('--index', type=int, required=True,
|
327 |
+
help='index of the sub_list to work with')
|
328 |
+
parser.add_argument('--batch_size', type=int, required=True,
|
329 |
+
help='batch size')
|
330 |
+
parser.add_argument('--pose_path', type=str, required=True,
|
331 |
+
help='path to where the pose data will be saved')
|
332 |
+
parser.add_argument('--stats_path', type=str, required=True,
|
333 |
+
help='path to where the stats data will be saved')
|
334 |
+
parser.add_argument('--time_limit', type=int, required=True,
|
335 |
+
help='time limit')
|
336 |
+
parser.add_argument('--files_list', type=str, required=True,
|
337 |
+
help='files list')
|
338 |
+
parser.add_argument('--problem_file_path', type=str, required=True,
|
339 |
+
help='problem file path')
|
340 |
+
parser.add_argument('--face_model_path', type=str, required=True,
|
341 |
+
help='face model path')
|
342 |
+
parser.add_argument('--hand_model_path', type=str, required=True,
|
343 |
+
help='hand model path')
|
344 |
+
|
345 |
+
args = parser.parse_args()
|
346 |
+
|
347 |
+
start_time = time.time()
|
348 |
+
|
349 |
+
# Initialize detector
|
350 |
+
detector = HolisticDetector(args.face_model_path, args.hand_model_path)
|
351 |
+
|
352 |
+
# Load the files list
|
353 |
+
fixed_list = load_file(args.files_list)
|
354 |
+
|
355 |
+
# Create folders if they do not exist
|
356 |
+
Path(args.pose_path).mkdir(parents=True, exist_ok=True)
|
357 |
+
Path(args.stats_path).mkdir(parents=True, exist_ok=True)
|
358 |
+
|
359 |
+
# Create problem file if it doesn't exist
|
360 |
+
if not os.path.exists(args.problem_file_path):
|
361 |
+
with open(args.problem_file_path, 'w') as f:
|
362 |
+
pass
|
363 |
+
|
364 |
+
# Process videos in batches
|
365 |
+
video_batches = [fixed_list[i:i + args.batch_size] for i in range(0, len(fixed_list), args.batch_size)]
|
366 |
+
|
367 |
+
for video_file in video_batches[args.index]:
|
368 |
+
current_time = time.time()
|
369 |
+
if current_time - start_time > args.time_limit:
|
370 |
+
print("Time limit reached. Stopping execution.")
|
371 |
+
break
|
372 |
+
|
373 |
+
# Check if output files already exist
|
374 |
+
video_name = Path(video_file).stem
|
375 |
+
landmark_json_path = Path(args.pose_path) / f"{video_name}_pose.json"
|
376 |
+
stats_json_path = Path(args.stats_path) / f"{video_name}_stats.json"
|
377 |
+
|
378 |
+
if landmark_json_path.exists() and stats_json_path.exists():
|
379 |
+
print(f"Skipping {video_file} - output files already exist")
|
380 |
+
continue
|
381 |
+
elif is_string_in_file(args.problem_file_path, video_file):
|
382 |
+
print(f"Skipping {video_file} - found in problem file")
|
383 |
+
continue
|
384 |
+
else:
|
385 |
+
try:
|
386 |
+
print(f"Processing {video_file}")
|
387 |
+
result_dict = detector.process_video(
|
388 |
+
video_file_path=video_file,
|
389 |
+
save_results=True,
|
390 |
+
output_dir=args.pose_path
|
391 |
+
)
|
392 |
+
|
393 |
+
# Also save stats separately for compatibility
|
394 |
+
stats = detector.compute_video_stats(result_dict)
|
395 |
+
with open(stats_json_path, 'w') as f:
|
396 |
+
json.dump(stats, f)
|
397 |
+
|
398 |
+
print(f"Successfully processed {video_file}")
|
399 |
+
|
400 |
+
except Exception as e:
|
401 |
+
print(f"Error processing {video_file}: {e}")
|
402 |
+
# Add to problem file
|
403 |
+
with open(args.problem_file_path, "a") as p:
|
404 |
+
p.write(video_file + "\n")
|
405 |
+
|
406 |
+
|
407 |
+
if __name__ == "__main__":
|
408 |
+
main()
|
shubert.py
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.distributed as dist
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
|
16 |
+
from fairseq.data.data_utils import compute_mask_indices
|
17 |
+
from fairseq.models import BaseFairseqModel, register_model
|
18 |
+
from fairseq.models.wav2vec import (
|
19 |
+
Wav2Vec2Config,
|
20 |
+
TransformerEncoder,
|
21 |
+
)
|
22 |
+
|
23 |
+
# Debug print to show where Wav2Vec2Config is defined
|
24 |
+
print(f"Wav2Vec2Config is imported from: {Wav2Vec2Config.__module__}")
|
25 |
+
print(f"Full path: {sys.modules[Wav2Vec2Config.__module__].__file__}")
|
26 |
+
|
27 |
+
from fairseq.modules import (
|
28 |
+
LayerNorm,
|
29 |
+
)
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class SignHubertConfig(Wav2Vec2Config):
|
36 |
+
# pos_conv_kernel: int = field(default=32)
|
37 |
+
conv_pos: int = field(default=32)
|
38 |
+
discrete: bool = field(default=False)
|
39 |
+
codebook_size: int = field(default=256)
|
40 |
+
channels_embed_dim: int = field(default=384)
|
41 |
+
channels_pose_embed_dim: int = field(default=14)
|
42 |
+
intermediate_dim: int = field(default=1024) # This will be overridden if needed
|
43 |
+
mask_strategy: str = field(default="random")
|
44 |
+
channels: str = field(default="face,left_hand,right_hand,body_posture")
|
45 |
+
|
46 |
+
|
47 |
+
@register_model("signhubert_onlyhands", dataclass=SignHubertConfig)
|
48 |
+
class SignHubertModel(BaseFairseqModel):
|
49 |
+
def __init__(self, cfg: SignHubertConfig):
|
50 |
+
super().__init__()
|
51 |
+
self.cfg = cfg
|
52 |
+
# print(cfg)
|
53 |
+
self.discrete = cfg.discrete # since it's hubert this will always be discrete anyways
|
54 |
+
|
55 |
+
self.embed = cfg.encoder_embed_dim # whether it is small(384), base(768), large, etc.
|
56 |
+
self.channel_embed = cfg.channels_embed_dim # embedding dimension for face, left_hand and right_hand (default: 384)
|
57 |
+
self.channel_pose_embed = cfg.channels_pose_embed_dim # embedding dimension for pose (default: 14)
|
58 |
+
self.intermediate_dim = cfg.intermediate_dim # intermediate dimension before the projection layer to encoder_embed_dim (default: 1024)
|
59 |
+
|
60 |
+
self.channels = cfg.channels.split(",")
|
61 |
+
|
62 |
+
self.post_extract_proj = nn.Linear(cfg.intermediate_dim, cfg.encoder_embed_dim) # 4 channels concatenated
|
63 |
+
|
64 |
+
self.mask_prob = cfg.mask_prob
|
65 |
+
self.mask_selection = cfg.mask_selection
|
66 |
+
self.mask_strategy = cfg.mask_strategy
|
67 |
+
self.mask_other = cfg.mask_other
|
68 |
+
self.mask_length = cfg.mask_length
|
69 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
70 |
+
self.mask_min_space = cfg.mask_min_space
|
71 |
+
|
72 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
73 |
+
self.mask_channel_before = cfg.mask_channel_before
|
74 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
75 |
+
self.mask_channel_other = cfg.mask_channel_other
|
76 |
+
self.mask_channel_length = cfg.mask_channel_length
|
77 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
78 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
79 |
+
|
80 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
81 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
82 |
+
|
83 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
84 |
+
|
85 |
+
self.mask_emb = nn.Parameter(
|
86 |
+
torch.FloatTensor(1, 1, 1, cfg.intermediate_dim // len(self.channels)).uniform_()
|
87 |
+
)
|
88 |
+
|
89 |
+
self.encoder = TransformerEncoder(cfg)
|
90 |
+
self.layer_norm = LayerNorm(self.channel_embed * len(self.channels))
|
91 |
+
|
92 |
+
|
93 |
+
if "face" in self.channels:
|
94 |
+
self.layer_norm_face = LayerNorm(self.channel_embed)
|
95 |
+
self.face_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // len(self.channels))
|
96 |
+
if "left_hand" in self.channels:
|
97 |
+
self.layer_norm_lhand = LayerNorm(self.channel_embed)
|
98 |
+
self.left_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // len(self.channels))
|
99 |
+
if "right_hand" in self.channels:
|
100 |
+
self.layer_norm_rhand = LayerNorm(self.channel_embed)
|
101 |
+
self.right_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // len(self.channels))
|
102 |
+
if "body_posture" in self.channels:
|
103 |
+
self.layer_norm_body = LayerNorm(self.channel_pose_embed)
|
104 |
+
self.body_posture_proj = nn.Linear(self.channel_pose_embed, cfg.intermediate_dim // len(self.channels))
|
105 |
+
|
106 |
+
self.codebook_size = cfg.codebook_size # number of codebook vectors
|
107 |
+
|
108 |
+
self.heads = []
|
109 |
+
for i in range(len(self.channels)):
|
110 |
+
self.heads.append(nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size))
|
111 |
+
|
112 |
+
self.heads = torch.nn.ModuleList(self.heads)
|
113 |
+
|
114 |
+
# self.heads = torch.nn.ModuleList([
|
115 |
+
# nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size) ,
|
116 |
+
# nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size),
|
117 |
+
# nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size),
|
118 |
+
# ]
|
119 |
+
# )
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
# # Define separate linear layers for each channel
|
124 |
+
# self.face_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // 4)
|
125 |
+
# self.left_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // 4)
|
126 |
+
# self.right_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // 4)
|
127 |
+
# self.body_posture_proj = nn.Linear(self.channel_pose_embed, cfg.intermediate_dim // 4)
|
128 |
+
|
129 |
+
|
130 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
131 |
+
|
132 |
+
state = super().state_dict(destination, prefix, keep_vars)
|
133 |
+
|
134 |
+
return state
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
@classmethod
|
139 |
+
def build_model(cls, cfg: SignHubertConfig, task=None):
|
140 |
+
"""Build a new model instance."""
|
141 |
+
|
142 |
+
return cls(cfg)
|
143 |
+
|
144 |
+
def apply_mask(
|
145 |
+
self,
|
146 |
+
x,
|
147 |
+
padding_mask,
|
148 |
+
mask_indices=None,
|
149 |
+
mask_channel_indices=None,
|
150 |
+
):
|
151 |
+
B, T, C, D = x.shape
|
152 |
+
|
153 |
+
# Initialize a mask vector with ones (same shape as x)
|
154 |
+
mask = torch.ones_like(x)
|
155 |
+
|
156 |
+
# channel masking
|
157 |
+
if self.mask_prob > 0 and self.mask_strategy == "channel":
|
158 |
+
if mask_indices is None:
|
159 |
+
mask_indices = torch.zeros_like(x[:,:,:,0], dtype=bool)
|
160 |
+
num_channels_to_mask = int(C * self.mask_prob)
|
161 |
+
num_channels_to_mask = max(1, num_channels_to_mask)
|
162 |
+
|
163 |
+
for i in range(B):
|
164 |
+
channels_to_mask = np.random.choice(C, num_channels_to_mask, replace=False)
|
165 |
+
mask_indices[i, :, channels_to_mask] = True
|
166 |
+
|
167 |
+
mask[mask_indices.unsqueeze(-1).expand(-1, -1, -1, D)] = 0
|
168 |
+
|
169 |
+
# gloss/time masking
|
170 |
+
elif self.mask_prob > 0 and self.mask_strategy == "gloss":
|
171 |
+
if mask_indices is None:
|
172 |
+
mask_indices_channel = compute_mask_indices(
|
173 |
+
(B, T),
|
174 |
+
padding_mask,
|
175 |
+
self.mask_prob,
|
176 |
+
self.mask_length,
|
177 |
+
self.mask_selection,
|
178 |
+
self.mask_other,
|
179 |
+
min_masks=1,
|
180 |
+
no_overlap=self.no_mask_channel_overlap,
|
181 |
+
min_space=self.mask_min_space,
|
182 |
+
require_same_masks=self.cfg.require_same_masks,
|
183 |
+
mask_dropout=self.cfg.mask_dropout,
|
184 |
+
)
|
185 |
+
mask_indices_channel = torch.from_numpy(mask_indices_channel).to(x.device)
|
186 |
+
|
187 |
+
# Apply the same mask to all channels
|
188 |
+
mask_indices = mask_indices_channel.unsqueeze(2).expand(-1, -1, C)
|
189 |
+
mask_indices = mask_indices.unsqueeze(3).expand(-1, -1, -1, D)
|
190 |
+
mask[mask_indices] = 0
|
191 |
+
|
192 |
+
# random masking
|
193 |
+
elif self.mask_prob > 0 and self.mask_strategy == "random":
|
194 |
+
if mask_indices is None:
|
195 |
+
mask_indices = compute_mask_indices(
|
196 |
+
(B, T*C), # Note: T*C instead of T
|
197 |
+
padding_mask,
|
198 |
+
self.mask_prob,
|
199 |
+
self.mask_length,
|
200 |
+
self.mask_selection,
|
201 |
+
self.mask_other,
|
202 |
+
min_masks=1,
|
203 |
+
no_overlap=self.no_mask_channel_overlap,
|
204 |
+
min_space=self.mask_min_space,
|
205 |
+
require_same_masks=self.cfg.require_same_masks,
|
206 |
+
mask_dropout=self.cfg.mask_dropout,
|
207 |
+
)
|
208 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
209 |
+
mask_indices = mask_indices.view(B, T, C)
|
210 |
+
mask_indices = mask_indices.unsqueeze(3).expand(-1, -1, -1, D)
|
211 |
+
mask[mask_indices] = 0
|
212 |
+
else:
|
213 |
+
raise ValueError(f"unknown mask strategy {self.mask_strategy}")
|
214 |
+
|
215 |
+
# Apply the mask to x and return the masked tensor with the same shape as x
|
216 |
+
# x = x * mask
|
217 |
+
x = x * mask + self.mask_emb * (1 - mask)
|
218 |
+
|
219 |
+
return x, mask
|
220 |
+
# mask is a tensor of shape BxTx4x256 where 0 means the value is masked and 1 means the value is not masked
|
221 |
+
|
222 |
+
|
223 |
+
def forward(
|
224 |
+
self,
|
225 |
+
source,
|
226 |
+
padding_mask=None,
|
227 |
+
mask=True,
|
228 |
+
features_only=False,
|
229 |
+
layer=None,
|
230 |
+
mask_indices=None,
|
231 |
+
mask_channel_indices=None,
|
232 |
+
padding_count=None,
|
233 |
+
kmeans_labels=None,
|
234 |
+
):
|
235 |
+
|
236 |
+
channels_to_use = []
|
237 |
+
for c in self.channels:
|
238 |
+
if c in source[0]:
|
239 |
+
channels_to_use.append(c)
|
240 |
+
|
241 |
+
for c in channels_to_use:
|
242 |
+
if c == "face":
|
243 |
+
face_features_list = []
|
244 |
+
label_face_features_list = []
|
245 |
+
elif c == "left_hand":
|
246 |
+
left_hand_features_list = []
|
247 |
+
label_left_hand_features_list = []
|
248 |
+
elif c == "right_hand":
|
249 |
+
right_hand_features_list = []
|
250 |
+
label_right_hand_features_list = []
|
251 |
+
elif c == "body_posture":
|
252 |
+
body_posture_features_list = []
|
253 |
+
label_body_posture_features_list = []
|
254 |
+
|
255 |
+
# # source is a list of dictionaries with keys "face", "left_hand", "right_hand", "body_posture"
|
256 |
+
# face_features_list = []
|
257 |
+
# left_hand_features_list = []
|
258 |
+
# right_hand_features_list = []
|
259 |
+
# body_posture_features_list = []
|
260 |
+
# label_face_features_list = []
|
261 |
+
# label_left_hand_features_list = []
|
262 |
+
# label_right_hand_features_list = []
|
263 |
+
# label_body_posture_features_list = []
|
264 |
+
|
265 |
+
# for sample in source:
|
266 |
+
# face_features_list.append(sample["face"]) # Tx384
|
267 |
+
# left_hand_features_list.append(sample["left_hand"]) # Tx384
|
268 |
+
# right_hand_features_list.append(sample["right_hand"]) # Tx384
|
269 |
+
# body_posture_features_list.append(sample["body_posture"]) # Tx14
|
270 |
+
# label_face_features_list.append(sample["label_face"]) # Tx1
|
271 |
+
# label_left_hand_features_list.append(sample["label_left_hand"]) # Tx1
|
272 |
+
# label_right_hand_features_list.append(sample["label_right_hand"]) # Tx1
|
273 |
+
# label_body_posture_features_list.append(sample["label_body_posture"]) # Tx1
|
274 |
+
|
275 |
+
for sample in source:
|
276 |
+
for c in channels_to_use:
|
277 |
+
if c == "face":
|
278 |
+
face_features_list.append(sample["face"]) # Tx384
|
279 |
+
label_face_features_list.append(sample["label_face"]) # Tx1
|
280 |
+
elif c == "left_hand":
|
281 |
+
left_hand_features_list.append(sample["left_hand"]) # Tx384
|
282 |
+
label_left_hand_features_list.append(sample["label_left_hand"]) # Tx1
|
283 |
+
elif c == "right_hand":
|
284 |
+
right_hand_features_list.append(sample["right_hand"]) # Tx384
|
285 |
+
label_right_hand_features_list.append(sample["label_right_hand"]) # Tx1
|
286 |
+
elif c == "body_posture":
|
287 |
+
body_posture_features_list.append(sample["body_posture"]) # Tx14
|
288 |
+
label_body_posture_features_list.append(sample["label_body_posture"]) # Tx1
|
289 |
+
|
290 |
+
|
291 |
+
|
292 |
+
|
293 |
+
# face_features = torch.stack(face_features_list) # BxTx384
|
294 |
+
# left_hand_features = torch.stack(left_hand_features_list) # BxTx384
|
295 |
+
# right_hand_features = torch.stack(right_hand_features_list) # BxTx384
|
296 |
+
# body_posture_features = torch.stack(body_posture_features_list) # BxTx14
|
297 |
+
# face_labels = torch.stack(label_face_features_list) # BxTx1
|
298 |
+
# left_hand_labels = torch.stack(label_left_hand_features_list) # BxTx1
|
299 |
+
# right_hand_labels = torch.stack(label_right_hand_features_list) # BxTx1
|
300 |
+
# body_posture_labels = torch.stack(label_body_posture_features_list) # BxTx1
|
301 |
+
|
302 |
+
|
303 |
+
# # Apply layer normalization to each part
|
304 |
+
# face_features = self.layer_norm_face(face_features) # BxTx384
|
305 |
+
# left_hand_features = self.layer_norm_lhand(left_hand_features) # BxTx384
|
306 |
+
# right_hand_features = self.layer_norm_rhand(right_hand_features) # BxTx384
|
307 |
+
# body_posture_features = self.layer_norm_body(body_posture_features) # BxTx14
|
308 |
+
|
309 |
+
# # Apply separate linear projections for each channel
|
310 |
+
# face_features = self.face_proj(face_features) # BxTx256
|
311 |
+
# left_hand_features = self.left_hand_proj(left_hand_features) # BxTx256
|
312 |
+
# right_hand_features = self.right_hand_proj(right_hand_features) # BxTx256
|
313 |
+
# body_posture_features = self.body_posture_proj(body_posture_features) # BxTx256
|
314 |
+
|
315 |
+
features_list = []
|
316 |
+
labels_list = []
|
317 |
+
|
318 |
+
for c in channels_to_use:
|
319 |
+
if c == "face":
|
320 |
+
face_features = torch.stack(face_features_list) # BxTx384
|
321 |
+
face_labels = torch.stack(label_face_features_list) # BxTx1
|
322 |
+
face_features = self.layer_norm_face(face_features) # BxTx384
|
323 |
+
face_features = self.face_proj(face_features) # BxTx256
|
324 |
+
features_list.append(face_features)
|
325 |
+
labels_list.append(face_labels)
|
326 |
+
elif c == "left_hand":
|
327 |
+
left_hand_features = torch.stack(left_hand_features_list) # BxTx384
|
328 |
+
left_hand_labels = torch.stack(label_left_hand_features_list) # BxTx1
|
329 |
+
left_hand_features = self.layer_norm_lhand(left_hand_features) # BxTx384
|
330 |
+
left_hand_features = self.left_hand_proj(left_hand_features) # BxTx256
|
331 |
+
features_list.append(left_hand_features)
|
332 |
+
labels_list.append(left_hand_labels)
|
333 |
+
elif c == "right_hand":
|
334 |
+
right_hand_features = torch.stack(right_hand_features_list) # BxTx384
|
335 |
+
right_hand_labels = torch.stack(label_right_hand_features_list) # BxTx1
|
336 |
+
right_hand_features = self.layer_norm_rhand(right_hand_features) # BxTx384
|
337 |
+
right_hand_features = self.right_hand_proj(right_hand_features) # BxTx256
|
338 |
+
features_list.append(right_hand_features)
|
339 |
+
labels_list.append(right_hand_labels)
|
340 |
+
elif c == "body_posture":
|
341 |
+
body_posture_features = torch.stack(body_posture_features_list) # BxTx14
|
342 |
+
body_posture_labels = torch.stack(label_body_posture_features_list) # BxTx1
|
343 |
+
body_posture_features = self.layer_norm_body(body_posture_features) # BxTx14
|
344 |
+
body_posture_features = self.body_posture_proj(body_posture_features) # BxTx256
|
345 |
+
features_list.append(body_posture_features)
|
346 |
+
labels_list.append(body_posture_labels)
|
347 |
+
|
348 |
+
|
349 |
+
# concatenate the projected features to have dimension BxTxCxD where C=4 and D=256
|
350 |
+
# features = torch.stack(
|
351 |
+
# [
|
352 |
+
# face_features,
|
353 |
+
# left_hand_features,
|
354 |
+
# right_hand_features,
|
355 |
+
# body_posture_features
|
356 |
+
# ],
|
357 |
+
# dim=2) # BxTx4x256
|
358 |
+
|
359 |
+
features = torch.stack(features_list, dim=2) # BxTx4x256
|
360 |
+
|
361 |
+
if mask:
|
362 |
+
x, mask_indices = self.apply_mask(
|
363 |
+
features,
|
364 |
+
padding_mask,
|
365 |
+
mask_indices=mask_indices,
|
366 |
+
mask_channel_indices=mask_channel_indices,
|
367 |
+
)
|
368 |
+
# mask_indices is a tensor of shape BxTx4x256 where 0 means the value is masked and 1 means the value is not masked
|
369 |
+
else:
|
370 |
+
x = features
|
371 |
+
mask_indices = None
|
372 |
+
|
373 |
+
|
374 |
+
x = self.dropout_input(x) # BxTx4x256
|
375 |
+
|
376 |
+
x = x.view(x.size(0), x.size(1), -1) # BxTx1024
|
377 |
+
if self.post_extract_proj is not None:
|
378 |
+
x = self.post_extract_proj(x) # BxTx768
|
379 |
+
|
380 |
+
x, layer_results = self.encoder(
|
381 |
+
x,
|
382 |
+
padding_mask=padding_mask,
|
383 |
+
layer=layer,
|
384 |
+
)
|
385 |
+
|
386 |
+
if features_only:
|
387 |
+
return {
|
388 |
+
"x": x,
|
389 |
+
"padding_mask": padding_mask,
|
390 |
+
"layer_results": layer_results,
|
391 |
+
}
|
392 |
+
|
393 |
+
result = {
|
394 |
+
"losses": {},
|
395 |
+
}
|
396 |
+
|
397 |
+
# use linear heads to compute the discrete prediction for each channel and make it into a single tensor of shape BxTxCxcodebook_size
|
398 |
+
predictions = []
|
399 |
+
for i, head in enumerate(self.heads):
|
400 |
+
channel_pred = head(x) # BxTxcodebook_size
|
401 |
+
predictions.append(channel_pred)
|
402 |
+
predictions = torch.stack(predictions, dim=2) # BxTx4xcodebook_size
|
403 |
+
|
404 |
+
# labels = torch.stack(
|
405 |
+
# [
|
406 |
+
# face_labels,
|
407 |
+
# left_hand_labels,
|
408 |
+
# right_hand_labels,
|
409 |
+
# body_posture_labels
|
410 |
+
# ],
|
411 |
+
# dim=2) # BxTx4x1
|
412 |
+
|
413 |
+
labels = torch.stack(labels_list, dim=2) # BxTx4x1
|
414 |
+
# print(f"predictions shape: {predictions.shape} and labels shape: {labels.shape}")
|
415 |
+
|
416 |
+
predictions_flat = predictions.view(-1, self.codebook_size) # Shape: (B * T * C, codebook_size)
|
417 |
+
labels_flat = labels.view(-1) # Shape: (B * T * C)
|
418 |
+
|
419 |
+
# Ensure labels are of correct shape
|
420 |
+
labels_flat = labels_flat.squeeze(-1) # Remove the last dimension if it's size 1
|
421 |
+
|
422 |
+
# Correct the mask_indices to match the shape of predictions_flat
|
423 |
+
mask_indices_reduced = mask_indices.any(dim=-1) # Reduce mask to (B, T, C) by collapsing last dimension
|
424 |
+
mask_indices_flat = mask_indices_reduced.view(-1) # Flatten to match the shape of (B * T * C)
|
425 |
+
|
426 |
+
# Calculate the loss only for the masked positions (where mask_indices_flat is zero)
|
427 |
+
masked_loss = F.cross_entropy(
|
428 |
+
predictions_flat[mask_indices_flat == 0],
|
429 |
+
labels_flat[mask_indices_flat == 0],
|
430 |
+
reduction='none'
|
431 |
+
)
|
432 |
+
|
433 |
+
# Store the result
|
434 |
+
result['losses']['kmeans_loss'] = masked_loss
|
435 |
+
|
436 |
+
|
437 |
+
|
438 |
+
if "sample_size" not in result:
|
439 |
+
result['sample_size'] = masked_loss.numel()
|
440 |
+
|
441 |
+
return result
|
442 |
+
|
443 |
+
@staticmethod
|
444 |
+
def compute_var(y):
|
445 |
+
y = y.view(-1, y.size(-1))
|
446 |
+
if dist.is_initialized():
|
447 |
+
zc = torch.tensor(y.size(0)).cuda()
|
448 |
+
zs = y.sum(dim=0)
|
449 |
+
zss = (y ** 2).sum(dim=0)
|
450 |
+
|
451 |
+
dist.all_reduce(zc)
|
452 |
+
dist.all_reduce(zs)
|
453 |
+
dist.all_reduce(zss)
|
454 |
+
|
455 |
+
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
|
456 |
+
return torch.sqrt(var + 1e-6).mean()
|
457 |
+
else:
|
458 |
+
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
|
459 |
+
|
460 |
+
def extract_features(
|
461 |
+
self, source, padding_mask, kmeans_labels, mask=False, layer=None
|
462 |
+
):
|
463 |
+
res = self.forward(
|
464 |
+
source,
|
465 |
+
padding_mask,
|
466 |
+
mask=mask,
|
467 |
+
features_only=True,
|
468 |
+
layer=layer,
|
469 |
+
kmeans_labels=kmeans_labels,
|
470 |
+
)
|
471 |
+
return res
|
472 |
+
|
473 |
+
def remove_pretraining_modules(self, last_layer=None):
|
474 |
+
self.heads = None
|
475 |
+
self.final_proj = None
|
476 |
+
if last_layer is not None:
|
477 |
+
self.encoder.layers = nn.ModuleList(
|
478 |
+
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
|
479 |
+
)
|
shubert_inference.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import csv
|
4 |
+
import os
|
5 |
+
from tqdm import tqdm
|
6 |
+
import argparse
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
9 |
+
from examples.shubert.models.shubert import SHubertModel, SHubertConfig
|
10 |
+
from transformers import ByT5Tokenizer, ByT5ForConditionalGeneration
|
11 |
+
|
12 |
+
|
13 |
+
class SHubertProcessor:
|
14 |
+
"""
|
15 |
+
A class for processing multi-modal embeddings through SHubert model.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, checkpoint_path: str, device: Optional[str] = None):
|
19 |
+
"""
|
20 |
+
Initialize the SHubertProcessor.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
checkpoint_path: Path to the SHubert model checkpoint
|
24 |
+
device: Device to use ('cuda' or 'cpu'). Auto-detected if None
|
25 |
+
"""
|
26 |
+
self.checkpoint_path = checkpoint_path
|
27 |
+
self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
|
29 |
+
# Load the model
|
30 |
+
self.model = self._load_model()
|
31 |
+
|
32 |
+
print(f"SHubertProcessor initialized on device: {self.device}")
|
33 |
+
|
34 |
+
def _load_model(self) -> SHubertModel:
|
35 |
+
"""Load the SHubert model from checkpoint."""
|
36 |
+
# Initialize configuration
|
37 |
+
cfg = SHubertConfig()
|
38 |
+
|
39 |
+
# Initialize the model
|
40 |
+
model = SHubertModel(cfg)
|
41 |
+
|
42 |
+
# Load the checkpoint
|
43 |
+
checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
|
44 |
+
|
45 |
+
# Extract state dict
|
46 |
+
if 'model' in checkpoint:
|
47 |
+
state_dict = checkpoint['model']
|
48 |
+
else:
|
49 |
+
state_dict = checkpoint
|
50 |
+
|
51 |
+
# Load the state dictionary into the model
|
52 |
+
model.load_state_dict(state_dict, strict=False)
|
53 |
+
|
54 |
+
model.eval()
|
55 |
+
model.to(self.device)
|
56 |
+
return model
|
57 |
+
|
58 |
+
def process_embeddings(self, face_embeddings: np.ndarray,
|
59 |
+
left_hand_embeddings: np.ndarray,
|
60 |
+
right_hand_embeddings: np.ndarray,
|
61 |
+
pose_embeddings: np.ndarray) -> np.ndarray:
|
62 |
+
"""
|
63 |
+
Process multi-modal embeddings through SHubert model.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
face_embeddings: Face embeddings array of shape (num_frames, embedding_dim)
|
67 |
+
left_hand_embeddings: Left hand embeddings array of shape (num_frames, embedding_dim)
|
68 |
+
right_hand_embeddings: Right hand embeddings array of shape (num_frames, embedding_dim)
|
69 |
+
pose_embeddings: Pose embeddings array of shape (num_frames, pose_dim)
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
Numpy array of SHubert features with shape (num_layers, num_frames, feature_dim)
|
73 |
+
"""
|
74 |
+
# Convert to tensors and move to device
|
75 |
+
face = torch.from_numpy(face_embeddings).float().to(self.device)
|
76 |
+
left_hand = torch.from_numpy(left_hand_embeddings).float().to(self.device)
|
77 |
+
right_hand = torch.from_numpy(right_hand_embeddings).float().to(self.device)
|
78 |
+
body_posture = torch.from_numpy(pose_embeddings).float().to(self.device)
|
79 |
+
|
80 |
+
length = face.shape[0]
|
81 |
+
|
82 |
+
# Prepare input in the format expected by SHubert
|
83 |
+
source = [{
|
84 |
+
"face": face,
|
85 |
+
"left_hand": left_hand,
|
86 |
+
"right_hand": right_hand,
|
87 |
+
"body_posture": body_posture,
|
88 |
+
# Add dummy labels to match the expected input format
|
89 |
+
"label_face": torch.zeros((length, 1)).to(self.device),
|
90 |
+
"label_left_hand": torch.zeros((length, 1)).to(self.device),
|
91 |
+
"label_right_hand": torch.zeros((length, 1)).to(self.device),
|
92 |
+
"label_body_posture": torch.zeros((length, 1)).to(self.device)
|
93 |
+
}]
|
94 |
+
|
95 |
+
# Extract features
|
96 |
+
with torch.no_grad():
|
97 |
+
result = self.model.extract_features(source, padding_mask=None, kmeans_labels=None, mask=False)
|
98 |
+
|
99 |
+
# Extract layer outputs
|
100 |
+
layer_outputs = []
|
101 |
+
for layer in result['layer_results']:
|
102 |
+
# layer_output has shape [T, B, D]
|
103 |
+
# Since batch size B is 1, we can squeeze it
|
104 |
+
layer_output = layer[-1]
|
105 |
+
layer_output = layer_output.squeeze(1) # Shape: [T, D]
|
106 |
+
layer_outputs.append(layer_output.cpu().numpy()) # Convert to NumPy array
|
107 |
+
|
108 |
+
# Stack the outputs from all layers to get an array of shape [L, T, D]
|
109 |
+
features = np.stack(layer_outputs, axis=0) # Shape: [L, T, D]
|
110 |
+
return features
|
111 |
+
|
112 |
+
def process_embeddings_from_files(self, face_path: str, left_hand_path: str,
|
113 |
+
right_hand_path: str, pose_path: str) -> np.ndarray:
|
114 |
+
"""
|
115 |
+
Process embeddings loaded from files.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
face_path: Path to face embeddings .npy file
|
119 |
+
left_hand_path: Path to left hand embeddings .npy file
|
120 |
+
right_hand_path: Path to right hand embeddings .npy file
|
121 |
+
pose_path: Path to pose embeddings .npy file
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
Numpy array of SHubert features with shape (num_layers, num_frames, feature_dim)
|
125 |
+
"""
|
126 |
+
# Load numpy arrays
|
127 |
+
face_embeddings = np.load(face_path)
|
128 |
+
left_hand_embeddings = np.load(left_hand_path)
|
129 |
+
right_hand_embeddings = np.load(right_hand_path)
|
130 |
+
pose_embeddings = np.load(pose_path)
|
131 |
+
|
132 |
+
return self.process_embeddings(face_embeddings, left_hand_embeddings,
|
133 |
+
right_hand_embeddings, pose_embeddings)
|
134 |
+
|
135 |
+
def process_and_save_embeddings(self, face_embeddings: np.ndarray,
|
136 |
+
left_hand_embeddings: np.ndarray,
|
137 |
+
right_hand_embeddings: np.ndarray,
|
138 |
+
pose_embeddings: np.ndarray,
|
139 |
+
output_path: str) -> str:
|
140 |
+
"""
|
141 |
+
Process embeddings and save to file.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
face_embeddings: Face embeddings array
|
145 |
+
left_hand_embeddings: Left hand embeddings array
|
146 |
+
right_hand_embeddings: Right hand embeddings array
|
147 |
+
pose_embeddings: Pose embeddings array
|
148 |
+
output_path: Path to save the output file
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
Path to the saved file
|
152 |
+
"""
|
153 |
+
# Process embeddings
|
154 |
+
features = self.process_embeddings(face_embeddings, left_hand_embeddings,
|
155 |
+
right_hand_embeddings, pose_embeddings)
|
156 |
+
|
157 |
+
# Create output directory if it doesn't exist
|
158 |
+
output_dir = Path(output_path).parent
|
159 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
160 |
+
|
161 |
+
# Save features
|
162 |
+
np.save(output_path, features)
|
163 |
+
|
164 |
+
return str(output_path)
|
165 |
+
|
166 |
+
def process_from_files_and_save(self, face_path: str, left_hand_path: str,
|
167 |
+
right_hand_path: str, pose_path: str,
|
168 |
+
output_path: str) -> str:
|
169 |
+
"""
|
170 |
+
Process embeddings from files and save results.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
face_path: Path to face embeddings .npy file
|
174 |
+
left_hand_path: Path to left hand embeddings .npy file
|
175 |
+
right_hand_path: Path to right hand embeddings .npy file
|
176 |
+
pose_path: Path to pose embeddings .npy file
|
177 |
+
output_path: Path to save the output file
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
Path to the saved file
|
181 |
+
"""
|
182 |
+
# Process embeddings
|
183 |
+
features = self.process_embeddings_from_files(face_path, left_hand_path,
|
184 |
+
right_hand_path, pose_path)
|
185 |
+
|
186 |
+
# Create output directory if it doesn't exist
|
187 |
+
output_dir = Path(output_path).parent
|
188 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
189 |
+
|
190 |
+
# Save features
|
191 |
+
np.save(output_path, features)
|
192 |
+
|
193 |
+
return str(output_path)
|
194 |
+
|
195 |
+
|
196 |
+
class SHuBERTTextGenerator:
|
197 |
+
"""
|
198 |
+
A class that combines SHuBERT feature extraction with BYT5 text generation.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, shubert_checkpoint: str, byt5_model_name: str = "google/byt5-base",
|
202 |
+
device: Optional[str] = None):
|
203 |
+
"""
|
204 |
+
Initialize with SHuBERT and BYT5 models.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
shubert_checkpoint: Path to SHuBERT model checkpoint
|
208 |
+
byt5_model_name: Name of BYT5 model (default: "google/byt5-base")
|
209 |
+
device: Device to use ('cuda' or 'cpu')
|
210 |
+
"""
|
211 |
+
self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
212 |
+
|
213 |
+
# Initialize SHuBERT processor
|
214 |
+
self.shubert_processor = SHubertProcessor(shubert_checkpoint, self.device)
|
215 |
+
|
216 |
+
# Initialize BYT5 model
|
217 |
+
self.tokenizer = ByT5Tokenizer.from_pretrained(byt5_model_name)
|
218 |
+
self.model = ByT5ForConditionalGeneration.from_pretrained(byt5_model_name).to(self.device)
|
219 |
+
|
220 |
+
def generate_text(self, face_embeddings: np.ndarray,
|
221 |
+
left_hand_embeddings: np.ndarray,
|
222 |
+
right_hand_embeddings: np.ndarray,
|
223 |
+
pose_embeddings: np.ndarray,
|
224 |
+
max_length: int = 1024,
|
225 |
+
num_beams: int = 5) -> str:
|
226 |
+
"""
|
227 |
+
Generate text from multi-modal embeddings.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
face_embeddings: Face embeddings array
|
231 |
+
left_hand_embeddings: Left hand embeddings array
|
232 |
+
right_hand_embeddings: Right hand embeddings array
|
233 |
+
pose_embeddings: Pose embeddings array
|
234 |
+
max_length: Maximum length of generated text
|
235 |
+
num_beams: Number of beams for beam search
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
Generated text string
|
239 |
+
"""
|
240 |
+
# Get SHuBERT features
|
241 |
+
features = self.shubert_processor.process_embeddings(
|
242 |
+
face_embeddings, left_hand_embeddings, right_hand_embeddings, pose_embeddings)
|
243 |
+
|
244 |
+
# Select features from specific layer (default: last layer)
|
245 |
+
features = features[-1] # Shape: [T, D]
|
246 |
+
|
247 |
+
# Convert to tensor and add batch dimension
|
248 |
+
features = torch.from_numpy(features).float().unsqueeze(0).to(self.device)
|
249 |
+
|
250 |
+
# Generate text
|
251 |
+
generated_ids = self.model.generate(
|
252 |
+
inputs_embeds=features,
|
253 |
+
max_length=max_length,
|
254 |
+
num_beams=num_beams,
|
255 |
+
early_stopping=True
|
256 |
+
)
|
257 |
+
|
258 |
+
# Decode generated tokens to text
|
259 |
+
return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
260 |
+
|
261 |
+
|
262 |
+
def generate_text_from_features(face_embeddings: np.ndarray,
|
263 |
+
left_hand_embeddings: np.ndarray,
|
264 |
+
right_hand_embeddings: np.ndarray,
|
265 |
+
pose_embeddings: np.ndarray,
|
266 |
+
shubert_checkpoint: str,
|
267 |
+
byt5_model_name: str = "google/byt5-base",
|
268 |
+
max_length: int = 1024,
|
269 |
+
num_beams: int = 5) -> str:
|
270 |
+
"""
|
271 |
+
Convenience function to generate text from features.
|
272 |
+
"""
|
273 |
+
generator = SHuBERTTextGenerator(shubert_checkpoint, byt5_model_name)
|
274 |
+
return generator.generate_text(
|
275 |
+
face_embeddings, left_hand_embeddings, right_hand_embeddings, pose_embeddings,
|
276 |
+
max_length=max_length, num_beams=num_beams
|
277 |
+
)
|
278 |
+
|
279 |
+
|
280 |
+
# Convenience functions for backward compatibility
|
281 |
+
def process_shubert_embeddings(face_embeddings: np.ndarray,
|
282 |
+
left_hand_embeddings: np.ndarray,
|
283 |
+
right_hand_embeddings: np.ndarray,
|
284 |
+
pose_embeddings: np.ndarray,
|
285 |
+
checkpoint_path: str) -> np.ndarray:
|
286 |
+
"""
|
287 |
+
Convenience function to process embeddings through SHubert.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
face_embeddings: Face embeddings array
|
291 |
+
left_hand_embeddings: Left hand embeddings array
|
292 |
+
right_hand_embeddings: Right hand embeddings array
|
293 |
+
pose_embeddings: Pose embeddings array
|
294 |
+
checkpoint_path: Path to the SHubert model checkpoint
|
295 |
+
|
296 |
+
Returns:
|
297 |
+
Numpy array of SHubert features
|
298 |
+
"""
|
299 |
+
processor = SHubertProcessor(checkpoint_path)
|
300 |
+
return processor.process_embeddings(face_embeddings, left_hand_embeddings,
|
301 |
+
right_hand_embeddings, pose_embeddings)
|
302 |
+
|
303 |
+
|
304 |
+
def process_sample(model: SHubertModel, face_path: str, left_hand_path: str,
|
305 |
+
right_hand_path: str, body_posture_path: str) -> np.ndarray:
|
306 |
+
"""
|
307 |
+
Original function for backward compatibility with command-line usage.
|
308 |
+
"""
|
309 |
+
# Load numpy arrays
|
310 |
+
face_np = np.load(face_path)
|
311 |
+
left_hand_np = np.load(left_hand_path)
|
312 |
+
right_hand_np = np.load(right_hand_path)
|
313 |
+
body_posture_np = np.load(body_posture_path)
|
314 |
+
|
315 |
+
face = torch.from_numpy(face_np).float().cuda()
|
316 |
+
left_hand = torch.from_numpy(left_hand_np).float().cuda()
|
317 |
+
right_hand = torch.from_numpy(right_hand_np).float().cuda()
|
318 |
+
body_posture = torch.from_numpy(body_posture_np).float().cuda()
|
319 |
+
|
320 |
+
length = face.shape[0]
|
321 |
+
|
322 |
+
# Prepare input
|
323 |
+
source = [{
|
324 |
+
"face": face,
|
325 |
+
"left_hand": left_hand,
|
326 |
+
"right_hand": right_hand,
|
327 |
+
"body_posture": body_posture,
|
328 |
+
# Add dummy labels to match the expected input format
|
329 |
+
"label_face": torch.zeros((length, 1)).cuda(),
|
330 |
+
"label_left_hand": torch.zeros((length, 1)).cuda(),
|
331 |
+
"label_right_hand": torch.zeros((length, 1)).cuda(),
|
332 |
+
"label_body_posture": torch.zeros((length, 1)).cuda()
|
333 |
+
}]
|
334 |
+
|
335 |
+
# Extract features
|
336 |
+
with torch.no_grad():
|
337 |
+
result = model.extract_features(source, padding_mask=None, kmeans_labels=None, mask=False)
|
338 |
+
|
339 |
+
# Extract layer outputs
|
340 |
+
layer_outputs = []
|
341 |
+
for layer in result['layer_results']:
|
342 |
+
# layer_output has shape [T, B, D]
|
343 |
+
# Since batch size B is 1, we can squeeze it
|
344 |
+
layer_output = layer[-1]
|
345 |
+
layer_output = layer_output.squeeze(1) # Shape: [T, D]
|
346 |
+
layer_outputs.append(layer_output.cpu().numpy()) # Convert to NumPy array
|
347 |
+
|
348 |
+
# Stack the outputs from all layers to get an array of shape [L, T, D]
|
349 |
+
features = np.stack(layer_outputs, axis=0) # Shape: [L, T, D]
|
350 |
+
return features
|
351 |
+
|
352 |
+
|
353 |
+
def load_model(checkpoint_path: str) -> SHubertModel:
|
354 |
+
"""
|
355 |
+
Original function for backward compatibility with command-line usage.
|
356 |
+
"""
|
357 |
+
cfg = SHubertConfig()
|
358 |
+
|
359 |
+
# Initialize the model
|
360 |
+
model = SHubertModel(cfg)
|
361 |
+
|
362 |
+
# Load the checkpoint
|
363 |
+
checkpoint = torch.load(checkpoint_path)
|
364 |
+
|
365 |
+
# If the checkpoint is saved with a 'model' key
|
366 |
+
if 'model' in checkpoint:
|
367 |
+
state_dict = checkpoint['model']
|
368 |
+
else:
|
369 |
+
state_dict = checkpoint
|
370 |
+
|
371 |
+
# Load the state dictionary into the model
|
372 |
+
model.load_state_dict(state_dict, strict=False)
|
373 |
+
|
374 |
+
model.eval()
|
375 |
+
model.cuda() # Move to GPU if available
|
376 |
+
return model
|
377 |
+
|
378 |
+
|
379 |
+
def main(csv_list: List[List[str]], checkpoint_path: str, output_dir: str, index: int):
|
380 |
+
"""
|
381 |
+
Original main function for backward compatibility with command-line usage.
|
382 |
+
"""
|
383 |
+
model = load_model(checkpoint_path)
|
384 |
+
|
385 |
+
os.makedirs(output_dir, exist_ok=True)
|
386 |
+
|
387 |
+
for row in csv_list:
|
388 |
+
cues_list = row[0].split('\t')
|
389 |
+
face_path, left_hand_path, right_hand_path, body_posture_path = cues_list[0], cues_list[1], cues_list[2], cues_list[3]
|
390 |
+
|
391 |
+
output_filename = f"{os.path.basename(face_path).rsplit('.', 1)[0].rsplit('_', 1)[0]}.npy"
|
392 |
+
output_path = os.path.join(output_dir, output_filename)
|
393 |
+
|
394 |
+
# check if the output file already exists
|
395 |
+
if os.path.exists(output_path):
|
396 |
+
print(f"Skipping {output_path} as it already exists")
|
397 |
+
continue
|
398 |
+
|
399 |
+
# Process the sample
|
400 |
+
features = process_sample(model, face_path, left_hand_path, right_hand_path, body_posture_path)
|
401 |
+
|
402 |
+
np.save(output_path, features)
|
403 |
+
|
404 |
+
|
405 |
+
if __name__ == "__main__":
|
406 |
+
parser = argparse.ArgumentParser()
|
407 |
+
parser.add_argument('--index', type=int, required=True,
|
408 |
+
help='index of the sub_list to work with')
|
409 |
+
parser.add_argument('--csv_path', type=str, required=True,
|
410 |
+
help='path to the CSV file')
|
411 |
+
parser.add_argument('--checkpoint_path', type=str, required=True,
|
412 |
+
help='path to the checkpoint file')
|
413 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
414 |
+
help='directory to save output files')
|
415 |
+
parser.add_argument('--batch_size', type=int, required=True,
|
416 |
+
help='batch size for processing')
|
417 |
+
|
418 |
+
args = parser.parse_args()
|
419 |
+
index = args.index
|
420 |
+
csv_path = args.csv_path
|
421 |
+
checkpoint_path = args.checkpoint_path
|
422 |
+
output_dir = args.output_dir
|
423 |
+
batch_size = int(args.batch_size)
|
424 |
+
|
425 |
+
# make output dir
|
426 |
+
os.makedirs(output_dir, exist_ok=True)
|
427 |
+
|
428 |
+
# Load CSV data
|
429 |
+
fixed_list = []
|
430 |
+
with open(csv_path, 'r') as csvfile:
|
431 |
+
reader = csv.reader(csvfile)
|
432 |
+
for row in reader:
|
433 |
+
fixed_list.append(row)
|
434 |
+
|
435 |
+
# Process in batches
|
436 |
+
video_batches = [fixed_list[i:i + batch_size] for i in range(0, len(fixed_list), batch_size)]
|
437 |
+
|
438 |
+
csv_list = video_batches[index]
|
439 |
+
main(csv_list, checkpoint_path, output_dir, index)
|