ShesterG commited on
Commit
ceeabec
·
1 Parent(s): 6854601

Add application file

Browse files
__pycache__/body_features.cpython-38.pyc ADDED
Binary file (10.2 kB). View file
 
__pycache__/crop_face.cpython-38.pyc ADDED
Binary file (12.1 kB). View file
 
__pycache__/crop_hands.cpython-38.pyc ADDED
Binary file (12.1 kB). View file
 
__pycache__/dinov2_features.cpython-38.pyc ADDED
Binary file (10.6 kB). View file
 
__pycache__/inference.cpython-38.pyc ADDED
Binary file (19.7 kB). View file
 
__pycache__/kpe_mediapipe.cpython-38.pyc ADDED
Binary file (12 kB). View file
 
__pycache__/shubert.cpython-38.pyc ADDED
Binary file (8.64 kB). View file
 
app.py CHANGED
@@ -1,14 +1,542 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
  import spaces
3
- import torch
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' 🤔
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' 🤗
11
- return f"Hello {zero + n} Tensor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import tempfile
4
+ import huggingface_hub
5
+ import shutil
6
+ import logging
7
+ import traceback
8
+ from features import SHuBERTProcessor
9
  import spaces
 
10
 
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
+ logger = logging.getLogger(__name__)
14
 
15
+ # Set writable cache directories
16
+ def setup_cache_directories():
17
+ """Set up cache directories with proper error handling"""
18
+ try:
19
+ cache_dirs = {
20
+ 'MPLCONFIGDIR': '/tmp/matplotlib',
21
+ 'TRANSFORMERS_CACHE': '/tmp/huggingface',
22
+ 'HF_HOME': '/tmp/huggingface',
23
+ 'FONTCONFIG_PATH': '/tmp/fontconfig',
24
+ 'TORCH_HOME': '/tmp/torch', # PyTorch cache directory
25
+ }
26
+
27
+ for env_var, path in cache_dirs.items():
28
+ os.environ[env_var] = path
29
+ os.makedirs(path, exist_ok=True, mode=0o777)
30
+ logger.info(f"Cache directory created: {env_var} = {path}")
31
+
32
+ # Also set XDG_CACHE_HOME to override default .cache location
33
+ os.environ['XDG_CACHE_HOME'] = '/tmp/cache'
34
+ os.makedirs('/tmp/cache', exist_ok=True, mode=0o777)
35
+ logger.info(f"Cache directory created: XDG_CACHE_HOME = /tmp/cache")
36
+
37
+ # Clear any existing PyTorch Hub cache to avoid corruption issues
38
+ torch_hub_dir = '/tmp/torch/hub'
39
+ if os.path.exists(torch_hub_dir):
40
+ shutil.rmtree(torch_hub_dir)
41
+ logger.info("Cleared existing PyTorch Hub cache")
42
+ os.makedirs(torch_hub_dir, exist_ok=True, mode=0o777)
43
+ logger.info(f"Created clean PyTorch Hub cache directory: {torch_hub_dir}")
44
+
45
+ # Copy updated DINOv2 files to torch cache after clearing
46
+ # This ensures they're available when PyTorch Hub downloads the repo
47
+ try:
48
+ src_dir = os.path.dirname(os.path.abspath(__file__))
49
+ target_dir = '/tmp/torch/hub/facebookresearch_dinov2_main/dinov2/layers'
50
+
51
+ for filename in ['attention.py', 'block.py']:
52
+ src_path = os.path.join(src_dir, filename)
53
+ if os.path.exists(src_path):
54
+ # We'll copy these after the initial hub download
55
+ logger.info(f"Found {filename} in project directory - will copy after hub download")
56
+ else:
57
+ logger.warning(f"Could not find {filename} in project directory")
58
+ except Exception as e:
59
+ logger.warning(f"Error preparing DINOv2 files: {e}")
60
+
61
+ return True
62
+ except Exception as e:
63
+ logger.error(f"Error creating cache directories: {str(e)}")
64
+ return False
65
+
66
+ # Configuration for Hugging Face Spaces
67
+ MODEL_REPO = "ShesterG/SHuBERT"
68
+ TOKEN = os.environ.get('HF_TOKEN')
69
+
70
+ def validate_environment():
71
+ """Validate required environment variables and setup"""
72
+ if not TOKEN:
73
+ raise ValueError("HF_TOKEN environment variable not set. This is required to access private model repository.")
74
+
75
+ # Check available disk space
76
+ free_space = shutil.disk_usage('/').free / (1024*1024*1024) # GB
77
+ logger.info(f"Available disk space: {free_space:.2f} GB")
78
+
79
+ if free_space < 2: # Less than 2GB
80
+ logger.warning("Low disk space available. This may cause issues.")
81
+
82
+ return True
83
+
84
+ def download_models():
85
+ """Download all required models from Hugging Face Hub with enhanced error handling"""
86
+ logger.info("Starting model download process...")
87
+
88
+ try:
89
+ # Validate environment first
90
+ validate_environment()
91
+
92
+
93
+
94
+ logger.info("Downloading entire models folder...")
95
+
96
+ # Download the entire models folder
97
+ models_path = huggingface_hub.snapshot_download(
98
+ repo_id=MODEL_REPO,
99
+ allow_patterns="models/*", # Download everything in models folder
100
+ token=TOKEN,
101
+ cache_dir=os.environ['TRANSFORMERS_CACHE']
102
+ )
103
+
104
+ # Build config dict with expected file paths
105
+ config = {
106
+ 'yolov8_model_path': os.path.join(models_path, "models/yolov8n.pt"),
107
+ 'dino_face_model_path': os.path.join(models_path, "models/dinov2face.pth"),
108
+ 'dino_hands_model_path': os.path.join(models_path, "models/dinov2hand.pth"),
109
+ 'mediapipe_face_model_path': os.path.join(models_path, "models/face_landmarker_v2_with_blendshapes.task"),
110
+ 'mediapipe_hands_model_path': os.path.join(models_path, "models/hand_landmarker.task"),
111
+ 'shubert_model_path': os.path.join(models_path, "models/checkpoint_836_400000.pt"),
112
+ 'slt_model_config': os.path.join(models_path, "models/byt5_base/config.json"),
113
+ 'slt_model_checkpoint': os.path.join(models_path, "models/checkpoint-11625"),
114
+ 'slt_tokenizer_checkpoint': os.path.join(models_path, "models/byt5_base"),
115
+ 'temp_dir': 'temp'
116
+ }
117
+
118
+ # Verify all required files and folders exist
119
+ logger.info("Verifying downloaded files...")
120
+ missing_files = []
121
+
122
+ for key, path in config.items():
123
+ if key == 'temp_dir': # Skip temp_dir check
124
+ continue
125
+
126
+ if not os.path.exists(path):
127
+ missing_files.append(f"{key}: {path}")
128
+ logger.error(f"Missing: {path}")
129
+ else:
130
+ logger.info(f"✓ Found: {path}")
131
+
132
+ if missing_files:
133
+ logger.error(f"Missing {len(missing_files)} required files/folders:")
134
+ for missing in missing_files:
135
+ logger.error(f" - {missing}")
136
+ raise FileNotFoundError(f"Required files not found: {missing_files}")
137
+
138
+ logger.info("All models downloaded and verified successfully!")
139
+ logger.info(f"Models root path: {models_path}")
140
+
141
+ return config
142
+
143
+ except Exception as e:
144
+ logger.error(f"Error downloading models: {str(e)}")
145
+ logger.error(f"Traceback: {traceback.format_exc()}")
146
+
147
+ # Additional debugging info
148
+ try:
149
+ cache_contents = os.listdir(os.environ['TRANSFORMERS_CACHE'])
150
+ logger.info(f"Cache directory contents: {cache_contents}")
151
+ except:
152
+ logger.error("Cannot access cache directory")
153
+
154
+ return None
155
+
156
+ def initialize_processor(config):
157
+ """Initialize SHuBERT processor with error handling"""
158
+ try:
159
+ logger.info("Initializing SHuBERT processor...")
160
+ processor = SHuBERTProcessor(config)
161
+ logger.info("SHuBERT processor initialized successfully!")
162
+ return processor
163
+ except Exception as e:
164
+ logger.error(f"Error initializing SHuBERT processor: {str(e)}")
165
+ logger.error(f"Traceback: {traceback.format_exc()}")
166
+ return None
167
+
168
+ # Initialize the application
169
+ def initialize_app():
170
+ """Initialize the entire application with comprehensive error handling"""
171
+ try:
172
+ # Setup cache directories
173
+ if not setup_cache_directories():
174
+ raise RuntimeError("Failed to setup cache directories")
175
+
176
+ # Download models
177
+ config = download_models()
178
+ if config is None:
179
+ raise RuntimeError("Failed to download models")
180
+
181
+ # Initialize processor
182
+ processor = initialize_processor(config)
183
+ if processor is None:
184
+ raise RuntimeError("Failed to initialize SHuBERT processor")
185
+
186
+ logger.info("Application initialized successfully!")
187
+ return config, processor
188
+
189
+ except Exception as e:
190
+ error_msg = f"Application initialization failed: {str(e)}"
191
+ logger.error(error_msg)
192
+ logger.error(f"Full traceback: {traceback.format_exc()}")
193
+ raise RuntimeError(error_msg)
194
+
195
+ # Global variables for application state
196
+ config = None
197
+ processor = None
198
+ initialization_error = None
199
+
200
+ try:
201
+ config, processor = initialize_app()
202
+ except Exception as e:
203
+ initialization_error = str(e)
204
+ logger.error(f"Startup failed: {initialization_error}")
205
+
206
+ def copy_dinov2_files_if_needed():
207
+ """Copy updated DINOv2 files after PyTorch Hub download if needed"""
208
+ try:
209
+ src_dir = os.path.dirname(os.path.abspath(__file__))
210
+ target_dir = '/tmp/torch/hub/facebookresearch_dinov2_main/dinov2/layers'
211
+
212
+ # Check if PyTorch Hub has downloaded the repository
213
+ hub_main_dir = '/tmp/torch/hub/facebookresearch_dinov2_main'
214
+
215
+ if os.path.exists(hub_main_dir):
216
+ # Ensure the target directory exists
217
+ os.makedirs(target_dir, exist_ok=True)
218
+
219
+ files_copied = 0
220
+ for filename in ['attention.py', 'block.py']:
221
+ src_path = os.path.join(src_dir, filename)
222
+ target_path = os.path.join(target_dir, filename)
223
+
224
+ if os.path.exists(src_path):
225
+ # Always overwrite with our robust versions
226
+ shutil.copy2(src_path, target_path)
227
+ # Make sure it's readable
228
+ os.chmod(target_path, 0o644)
229
+ logger.info(f"Replaced {filename} with robust version (numpy/Python 3.8 compatible)")
230
+ files_copied += 1
231
+ else:
232
+ logger.error(f"Source file not found: {src_path}")
233
+
234
+ if files_copied > 0:
235
+ # Clear Python's import cache to ensure new files are used
236
+ import importlib
237
+ import sys
238
+
239
+ # Remove any cached imports of dinov2 modules
240
+ modules_to_remove = [key for key in sys.modules.keys() if 'dinov2' in key]
241
+ for module in modules_to_remove:
242
+ del sys.modules[module]
243
+ logger.info(f"Cleared cached import: {module}")
244
+
245
+ logger.info(f"Successfully replaced {files_copied} DINOv2 files with robust versions")
246
+ return True
247
+ else:
248
+ logger.info("PyTorch Hub repository not yet downloaded")
249
+ return False
250
+
251
+ except Exception as e:
252
+ logger.error(f"Error copying DINOv2 files: {e}")
253
+ logger.error(f"Traceback: {traceback.format_exc()}")
254
+ return False
255
+
256
  @spaces.GPU
257
+ def process_video(video_file):
258
+ """Process uploaded video file with enhanced error handling"""
259
+ # Check if initialization was successful
260
+ if initialization_error:
261
+ return f"Application initialization failed: {initialization_error}\n\nPlease check the logs for more details."
262
+
263
+ if processor is None:
264
+ return "Error: Model not initialized properly. Please check the logs."
265
+
266
+ if video_file is None:
267
+ return "Please upload a video file."
268
+
269
+ logger.info(f"=== Starting video processing ===")
270
+ logger.info(f"Video file input: {video_file}")
271
+ logger.info(f"Video file type: {type(video_file)}")
272
+
273
+ try:
274
+ # Create temp directory with proper permissions
275
+ temp_dir = config['temp_dir']
276
+ os.makedirs(temp_dir, exist_ok=True, mode=0o777)
277
+ logger.info(f"Temp directory: {temp_dir}")
278
+
279
+ # Generate unique filename to avoid conflicts
280
+ import time
281
+ timestamp = str(int(time.time() * 1000))
282
+ file_extension = '.mp4' # Default extension
283
+
284
+ # Try to get original extension if available
285
+ try:
286
+ if hasattr(video_file, 'name') and video_file.name:
287
+ file_extension = os.path.splitext(video_file.name)[1] or '.mp4'
288
+ elif isinstance(video_file, str):
289
+ file_extension = os.path.splitext(video_file)[1] or '.mp4'
290
+ except:
291
+ pass
292
+
293
+ temp_video_path = os.path.join(temp_dir, f"video_{timestamp}{file_extension}")
294
+ logger.info(f"Target temp video path: {temp_video_path}")
295
+
296
+ # Handle Gradio file upload - video_file is typically a string path to temp file
297
+ logger.info(f"Processing video file: {video_file} (type: {type(video_file)})")
298
+
299
+ if isinstance(video_file, str):
300
+ # Gradio provides a file path string
301
+ source_path = video_file
302
+
303
+ # Handle both absolute and relative paths
304
+ if not os.path.isabs(source_path):
305
+ # Try current working directory first
306
+ abs_source_path = os.path.abspath(source_path)
307
+ logger.info(f"Converting relative path {source_path} to absolute: {abs_source_path}")
308
+ if os.path.exists(abs_source_path):
309
+ source_path = abs_source_path
310
+ else:
311
+ # Try looking in common Gradio temp directories
312
+ possible_paths = [
313
+ source_path,
314
+ os.path.join('/tmp', os.path.basename(source_path)),
315
+ os.path.join('/tmp/gradio', os.path.basename(source_path)),
316
+ abs_source_path
317
+ ]
318
+
319
+ found_path = None
320
+ for path in possible_paths:
321
+ logger.info(f"Checking path: {path}")
322
+ if os.path.exists(path):
323
+ found_path = path
324
+ logger.info(f"Found file at: {path}")
325
+ break
326
+
327
+ if found_path:
328
+ source_path = found_path
329
+ else:
330
+ logger.error(f"Could not find source file in any expected location")
331
+ logger.error(f"Tried paths: {possible_paths}")
332
+ raise FileNotFoundError(f"Source video file not found in any expected location: {video_file}")
333
+
334
+ logger.info(f"Final source file path: {source_path}")
335
+ logger.info(f"Source file exists: {os.path.exists(source_path)}")
336
+
337
+ if os.path.exists(source_path):
338
+ try:
339
+ # Check source file permissions and size
340
+ stat_info = os.stat(source_path)
341
+ logger.info(f"Source file size: {stat_info.st_size} bytes, mode: {oct(stat_info.st_mode)}")
342
+
343
+ # Try to read the file content
344
+ with open(source_path, 'rb') as src:
345
+ content = src.read()
346
+ logger.info(f"Successfully read {len(content)} bytes from source")
347
+
348
+ # Write to destination (with a different name to avoid conflicts)
349
+ final_temp_path = os.path.join(temp_dir, f"processed_{timestamp}{file_extension}")
350
+ with open(final_temp_path, 'wb') as dst:
351
+ dst.write(content)
352
+ logger.info(f"Successfully wrote to destination: {final_temp_path}")
353
+
354
+ # Update temp_video_path to the final location
355
+ temp_video_path = final_temp_path
356
+
357
+ except PermissionError as e:
358
+ logger.error(f"Permission error reading source file: {e}")
359
+ # Try alternative approach - use a completely different temp location
360
+ try:
361
+ import tempfile
362
+ # Create a new temporary file in system temp directory
363
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp:
364
+ alternative_temp_path = tmp.name
365
+
366
+ logger.info(f"Trying alternative temp path: {alternative_temp_path}")
367
+
368
+ # Try to copy using system copy command as fallback
369
+ import subprocess
370
+ result = subprocess.run(['cp', source_path, alternative_temp_path],
371
+ capture_output=True, text=True)
372
+
373
+ if result.returncode == 0:
374
+ logger.info("Successfully copied using system cp command")
375
+ temp_video_path = alternative_temp_path
376
+ else:
377
+ logger.error(f"System cp failed: {result.stderr}")
378
+ raise PermissionError(f"Cannot read video file due to permission restrictions: {e}")
379
+
380
+ except Exception as e2:
381
+ logger.error(f"Alternative copy method also failed: {e2}")
382
+ raise PermissionError(f"Cannot read video file due to permission restrictions: {e}")
383
+ else:
384
+ raise FileNotFoundError(f"Source video file not found: {source_path}")
385
+
386
+ elif hasattr(video_file, 'read'):
387
+ # If it's a file-like object with read method
388
+ try:
389
+ content = video_file.read()
390
+ with open(temp_video_path, 'wb') as f:
391
+ f.write(content)
392
+ logger.info(f"Saved video from file object: {temp_video_path} ({len(content)} bytes)")
393
+ except Exception as e:
394
+ logger.error(f"Error reading from file object: {e}")
395
+ raise ValueError(f"Cannot read from file object: {e}")
396
+ else:
397
+ # Handle other cases - try to extract file path or content
398
+ logger.info(f"Attempting to handle unknown file type: {type(video_file)}")
399
+ try:
400
+ # Check if it has a name attribute (common for file objects)
401
+ if hasattr(video_file, 'name'):
402
+ source_path = video_file.name
403
+ logger.info(f"Found name attribute: {source_path}")
404
+
405
+ if os.path.exists(source_path):
406
+ with open(source_path, 'rb') as src:
407
+ content = src.read()
408
+ with open(temp_video_path, 'wb') as dst:
409
+ dst.write(content)
410
+ logger.info(f"Successfully copied from name attribute")
411
+ else:
412
+ raise FileNotFoundError(f"File from name attribute not found: {source_path}")
413
+ else:
414
+ logger.error(f"Unsupported video file type: {type(video_file)}")
415
+ raise ValueError(f"Unsupported video file type: {type(video_file)}")
416
+ except Exception as e:
417
+ logger.error(f"Failed to handle unknown file type: {e}")
418
+ raise ValueError(f"Cannot process video file: {e}")
419
+
420
+ # Set proper permissions on the saved file
421
+ os.chmod(temp_video_path, 0o666)
422
+
423
+ # Verify file exists and has content
424
+ if not os.path.exists(temp_video_path) or os.path.getsize(temp_video_path) == 0:
425
+ raise ValueError("Video file is empty or could not be saved")
426
+
427
+ # Copy DINOv2 files if needed before processing
428
+ # This needs to happen right after PyTorch Hub downloads but before model loading
429
+ logger.info("Ensuring DINOv2 files are ready for processing...")
430
+ copy_dinov2_files_if_needed()
431
+
432
+ # Set up a monitoring patch for torch.hub.load to replace files immediately after download
433
+ original_torch_hub_load = None
434
+ try:
435
+ import torch.hub
436
+ original_torch_hub_load = torch.hub.load
437
+
438
+ def patched_torch_hub_load(*args, **kwargs):
439
+ logger.info(f"PyTorch Hub load called with: {args[0] if args else 'unknown'}")
440
+
441
+ # Call the original function first
442
+ result = original_torch_hub_load(*args, **kwargs)
443
+
444
+ # If this was a DINOv2 call, immediately replace the files
445
+ if args and 'dinov2' in str(args[0]):
446
+ logger.info("DINOv2 downloaded! Immediately replacing with robust versions...")
447
+
448
+ # Try multiple times to ensure files are replaced
449
+ import time
450
+ for attempt in range(5):
451
+ if copy_dinov2_files_if_needed():
452
+ logger.info("Successfully replaced DINOv2 files!")
453
+ break
454
+ else:
455
+ logger.info(f"Attempt {attempt + 1} failed, retrying in 1 second...")
456
+ time.sleep(1)
457
+
458
+ return result
459
+
460
+ # Temporarily patch torch.hub.load
461
+ torch.hub.load = patched_torch_hub_load
462
+ logger.info("Patched torch.hub.load to replace DINOv2 files after download")
463
+ except Exception as e:
464
+ logger.warning(f"Could not patch torch.hub.load: {e}")
465
+
466
+ logger.info(f"Processing video: {temp_video_path}")
467
+ try:
468
+ output_text = processor.process_video(temp_video_path)
469
+ finally:
470
+ # Restore original function
471
+ if original_torch_hub_load:
472
+ try:
473
+ import torch.hub
474
+ torch.hub.load = original_torch_hub_load
475
+ logger.info("Restored original torch.hub.load")
476
+ except:
477
+ pass
478
+
479
+ logger.info(f"Video processed successfully. Output: {output_text[:100]}...")
480
+
481
+ # Clean up temp file
482
+ if os.path.exists(temp_video_path):
483
+ os.remove(temp_video_path)
484
+ logger.info("Temporary video file cleaned up")
485
+
486
+ return output_text
487
+
488
+ except Exception as e:
489
+ logger.error(f"Error processing video: {str(e)}")
490
+ logger.error(f"Traceback: {traceback.format_exc()}")
491
+ return f"Error processing video: {str(e)}\n\nPlease check that your video is a valid ASL video under 10 seconds."
492
+
493
+ # Create Gradio interface
494
+ def create_interface():
495
+ """Create the Gradio interface"""
496
+ description = """
497
+ Upload an ASL* video to get an English translation. *Sign languages belonging to the same sign language family as ASL (e.g. Ghanaian Sign Language, as well as others listed in Table 7, Row 1 of https://aclanthology.org/2023.findings-emnlp.664.pdf) might also have non-trivial performance, although the model is trained only on ASL data.
498
+
499
+
500
+ This app uses TTIC's foundation model SHuBERT (introduced in an ACL 2025 paper, see http://shubert.pals.ttic.edu).
501
+
502
+ **Requirements:**
503
+ - We recommend that videos be under 60 seconds. Performance for longer videos has not been tested.
504
+ - The signer should be the main part of the video. Videos recorded from a phone camera, tablet, or personal computer should work well. Studio recordings where the signer is farther from the camera may not work as well.
505
+ - Supported formats: MP4, MOV
506
+
507
+ **Note:**
508
+ - Videos will be deleted after the output is generated.
509
+ - Inquires or Feedback? Please email us at [email protected]
510
+ """
511
+
512
+ if initialization_error:
513
+ description += f"\n\n:warning: **Initialization Error:** {initialization_error}"
514
+
515
+ return gr.Interface(
516
+ fn=process_video,
517
+ inputs=gr.Video(label="ASL Video (under 60 seconds)", format="mp4"),
518
+ # inputs=gr.File(
519
+ # label="Upload ASL Video (under 60 seconds)",
520
+ # file_types=[".mp4", ".avi", ".mov", ".webm"],
521
+ # type="filepath" # This tells Gradio to provide the file path directly
522
+ # ),
523
+ outputs=gr.Textbox(label="English Translation", lines=5),
524
+ title="ASL Video to English Text Translation",
525
+ description=description,
526
+ article="",
527
+ examples=[],
528
+ allow_flagging="never"
529
+ )
530
+
531
+ # Create the demo
532
+ demo = create_interface()
533
 
534
+ if __name__ == "__main__":
535
+ # Launch with better configuration for Hugging Face Spaces
536
+ logger.info("Launching Gradio interface...")
537
+ demo.launch(
538
+ server_name="0.0.0.0",
539
+ server_port=7860,
540
+ share=False,
541
+ show_error=True
542
+ )
attention.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ import torch
15
+ from torch import nn, Tensor
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ try:
37
+ from typing import Optional
38
+ from typing import Union
39
+ FloatOrNone = Union[float, None]
40
+ except ImportError:
41
+ FloatOrNone = float | None
42
+
43
+
44
+ class Attention(nn.Module):
45
+ def __init__(
46
+ self,
47
+ dim: int,
48
+ num_heads: int = 8,
49
+ qkv_bias: bool = False,
50
+ proj_bias: bool = True,
51
+ attn_drop: float = 0.0,
52
+ proj_drop: float = 0.0,
53
+ ) -> None:
54
+ super().__init__()
55
+ self.dim = dim
56
+ self.num_heads = num_heads
57
+ head_dim = dim // num_heads
58
+ self.scale = head_dim**-0.5
59
+
60
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
61
+ self.attn_drop = attn_drop
62
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
63
+ self.proj_drop = nn.Dropout(proj_drop)
64
+
65
+ def init_weights(
66
+ self, init_attn_std: FloatOrNone = None, init_proj_std: FloatOrNone = None, factor: float = 1.0
67
+ ) -> None:
68
+ init_attn_std = init_attn_std or (self.dim**-0.5)
69
+ init_proj_std = init_proj_std or init_attn_std * factor
70
+ nn.init.normal_(self.qkv.weight, std=init_attn_std)
71
+ nn.init.normal_(self.proj.weight, std=init_proj_std)
72
+ if self.qkv.bias is not None:
73
+ nn.init.zeros_(self.qkv.bias)
74
+ if self.proj.bias is not None:
75
+ nn.init.zeros_(self.proj.bias)
76
+
77
+ def forward(self, x: Tensor, is_causal: bool = False) -> Tensor:
78
+ B, N, C = x.shape
79
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
80
+ q, k, v = torch.unbind(qkv, 2)
81
+ q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
82
+ x = nn.functional.scaled_dot_product_attention(
83
+ q, k, v, attn_mask=None, dropout_p=self.attn_drop if self.training else 0, is_causal=is_causal
84
+ )
85
+ x = x.transpose(1, 2).contiguous().view(B, N, C)
86
+ x = self.proj_drop(self.proj(x))
87
+ return x
88
+
89
+
90
+ class MemEffAttention(Attention):
91
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
92
+ if not XFORMERS_AVAILABLE:
93
+ if attn_bias is not None:
94
+ raise AssertionError("xFormers is required for using nested tensors")
95
+ return super().forward(x)
96
+
97
+ B, N, C = x.shape
98
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
99
+
100
+ q, k, v = unbind(qkv, 2)
101
+
102
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
103
+ x = x.reshape([B, N, C])
104
+
105
+ x = self.proj(x)
106
+ x = self.proj_drop(x)
107
+ return x
block.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict, Optional
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+ try:
24
+ from typing import Optional
25
+ from typing import Union
26
+ FloatOrNone = Union[float, None]
27
+ except ImportError:
28
+ FloatOrNone = float | None
29
+
30
+ logger = logging.getLogger("dinov2")
31
+
32
+
33
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
34
+ try:
35
+ if XFORMERS_ENABLED:
36
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
37
+
38
+ XFORMERS_AVAILABLE = True
39
+ warnings.warn("xFormers is available (Block)")
40
+ else:
41
+ warnings.warn("xFormers is disabled (Block)")
42
+ raise ImportError
43
+ except ImportError:
44
+ XFORMERS_AVAILABLE = False
45
+
46
+ warnings.warn("xFormers is not available (Block)")
47
+
48
+
49
+ class Block(nn.Module):
50
+ def __init__(
51
+ self,
52
+ dim: int,
53
+ num_heads: int,
54
+ mlp_ratio: float = 4.0,
55
+ qkv_bias: bool = False,
56
+ proj_bias: bool = True,
57
+ ffn_bias: bool = True,
58
+ drop: float = 0.0,
59
+ attn_drop: float = 0.0,
60
+ init_values=None,
61
+ drop_path: float = 0.0,
62
+ act_layer: Callable[..., nn.Module] = nn.GELU,
63
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
64
+ attn_class: Callable[..., nn.Module] = Attention,
65
+ ffn_layer: Callable[..., nn.Module] = Mlp,
66
+ ) -> None:
67
+ super().__init__()
68
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
69
+ self.norm1 = norm_layer(dim)
70
+ self.attn = attn_class(
71
+ dim,
72
+ num_heads=num_heads,
73
+ qkv_bias=qkv_bias,
74
+ proj_bias=proj_bias,
75
+ attn_drop=attn_drop,
76
+ proj_drop=drop,
77
+ )
78
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
79
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
80
+
81
+ self.norm2 = norm_layer(dim)
82
+ mlp_hidden_dim = int(dim * mlp_ratio)
83
+ self.mlp = ffn_layer(
84
+ in_features=dim,
85
+ hidden_features=mlp_hidden_dim,
86
+ act_layer=act_layer,
87
+ drop=drop,
88
+ bias=ffn_bias,
89
+ )
90
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
91
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
92
+
93
+ self.sample_drop_ratio = drop_path
94
+
95
+ def forward(self, x: Tensor) -> Tensor:
96
+ def attn_residual_func(x: Tensor) -> Tensor:
97
+ return self.ls1(self.attn(self.norm1(x)))
98
+
99
+ def ffn_residual_func(x: Tensor) -> Tensor:
100
+ return self.ls2(self.mlp(self.norm2(x)))
101
+
102
+ if self.training and self.sample_drop_ratio > 0.1:
103
+ # the overhead is compensated only for a drop path rate larger than 0.1
104
+ x = drop_add_residual_stochastic_depth(
105
+ x,
106
+ residual_func=attn_residual_func,
107
+ sample_drop_ratio=self.sample_drop_ratio,
108
+ )
109
+ x = drop_add_residual_stochastic_depth(
110
+ x,
111
+ residual_func=ffn_residual_func,
112
+ sample_drop_ratio=self.sample_drop_ratio,
113
+ )
114
+ elif self.training and self.sample_drop_ratio > 0.0:
115
+ x = x + self.drop_path1(attn_residual_func(x))
116
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
117
+ else:
118
+ x = x + attn_residual_func(x)
119
+ x = x + ffn_residual_func(x)
120
+ return x
121
+
122
+
123
+ class CausalAttentionBlock(nn.Module):
124
+ def __init__(
125
+ self,
126
+ dim: int,
127
+ num_heads: int,
128
+ ffn_ratio: float = 4.0,
129
+ ls_init_value: Optional[float] = None,
130
+ is_causal: bool = True,
131
+ act_layer: Callable = nn.GELU,
132
+ norm_layer: Callable = nn.LayerNorm,
133
+ dropout_prob: float = 0.0,
134
+ ):
135
+ super().__init__()
136
+
137
+ self.dim = dim
138
+ self.is_causal = is_causal
139
+ self.ls1 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity()
140
+ self.attention_norm = norm_layer(dim)
141
+ self.attention = Attention(dim, num_heads, attn_drop=dropout_prob, proj_drop=dropout_prob)
142
+
143
+ self.ffn_norm = norm_layer(dim)
144
+ ffn_hidden_dim = int(dim * ffn_ratio)
145
+ self.feed_forward = Mlp(
146
+ in_features=dim,
147
+ hidden_features=ffn_hidden_dim,
148
+ drop=dropout_prob,
149
+ act_layer=act_layer,
150
+ )
151
+
152
+ self.ls2 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity()
153
+
154
+ def init_weights(
155
+ self,
156
+ init_attn_std: FloatOrNone = None,
157
+ init_proj_std: FloatOrNone = None,
158
+ init_fc_std: FloatOrNone = None,
159
+ factor: float = 1.0,
160
+ ) -> None:
161
+ init_attn_std = init_attn_std or (self.dim**-0.5)
162
+ init_proj_std = init_proj_std or init_attn_std * factor
163
+ init_fc_std = init_fc_std or (2 * self.dim) ** -0.5
164
+ self.attention.init_weights(init_attn_std, init_proj_std)
165
+ self.attention_norm.reset_parameters()
166
+ nn.init.normal_(self.feed_forward.fc1.weight, std=init_fc_std)
167
+ nn.init.normal_(self.feed_forward.fc2.weight, std=init_proj_std)
168
+ self.ffn_norm.reset_parameters()
169
+
170
+ def forward(
171
+ self,
172
+ x: torch.Tensor,
173
+ ):
174
+ x_attn = x + self.ls1(self.attention(self.attention_norm(x), self.is_causal))
175
+ x_ffn = x_attn + self.ls2(self.feed_forward(self.ffn_norm(x_attn)))
176
+ return x_ffn
177
+
178
+
179
+ def drop_add_residual_stochastic_depth(
180
+ x: Tensor,
181
+ residual_func: Callable[[Tensor], Tensor],
182
+ sample_drop_ratio: float = 0.0,
183
+ ) -> Tensor:
184
+ # 1) extract subset using permutation
185
+ b, n, d = x.shape
186
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
187
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
188
+ x_subset = x[brange]
189
+
190
+ # 2) apply residual_func to get residual
191
+ residual = residual_func(x_subset)
192
+
193
+ x_flat = x.flatten(1)
194
+ residual = residual.flatten(1)
195
+
196
+ residual_scale_factor = b / sample_subset_size
197
+
198
+ # 3) add the residual
199
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
200
+ return x_plus_residual.view_as(x)
201
+
202
+
203
+ def get_branges_scales(x, sample_drop_ratio=0.0):
204
+ b, n, d = x.shape
205
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
206
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
207
+ residual_scale_factor = b / sample_subset_size
208
+ return brange, residual_scale_factor
209
+
210
+
211
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
212
+ if scaling_vector is None:
213
+ x_flat = x.flatten(1)
214
+ residual = residual.flatten(1)
215
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
216
+ else:
217
+ x_plus_residual = scaled_index_add(
218
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
219
+ )
220
+ return x_plus_residual
221
+
222
+
223
+ attn_bias_cache: Dict[Tuple, Any] = {}
224
+
225
+
226
+ def get_attn_bias_and_cat(x_list, branges=None):
227
+ """
228
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
229
+ """
230
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
231
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
232
+ if all_shapes not in attn_bias_cache.keys():
233
+ seqlens = []
234
+ for b, x in zip(batch_sizes, x_list):
235
+ for _ in range(b):
236
+ seqlens.append(x.shape[1])
237
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
238
+ attn_bias._batch_sizes = batch_sizes
239
+ attn_bias_cache[all_shapes] = attn_bias
240
+
241
+ if branges is not None:
242
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
243
+ else:
244
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
245
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
246
+
247
+ return attn_bias_cache[all_shapes], cat_tensors
248
+
249
+
250
+ def drop_add_residual_stochastic_depth_list(
251
+ x_list: List[Tensor],
252
+ residual_func: Callable[[Tensor, Any], Tensor],
253
+ sample_drop_ratio: float = 0.0,
254
+ scaling_vector=None,
255
+ ) -> Tensor:
256
+ # 1) generate random set of indices for dropping samples in the batch
257
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
258
+ branges = [s[0] for s in branges_scales]
259
+ residual_scale_factors = [s[1] for s in branges_scales]
260
+
261
+ # 2) get attention bias and index+concat the tensors
262
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
263
+
264
+ # 3) apply residual_func to get residual, and split the result
265
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
266
+
267
+ outputs = []
268
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
269
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
270
+ return outputs
271
+
272
+
273
+ class NestedTensorBlock(Block):
274
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
275
+ """
276
+ x_list contains a list of tensors to nest together and run
277
+ """
278
+ assert isinstance(self.attn, MemEffAttention)
279
+
280
+ if self.training and self.sample_drop_ratio > 0.0:
281
+
282
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
283
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
284
+
285
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
286
+ return self.mlp(self.norm2(x))
287
+
288
+ x_list = drop_add_residual_stochastic_depth_list(
289
+ x_list,
290
+ residual_func=attn_residual_func,
291
+ sample_drop_ratio=self.sample_drop_ratio,
292
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
293
+ )
294
+ x_list = drop_add_residual_stochastic_depth_list(
295
+ x_list,
296
+ residual_func=ffn_residual_func,
297
+ sample_drop_ratio=self.sample_drop_ratio,
298
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
299
+ )
300
+ return x_list
301
+ else:
302
+
303
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
304
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
305
+
306
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
307
+ return self.ls2(self.mlp(self.norm2(x)))
308
+
309
+ attn_bias, x = get_attn_bias_and_cat(x_list)
310
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
311
+ x = x + ffn_residual_func(x)
312
+ return attn_bias.split(x)
313
+
314
+ def forward(self, x_or_x_list):
315
+ if isinstance(x_or_x_list, Tensor):
316
+ return super().forward(x_or_x_list)
317
+ elif isinstance(x_or_x_list, list):
318
+ if not XFORMERS_AVAILABLE:
319
+ raise AssertionError("xFormers is required for using nested tensors")
320
+ return self.forward_nested(x_or_x_list)
321
+ else:
322
+ raise AssertionError
body_features.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ import pickle
5
+ import gzip
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ import decord
9
+ import argparse
10
+ import json
11
+ import glob
12
+ import time
13
+ from typing import Dict, List, Optional, Tuple, Union, Any
14
+
15
+
16
+ class PoseProcessor:
17
+ """
18
+ A class for processing pose landmarks and converting them to normalized numpy arrays.
19
+ """
20
+
21
+ def __init__(self, pose_indices: Optional[List[int]] = None,
22
+ normalize_keypoints: bool = True, fill_missing_value: float = -9999.0):
23
+ """
24
+ Initialize the PoseProcessor.
25
+
26
+ Args:
27
+ pose_indices: List of pose landmark indices to extract.
28
+ Default is [0,11,12,13,14,15,16] (nose, shoulders, elbows, wrists)
29
+ normalize_keypoints: Whether to normalize keypoints to signing space
30
+ fill_missing_value: Value to use for missing keypoints
31
+ """
32
+ self.pose_indices = pose_indices if pose_indices else [0, 11, 12, 13, 14, 15, 16]
33
+ self.normalize_keypoints = normalize_keypoints
34
+ self.fill_missing_value = fill_missing_value
35
+
36
+ # Number of coordinates per keypoint (x, y)
37
+ self.coords_per_keypoint = 2
38
+ self.output_shape = (len(self.pose_indices), self.coords_per_keypoint)
39
+
40
+ def normalize_pose_keypoints(self, pose_landmarks: List[List[float]]) -> List[List[float]]:
41
+ """
42
+ Normalize pose keypoints to signing space.
43
+
44
+ Args:
45
+ pose_landmarks: List of pose landmarks from MediaPipe
46
+
47
+ Returns:
48
+ List of normalized pose keypoints
49
+ """
50
+ # Extract relevant landmarks for normalization
51
+ left_shoulder = np.array(pose_landmarks[11][:2])
52
+ right_shoulder = np.array(pose_landmarks[12][:2])
53
+ left_eye = np.array(pose_landmarks[2][:2])
54
+ nose = np.array(pose_landmarks[0][:2])
55
+
56
+ # Calculate head unit in normalized space
57
+ head_unit = np.linalg.norm(right_shoulder - left_shoulder) / 2
58
+
59
+ # Define signing space dimensions in normalized space
60
+ signing_space_width = 6 * head_unit
61
+ signing_space_height = 7 * head_unit
62
+
63
+ # Calculate signing space bounding box in normalized space
64
+ signing_space_top = left_eye[1] - 0.5 * head_unit
65
+ signing_space_bottom = signing_space_top + signing_space_height
66
+ signing_space_left = nose[0] - signing_space_width / 2
67
+ signing_space_right = signing_space_left + signing_space_width
68
+
69
+ # Create transformation matrix
70
+ translation_matrix = np.array([[1, 0, -signing_space_left],
71
+ [0, 1, -signing_space_top],
72
+ [0, 0, 1]])
73
+ scale_matrix = np.array([[1 / signing_space_width, 0, 0],
74
+ [0, 1 / signing_space_height, 0],
75
+ [0, 0, 1]])
76
+ shift_matrix = np.array([[1, 0, -0.5],
77
+ [0, 1, -0.5],
78
+ [0, 0, 1]])
79
+ transformation_matrix = shift_matrix @ scale_matrix @ translation_matrix
80
+
81
+ # Apply transformation to pose keypoints
82
+ normalized_keypoints = []
83
+ for landmark in pose_landmarks:
84
+ keypoint = np.array([landmark[0], landmark[1], 1])
85
+ normalized_keypoint = transformation_matrix @ keypoint
86
+ normalized_keypoints.append(normalized_keypoint[:2].tolist())
87
+
88
+ return normalized_keypoints
89
+
90
+ def process_frame_landmarks(self, frame_landmarks: Optional[Dict[str, Any]]) -> np.ndarray:
91
+ """
92
+ Process landmarks for a single frame.
93
+
94
+ Args:
95
+ frame_landmarks: Dictionary containing pose landmarks for one frame
96
+
97
+ Returns:
98
+ Numpy array of processed pose keypoints
99
+ """
100
+ if frame_landmarks is None or frame_landmarks.get('pose_landmarks') is None:
101
+ # Return missing value array
102
+ return np.full(self.output_shape, self.fill_missing_value).flatten()
103
+
104
+ # Get pose landmarks
105
+ pose_landmarks = frame_landmarks['pose_landmarks'][0]
106
+
107
+ # Normalize keypoints if required
108
+ if self.normalize_keypoints:
109
+ # Take first 25 landmarks for normalization (MediaPipe pose has 33 total)
110
+ normalized_landmarks = self.normalize_pose_keypoints(pose_landmarks[:25])
111
+ else:
112
+ normalized_landmarks = pose_landmarks
113
+
114
+ # Extract only the specified indices
115
+ selected_landmarks = [normalized_landmarks[i] for i in self.pose_indices]
116
+
117
+ # Convert to numpy array and flatten
118
+ frame_keypoints = np.array(selected_landmarks).flatten()
119
+
120
+ return frame_keypoints
121
+
122
+ def process_landmarks_sequence(self, landmarks_data: Dict[int, Any]) -> np.ndarray:
123
+ """
124
+ Process landmarks for an entire sequence (video).
125
+
126
+ Args:
127
+ landmarks_data: Dictionary containing landmarks for each frame
128
+
129
+ Returns:
130
+ Numpy array of shape (num_frames, num_keypoints * 2)
131
+ """
132
+ # Get number of frames
133
+ if not landmarks_data:
134
+ return np.array([])
135
+
136
+ max_frame = max(landmarks_data.keys())
137
+ num_frames = max_frame + 1
138
+
139
+ video_pose_landmarks = []
140
+ prev_pose = None
141
+
142
+ for i in range(num_frames):
143
+ frame_landmarks = landmarks_data.get(i, None)
144
+
145
+ if frame_landmarks is None:
146
+ # Use previous pose if available, otherwise use missing values
147
+ if prev_pose is not None:
148
+ frame_keypoints = prev_pose
149
+ else:
150
+ frame_keypoints = np.full(self.output_shape, self.fill_missing_value).flatten()
151
+ else:
152
+ # Process current frame
153
+ frame_keypoints = self.process_frame_landmarks(frame_landmarks)
154
+ if not np.all(frame_keypoints == self.fill_missing_value):
155
+ prev_pose = frame_keypoints
156
+
157
+ video_pose_landmarks.append(frame_keypoints)
158
+
159
+ # Convert to numpy array
160
+ video_pose_landmarks = np.array(video_pose_landmarks)
161
+
162
+ # Apply any post-processing (like the original code's wrist masking)
163
+ # video_pose_landmarks = self._apply_post_processing(video_pose_landmarks)
164
+
165
+ return video_pose_landmarks
166
+
167
+ def _apply_post_processing(self, pose_array: np.ndarray) -> np.ndarray:
168
+ """
169
+ Apply post-processing to the pose array.
170
+
171
+ Args:
172
+ pose_array: Input pose array
173
+
174
+ Returns:
175
+ Post-processed pose array
176
+ """
177
+ # The original code fills left and right wrist with -9999
178
+ # This corresponds to indices 15 and 16 in the original pose landmarks
179
+ # In our selected indices [0,11,12,13,14,15,16], wrists are at positions 5 and 6
180
+ # Each keypoint has 2 coordinates, so wrists are at positions 10-11 and 12-13
181
+
182
+ # if len(self.pose_indices) >= 7 and 15 in self.pose_indices and 16 in self.pose_indices:
183
+ # # Find positions of wrists in our selected indices
184
+ # left_wrist_idx = self.pose_indices.index(15) * 2 # *2 because each keypoint has x,y
185
+ # right_wrist_idx = self.pose_indices.index(16) * 2
186
+
187
+ # # Fill wrist coordinates with missing value
188
+ # pose_array[:, left_wrist_idx:left_wrist_idx+2] = self.fill_missing_value
189
+ # pose_array[:, right_wrist_idx:right_wrist_idx+2] = self.fill_missing_value
190
+
191
+ return pose_array
192
+
193
+ def process_landmarks_from_file(self, pose_file_path: str) -> np.ndarray:
194
+ """
195
+ Process landmarks from a JSON file.
196
+
197
+ Args:
198
+ pose_file_path: Path to the pose landmarks JSON file
199
+
200
+ Returns:
201
+ Numpy array of processed pose keypoints
202
+ """
203
+ try:
204
+ with open(pose_file_path, 'r') as f:
205
+ landmarks_data = json.load(f)
206
+
207
+ # Convert string keys to integers
208
+ landmarks_data = {int(k): v for k, v in landmarks_data.items()}
209
+
210
+ return self.process_landmarks_sequence(landmarks_data)
211
+
212
+ except Exception as e:
213
+ print(f"Error processing {pose_file_path}: {e}")
214
+ return np.array([])
215
+
216
+ def process_and_save_landmarks(self, landmarks_data: Dict[int, Any],
217
+ output_path: str, filename: str) -> str:
218
+ """
219
+ Process landmarks and save to file.
220
+
221
+ Args:
222
+ landmarks_data: Dictionary containing landmarks for each frame
223
+ output_path: Directory to save the processed landmarks
224
+ filename: Name for the output file (without extension)
225
+
226
+ Returns:
227
+ Path to the saved file
228
+ """
229
+ # Process landmarks
230
+ processed_landmarks = self.process_landmarks_sequence(landmarks_data)
231
+
232
+ # Create output directory if it doesn't exist
233
+ output_dir = Path(output_path)
234
+ output_dir.mkdir(parents=True, exist_ok=True)
235
+
236
+ # Save to file
237
+ save_path = output_dir / f"{filename}.npy"
238
+ np.save(save_path, processed_landmarks)
239
+
240
+ return str(save_path)
241
+
242
+
243
+ # Convenience functions for backward compatibility
244
+ def process_pose_landmarks(landmarks_data: Dict[int, Any],
245
+ normalize: bool = True,
246
+ pose_indices: Optional[List[int]] = None) -> np.ndarray:
247
+ """
248
+ Convenience function to process pose landmarks.
249
+
250
+ Args:
251
+ landmarks_data: Dictionary containing landmarks for each frame
252
+ normalize: Whether to normalize keypoints to signing space
253
+ pose_indices: List of pose landmark indices to extract
254
+
255
+ Returns:
256
+ Numpy array of processed pose keypoints
257
+ """
258
+ processor = PoseProcessor(pose_indices=pose_indices, normalize_keypoints=normalize)
259
+ return processor.process_landmarks_sequence(landmarks_data)
260
+
261
+
262
+ def keypoints_to_numpy(pose_file: str, pose_emb_path: str):
263
+ """
264
+ Original function for backward compatibility with command-line usage.
265
+ """
266
+ try:
267
+ processor = PoseProcessor()
268
+ processed_landmarks = processor.process_landmarks_from_file(pose_file)
269
+
270
+ if processed_landmarks.size > 0:
271
+ # Save the processed landmarks
272
+ video_name = Path(pose_file).stem
273
+ save_path = Path(pose_emb_path) / f"{video_name}.npy"
274
+ save_path.parent.mkdir(parents=True, exist_ok=True)
275
+ np.save(save_path, processed_landmarks)
276
+
277
+ except Exception as e:
278
+ print(f"Error processing {pose_file}: {e}")
279
+
280
+
281
+ # Utility functions for batch processing
282
+ def get_mp4_files(directory: str) -> List[str]:
283
+ """Get all MP4 files in a directory."""
284
+ if not os.path.exists(directory):
285
+ raise FileNotFoundError(f'Directory not found: {directory}')
286
+
287
+ mp4_files = glob.glob(os.path.join(directory, '*.mp4'))
288
+ return [os.path.abspath(file) for file in mp4_files]
289
+
290
+
291
+ def load_file(filename: str):
292
+ """Load a pickled and gzipped file."""
293
+ with gzip.open(filename, "rb") as f:
294
+ return pickle.load(f)
295
+
296
+
297
+ def is_string_in_file(file_path: str, target_string: str) -> bool:
298
+ """Check if a string exists in a file."""
299
+ try:
300
+ with Path(file_path).open("r") as f:
301
+ for line in f:
302
+ if target_string in line:
303
+ return True
304
+ return False
305
+ except Exception as e:
306
+ print(f"Error: {e}")
307
+ return False
308
+
309
+
310
+ def main():
311
+ """Main function for command-line usage."""
312
+ parser = argparse.ArgumentParser()
313
+ parser.add_argument('--index', type=int, required=True,
314
+ help='index of the sub_list to work with')
315
+ parser.add_argument('--files_list', type=str, required=True,
316
+ help='path to the pose file')
317
+ parser.add_argument('--pose_features_path', type=str, required=True,
318
+ help='path to the pose features file')
319
+ parser.add_argument('--batch_size', type=int, required=True,
320
+ help='batch size')
321
+ parser.add_argument('--time_limit', type=int, required=True,
322
+ help='time limit')
323
+
324
+ args = parser.parse_args()
325
+ start_time = time.time()
326
+
327
+ # Load files list
328
+ fixed_list = load_file(args.files_list)
329
+
330
+ # Initialize processor
331
+ processor = PoseProcessor()
332
+
333
+ # Process files in batches
334
+ video_batches = [fixed_list[i:i + args.batch_size] for i in range(0, len(fixed_list), args.batch_size)]
335
+
336
+ for pose_file in video_batches[args.index]:
337
+ pose_file_path = Path(pose_file)
338
+ output_path = Path(args.pose_features_path) / f"{pose_file_path.stem}.npy"
339
+
340
+ if output_path.exists():
341
+ print(f"Skipping {pose_file} - output already exists")
342
+ continue
343
+
344
+ current_time = time.time()
345
+ if current_time - start_time > args.time_limit:
346
+ print("Time limit reached. Stopping execution.")
347
+ break
348
+
349
+ try:
350
+ print(f"Processing {pose_file}")
351
+ keypoints_to_numpy(pose_file, args.pose_features_path)
352
+ print(f"Successfully processed {pose_file}")
353
+ except Exception as e:
354
+ print(f"Error processing {pose_file}: {e}")
355
+
356
+
357
+ if __name__ == "__main__":
358
+ main()
crop_face.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ import pickle
5
+ import gzip
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ import decord
9
+ import argparse
10
+ import json
11
+ import time
12
+ from typing import Dict, Optional, Tuple, List, Union, Any
13
+
14
+
15
+ class FaceExtractor:
16
+ """
17
+ A class for extracting face regions from videos based on pose and face landmarks.
18
+ Creates face frames with only eyes and mouth visible on grey background.
19
+ """
20
+
21
+ def __init__(self, output_size: Tuple[int, int] = (224, 224),
22
+ scale_factor: float = 1.2, grey_background_color: int = 128):
23
+ """
24
+ Initialize the FaceExtractor.
25
+
26
+ Args:
27
+ output_size: Size of the output face frames (width, height)
28
+ scale_factor: Scale factor for bounding box expansion
29
+ grey_background_color: Color value for grey background (0-255)
30
+ """
31
+ self.output_size = output_size
32
+ self.scale_factor = scale_factor
33
+ self.grey_background_color = grey_background_color
34
+
35
+ # Face landmark indices for eyes and mouth
36
+ self.left_eye_indices = [69, 168, 156, 118, 54]
37
+ self.right_eye_indices = [168, 299, 347, 336, 301]
38
+ self.mouth_indices = [164, 212, 432, 18]
39
+
40
+ def resize_frame(self, frame: np.ndarray, frame_size: Tuple[int, int]) -> Optional[np.ndarray]:
41
+ """Resize frame to specified size."""
42
+ if frame is not None and frame.size > 0:
43
+ return cv2.resize(frame, frame_size, interpolation=cv2.INTER_AREA)
44
+ else:
45
+ return None
46
+
47
+ def calculate_bounding_box(self, landmarks: List[List[float]], indices: List[int],
48
+ image_shape: Tuple[int, int, int]) -> Tuple[int, int, int, int]:
49
+ """Calculate bounding box for specific landmark indices."""
50
+ x_coordinates = [landmarks[i][0] for i in indices]
51
+ y_coordinates = [landmarks[i][1] for i in indices]
52
+
53
+ left = min(x_coordinates)
54
+ right = max(x_coordinates)
55
+ top = min(y_coordinates)
56
+ bottom = max(y_coordinates)
57
+
58
+ return (int(left * image_shape[1]), int(top * image_shape[0]),
59
+ int(right * image_shape[1]), int(bottom * image_shape[0]))
60
+
61
+ def crop_and_paste(self, src: np.ndarray, dst: np.ndarray,
62
+ src_box: Tuple[int, int, int, int], dst_origin: Tuple[int, int]):
63
+ """Crop region from source and paste to destination."""
64
+ x1, y1, x2, y2 = src_box
65
+ dx, dy = dst_origin
66
+ crop = src[y1:y2, x1:x2]
67
+ crop_height, crop_width = crop.shape[:2]
68
+ dst[dy:dy+crop_height, dx:dx+crop_width] = crop
69
+
70
+ def cues_on_grey_background(self, image: np.ndarray, facial_landmarks: List[List[float]]) -> np.ndarray:
71
+ """
72
+ Create face frame with only eyes and mouth visible on grey background.
73
+
74
+ Args:
75
+ image: Input image as numpy array
76
+ facial_landmarks: Face landmarks from MediaPipe
77
+
78
+ Returns:
79
+ Face frame with eyes and mouth on grey background
80
+ """
81
+ image_shape = image.shape
82
+
83
+ # Calculate bounding boxes for facial features
84
+ left_eye_box = self.calculate_bounding_box(facial_landmarks, self.left_eye_indices, image_shape)
85
+ right_eye_box = self.calculate_bounding_box(facial_landmarks, self.right_eye_indices, image_shape)
86
+ mouth_box = self.calculate_bounding_box(facial_landmarks, self.mouth_indices, image_shape)
87
+
88
+ # Calculate the overall bounding box
89
+ min_x = min(left_eye_box[0], right_eye_box[0], mouth_box[0])
90
+ min_y = min(left_eye_box[1], right_eye_box[1], mouth_box[1])
91
+ max_x = max(left_eye_box[2], right_eye_box[2], mouth_box[2])
92
+ max_y = max(left_eye_box[3], right_eye_box[3], mouth_box[3])
93
+
94
+ # Add padding
95
+ padding = 10
96
+ min_x = max(0, min_x - padding)
97
+ min_y = max(0, min_y - padding)
98
+ max_x = min(image.shape[1], max_x + padding)
99
+ max_y = min(image.shape[0], max_y + padding)
100
+
101
+ # Make the crop a square by adjusting either width or height
102
+ width = max_x - min_x
103
+ height = max_y - min_y
104
+ side_length = max(width, height)
105
+
106
+ # Adjust to ensure square
107
+ if width < side_length:
108
+ extra = side_length - width
109
+ min_x = max(0, min_x - extra // 2)
110
+ max_x = min(image.shape[1], max_x + extra // 2)
111
+
112
+ if height < side_length:
113
+ extra = side_length - height
114
+ min_y = max(0, min_y - extra // 2)
115
+ max_y = min(image.shape[0], max_y + extra // 2)
116
+
117
+ # Create grey background image
118
+ grey_background = np.ones((side_length, side_length, 3), dtype=np.uint8) * self.grey_background_color
119
+
120
+ # Crop and paste facial features onto grey background
121
+ self.crop_and_paste(image, grey_background, left_eye_box, (left_eye_box[0]-min_x, left_eye_box[1]-min_y))
122
+ self.crop_and_paste(image, grey_background, right_eye_box, (right_eye_box[0]-min_x, right_eye_box[1]-min_y))
123
+ self.crop_and_paste(image, grey_background, mouth_box, (mouth_box[0]-min_x, mouth_box[1]-min_y))
124
+
125
+ return grey_background
126
+
127
+ def select_face(self, pose_landmarks: List[List[float]], face_landmarks: List[List[List[float]]]) -> List[List[float]]:
128
+ """
129
+ Select the face that is closest to the pose nose landmark.
130
+
131
+ Args:
132
+ pose_landmarks: Pose landmarks from MediaPipe
133
+ face_landmarks: List of face landmarks from MediaPipe
134
+
135
+ Returns:
136
+ Selected face landmarks
137
+ """
138
+ nose_landmark_from_pose = pose_landmarks[0] # Nose from pose
139
+ nose_landmarks_from_face = [face_landmarks[i][0] for i in range(len(face_landmarks))]
140
+
141
+ # Find closest face based on nose landmark
142
+ distances = [np.linalg.norm(np.array(nose_landmark_from_pose) - np.array(nose_landmark))
143
+ for nose_landmark in nose_landmarks_from_face]
144
+ closest_nose_index = np.argmin(distances)
145
+
146
+ return face_landmarks[closest_nose_index]
147
+
148
+ def extract_face_frames(self, video_input, landmarks_data: Dict[int, Any]) -> List[np.ndarray]:
149
+ """
150
+ Extract face frames from video based on landmarks.
151
+
152
+ Args:
153
+ video_input: Either a path to video file (str) or a decord.VideoReader object
154
+ landmarks_data: Dictionary containing pose and face landmarks for each frame
155
+
156
+ Returns:
157
+ List of face frames as numpy arrays
158
+ """
159
+ # Handle different input types
160
+ if isinstance(video_input, str):
161
+ video_path = Path(video_input)
162
+ if not video_path.exists():
163
+ raise FileNotFoundError(f"Video file not found: {video_input}")
164
+ video = decord.VideoReader(str(video_path))
165
+ # elif hasattr(video_input, '__len__') and hasattr(video_input, '__getitem__'):
166
+ else:
167
+ video = video_input
168
+ # else:
169
+ # raise TypeError("video_input must be either a file path (str) or a VideoReader object")
170
+
171
+ face_frames = []
172
+ prev_face_frame = None
173
+ prev_landmarks = None
174
+
175
+ for i in range(len(video)):
176
+ # frame = video[i].asnumpy()
177
+ frame = video[i]
178
+ if hasattr(video, 'seek'):
179
+ video.seek(0)
180
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
181
+
182
+ # Get landmarks for this frame
183
+ frame_landmarks = landmarks_data.get(i, None)
184
+
185
+ # Handle missing landmarks
186
+ if frame_landmarks is None:
187
+ if prev_landmarks is not None:
188
+ frame_landmarks = prev_landmarks
189
+ else:
190
+ # Use blank frame if no landmarks available
191
+ face_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
192
+ continue
193
+ else:
194
+ prev_landmarks = frame_landmarks
195
+
196
+ # Check if pose landmarks exist
197
+ if frame_landmarks.get('pose_landmarks') is None:
198
+ if prev_face_frame is not None:
199
+ face_frames.append(prev_face_frame)
200
+ else:
201
+ face_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
202
+ continue
203
+
204
+ # Process face if face landmarks exist
205
+ if frame_landmarks.get('face_landmarks') is not None:
206
+ # Select the face closest to the pose
207
+ selected_face = self.select_face(
208
+ frame_landmarks['pose_landmarks'][0],
209
+ frame_landmarks['face_landmarks']
210
+ )
211
+
212
+ # Create face frame with cues on grey background
213
+ face_frame = self.cues_on_grey_background(frame_rgb, selected_face)
214
+ face_frame = self.resize_frame(face_frame, self.output_size)
215
+ face_frames.append(face_frame)
216
+ prev_face_frame = face_frame
217
+
218
+ elif prev_face_frame is not None:
219
+ face_frames.append(prev_face_frame)
220
+ else:
221
+ # Use blank frame if no face landmarks
222
+ face_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
223
+
224
+ return face_frames
225
+
226
+ def extract_and_save_face_video(self, video_input, landmarks_data: Dict[int, Any],
227
+ output_dir: str, video_name: Optional[str] = None) -> str:
228
+ """
229
+ Extract face frames and save as video file.
230
+
231
+ Args:
232
+ video_input: Either a path to video file (str) or a decord.VideoReader object
233
+ landmarks_data: Dictionary containing pose and face landmarks for each frame
234
+ output_dir: Directory to save the face video
235
+ video_name: Name for output video (auto-generated if not provided)
236
+
237
+ Returns:
238
+ Path to the saved face video
239
+ """
240
+ # Handle video input and get FPS
241
+ if isinstance(video_input, str):
242
+ video_path = Path(video_input)
243
+ if not video_path.exists():
244
+ raise FileNotFoundError(f"Video file not found: {video_input}")
245
+ video = decord.VideoReader(str(video_path))
246
+ if video_name is None:
247
+ video_name = video_path.stem
248
+ # elif hasattr(video_input, '__len__') and hasattr(video_input, '__getitem__'):
249
+ else:
250
+ video = video_input
251
+ if video_name is None:
252
+ video_name = "video"
253
+ # else:
254
+ # raise TypeError("video_input must be either a file path (str) or a VideoReader object")
255
+
256
+ fps = video.get_avg_fps() if hasattr(video, 'get_avg_fps') else 30.0
257
+
258
+ # Create output directory
259
+ output_path = Path(output_dir)
260
+ output_path.mkdir(parents=True, exist_ok=True)
261
+
262
+ # Define output path
263
+ face_video_path = output_path / f"{video_name}_face.mp4"
264
+
265
+ # Remove existing file
266
+ if face_video_path.exists():
267
+ face_video_path.unlink()
268
+
269
+ # Create video writer
270
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
271
+ writer = cv2.VideoWriter(str(face_video_path), fourcc, fps, self.output_size)
272
+
273
+ try:
274
+ # Extract face frames
275
+ face_frames = self.extract_face_frames(video, landmarks_data)
276
+
277
+ # Write frames to video file
278
+ for frame in face_frames:
279
+ writer.write(frame)
280
+
281
+ finally:
282
+ # Clean up
283
+ writer.release()
284
+ del writer
285
+
286
+ return str(face_video_path)
287
+
288
+
289
+ # Convenience function for backward compatibility
290
+ def extract_face_frames(video_input, landmarks_data: Dict[int, Any],
291
+ output_size: Tuple[int, int] = (224, 224)) -> List[np.ndarray]:
292
+ """
293
+ Convenience function to extract face frames from video.
294
+
295
+ Args:
296
+ video_input: Either a path to video file (str) or a decord.VideoReader object
297
+ landmarks_data: Dictionary containing pose and face landmarks for each frame
298
+ output_size: Size of the output face frames (width, height)
299
+
300
+ Returns:
301
+ List of face frames as numpy arrays
302
+ """
303
+ extractor = FaceExtractor(output_size=output_size)
304
+ return extractor.extract_face_frames(video_input, landmarks_data)
305
+
306
+
307
+ def video_holistic(video_file: str, face_path: str, problem_file_path: str, pose_path: str):
308
+ """
309
+ Original function for backward compatibility with command-line usage.
310
+ """
311
+ try:
312
+ video = decord.VideoReader(video_file)
313
+ fps = video.get_avg_fps()
314
+
315
+ video_name = Path(video_file).stem
316
+ clip_face_path = Path(face_path) / f"{video_name}_face.mp4"
317
+ landmark_json_path = Path(pose_path) / f"{video_name}_pose.json"
318
+
319
+ # Load landmarks
320
+ with open(landmark_json_path, 'r') as rd:
321
+ landmarks_data = json.load(rd)
322
+
323
+ # Convert string keys to integers
324
+ landmarks_data = {int(k): v for k, v in landmarks_data.items()}
325
+
326
+ # Extract face video
327
+ extractor = FaceExtractor()
328
+ extractor.extract_and_save_face_video(video, landmarks_data, face_path, video_name)
329
+
330
+ except Exception as e:
331
+ print(f"Error processing {video_file}: {e}")
332
+ with open(problem_file_path, "a") as p:
333
+ p.write(video_file + "\n")
334
+
335
+
336
+ # Utility functions for batch processing
337
+ def load_file(filename: str):
338
+ """Load a pickled and gzipped file."""
339
+ with gzip.open(filename, "rb") as f:
340
+ return pickle.load(f)
341
+
342
+
343
+ def is_string_in_file(file_path: str, target_string: str) -> bool:
344
+ """Check if a string exists in a file."""
345
+ try:
346
+ with Path(file_path).open("r") as f:
347
+ for line in f:
348
+ if target_string in line:
349
+ return True
350
+ return False
351
+ except Exception as e:
352
+ print(f"Error: {e}")
353
+ return False
354
+
355
+
356
+ def main():
357
+ """Main function for command-line usage."""
358
+ parser = argparse.ArgumentParser()
359
+ parser.add_argument('--index', type=int, required=True,
360
+ help='index of the sub_list to work with')
361
+ parser.add_argument('--batch_size', type=int, required=True,
362
+ help='batch size')
363
+ parser.add_argument('--time_limit', type=int, required=True,
364
+ help='time limit')
365
+ parser.add_argument('--files_list', type=str, required=True,
366
+ help='files list')
367
+ parser.add_argument('--problem_file_path', type=str, required=True,
368
+ help='problem file path')
369
+ parser.add_argument('--pose_path', type=str, required=True,
370
+ help='pose path')
371
+ parser.add_argument('--face_path', type=str, required=True,
372
+ help='face path')
373
+
374
+ args = parser.parse_args()
375
+ start_time = time.time()
376
+
377
+ # Load files list
378
+ fixed_list = load_file(args.files_list)
379
+
380
+ # Create problem file if it doesn't exist
381
+ if not os.path.exists(args.problem_file_path):
382
+ with open(args.problem_file_path, "w") as f:
383
+ f.write("")
384
+
385
+ # Process videos in batches
386
+ video_batches = [fixed_list[i:i + args.batch_size] for i in range(0, len(fixed_list), args.batch_size)]
387
+
388
+ for video_file in video_batches[args.index]:
389
+ current_time = time.time()
390
+ if current_time - start_time > args.time_limit:
391
+ print("Time limit reached. Stopping execution.")
392
+ break
393
+
394
+ video_name = Path(video_file).stem
395
+ clip_face_path = Path(args.face_path) / f"{video_name}_face.mp4"
396
+
397
+ if clip_face_path.exists():
398
+ print(f"Skipping {video_file} - output already exists")
399
+ continue
400
+ elif is_string_in_file(args.problem_file_path, video_file):
401
+ print(f"Skipping {video_file} - found in problem file")
402
+ continue
403
+ else:
404
+ try:
405
+ print(f"Processing {video_file}")
406
+ video_holistic(video_file, args.face_path, args.problem_file_path, args.pose_path)
407
+ print(f"Successfully processed {video_file}")
408
+ except Exception as e:
409
+ print(f"Error processing {video_file}: {e}")
410
+ with open(args.problem_file_path, "a") as p:
411
+ p.write(video_file + "\n")
412
+
413
+
414
+ if __name__ == "__main__":
415
+ main()
crop_hands.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ import pickle
5
+ import gzip
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ import decord
9
+ import argparse
10
+ import json
11
+ import time
12
+ from typing import Dict, Optional, Tuple, List, Union, Any
13
+ import tempfile
14
+
15
+
16
+ class HandExtractor:
17
+ """
18
+ A class for extracting hand regions from videos based on pose landmarks.
19
+ """
20
+
21
+ def __init__(self, output_size: Tuple[int, int] = (224, 224),
22
+ scale_factor: float = 1.5, distance_threshold: float = 0.1):
23
+ """
24
+ Initialize the HandExtractor.
25
+
26
+ Args:
27
+ output_size: Size of the output hand frames (width, height)
28
+ scale_factor: Scale factor for bounding box expansion
29
+ distance_threshold: Distance threshold for hand-pose matching
30
+ """
31
+ self.output_size = output_size
32
+ self.scale_factor = scale_factor
33
+ self.distance_threshold = distance_threshold
34
+
35
+ def resize_frame(self, frame: np.ndarray, frame_size: Tuple[int, int]) -> Optional[np.ndarray]:
36
+ """Resize frame to specified size."""
37
+ if frame is not None and frame.size > 0:
38
+ return cv2.resize(frame, frame_size, interpolation=cv2.INTER_AREA)
39
+ else:
40
+ return None
41
+
42
+ def crop_frame(self, image: np.ndarray, bounding_box: Tuple[int, int, int, int]) -> np.ndarray:
43
+ """Crop frame using bounding box."""
44
+ x, y, w, h = bounding_box
45
+ cropped_frame = image[y:y + h, x:x + w]
46
+ return cropped_frame
47
+
48
+ def get_bounding_box(self, landmarks: List[List[float]], image_shape: Tuple[int, int, int],
49
+ scale_factor: float = 1.2) -> Tuple[int, int, int, int]:
50
+ """Get bounding box from landmarks."""
51
+ ih, iw, _ = image_shape
52
+ landmarks_px = np.array([(int(l[0] * iw), int(l[1] * ih)) for l in landmarks])
53
+ center_x, center_y = np.mean(landmarks_px, axis=0, dtype=int)
54
+ xb, yb, wb, hb = cv2.boundingRect(landmarks_px)
55
+ box_size = max(wb, hb)
56
+ half_size = box_size // 2
57
+ x = center_x - half_size
58
+ y = center_y - half_size
59
+ w = box_size
60
+ h = box_size
61
+
62
+ w_padding = int((scale_factor - 1) * w / 2)
63
+ h_padding = int((scale_factor - 1) * h / 2)
64
+ x -= w_padding
65
+ y -= h_padding
66
+ w += 2 * w_padding
67
+ h += 2 * h_padding
68
+
69
+ return x, y, w, h
70
+
71
+ def adjust_bounding_box(self, bounding_box: Tuple[int, int, int, int],
72
+ image_shape: Tuple[int, int, int]) -> Tuple[int, int, int, int]:
73
+ """Adjust bounding box to fit within image boundaries."""
74
+ x, y, w, h = bounding_box
75
+ ih, iw, _ = image_shape
76
+
77
+ # Adjust x-coordinate if the bounding box extends beyond the image's right edge
78
+ if x + w > iw:
79
+ x = iw - w
80
+
81
+ # Adjust y-coordinate if the bounding box extends beyond the image's bottom edge
82
+ if y + h > ih:
83
+ y = ih - h
84
+
85
+ # Ensure bounding box's x and y coordinates are not negative
86
+ x = max(x, 0)
87
+ y = max(y, 0)
88
+
89
+ return x, y, w, h
90
+
91
+ def select_hands(self, pose_landmarks: List[List[float]], hand_landmarks: Optional[List[List[List[float]]]],
92
+ image_shape: Tuple[int, int, int]) -> Tuple[Optional[List[List[float]]], Optional[List[List[float]]]]:
93
+ """
94
+ Select left and right hands from detected hand landmarks based on pose wrist positions.
95
+
96
+ Args:
97
+ pose_landmarks: Pose landmarks from MediaPipe
98
+ hand_landmarks: Hand landmarks from MediaPipe
99
+ image_shape: Shape of the image (height, width, channels)
100
+
101
+ Returns:
102
+ Tuple of (left_hand_landmarks, right_hand_landmarks)
103
+ """
104
+ if hand_landmarks is None:
105
+ return None, None
106
+
107
+ # Get wrist landmarks from pose (indices 15 and 16 for left and right wrists)
108
+ left_wrist_from_pose = pose_landmarks[15]
109
+ right_wrist_from_pose = pose_landmarks[16]
110
+
111
+ # Get wrist landmarks from hand detections (index 0 is wrist in hand landmarks)
112
+ wrist_from_hand = [hand_landmarks[i][0] for i in range(len(hand_landmarks))]
113
+
114
+ # Match right hand
115
+ right_hand_landmarks = None
116
+ if right_wrist_from_pose is not None:
117
+ minimum_distance = 100
118
+ best_hand_idx = 0
119
+ for i in range(len(hand_landmarks)):
120
+ distance = np.linalg.norm(np.array(right_wrist_from_pose[0:2]) - np.array(wrist_from_hand[i][0:2]))
121
+ if distance < minimum_distance:
122
+ minimum_distance = distance
123
+ best_hand_idx = i
124
+
125
+ if minimum_distance < self.distance_threshold:
126
+ right_hand_landmarks = hand_landmarks[best_hand_idx]
127
+
128
+ # Match left hand
129
+ left_hand_landmarks = None
130
+ if left_wrist_from_pose is not None:
131
+ minimum_distance = 100
132
+ best_hand_idx = 0
133
+ for i in range(len(hand_landmarks)):
134
+ distance = np.linalg.norm(np.array(left_wrist_from_pose[0:2]) - np.array(wrist_from_hand[i][0:2]))
135
+ if distance < minimum_distance:
136
+ minimum_distance = distance
137
+ best_hand_idx = i
138
+
139
+ if minimum_distance < self.distance_threshold:
140
+ left_hand_landmarks = hand_landmarks[best_hand_idx]
141
+
142
+ return left_hand_landmarks, right_hand_landmarks
143
+
144
+ def extract_hand_frames(self, video_input, landmarks_data: Dict[int, Any]) -> Tuple[List[np.ndarray], List[np.ndarray]]:
145
+ """
146
+ Extract hand frames from video based on landmarks.
147
+
148
+ Args:
149
+ video_input: Either a path to video file (str) or a decord.VideoReader object
150
+ landmarks_data: Dictionary containing pose and hand landmarks for each frame
151
+
152
+ Returns:
153
+ Tuple of (left_hand_frames, right_hand_frames) as lists of numpy arrays
154
+ """
155
+ # Handle different input types
156
+ if isinstance(video_input, str):
157
+ video_path = Path(video_input)
158
+ if not video_path.exists():
159
+ raise FileNotFoundError(f"Video file not found: {video_input}")
160
+ video = decord.VideoReader(str(video_path))
161
+ # elif hasattr(video_input, '__len__') and hasattr(video_input, '__getitem__'):
162
+ else:
163
+ video = video_input
164
+ # else:
165
+ # raise TypeError("video_input must be either a file path (str) or a VideoReader object")
166
+
167
+ left_hand_frames = []
168
+ right_hand_frames = []
169
+
170
+ prev_left_frame = None
171
+ prev_right_frame = None
172
+ prev_landmarks = None
173
+
174
+ for i in range(len(video)):
175
+ # frame = video[i].asnumpy()
176
+ frame = video[i]
177
+ if hasattr(video, 'seek'):
178
+ video.seek(0)
179
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
180
+
181
+ # Get landmarks for this frame
182
+ frame_landmarks = landmarks_data.get(i, None)
183
+
184
+ # Handle missing landmarks
185
+ if frame_landmarks is None:
186
+ if prev_landmarks is not None:
187
+ frame_landmarks = prev_landmarks
188
+ else:
189
+ # Use blank frames if no landmarks available
190
+ left_hand_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
191
+ right_hand_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
192
+ continue
193
+ else:
194
+ prev_landmarks = frame_landmarks
195
+
196
+ # Check if pose landmarks exist
197
+ if frame_landmarks.get('pose_landmarks') is None:
198
+ # Use previous frames or blank frames
199
+ if prev_left_frame is not None:
200
+ left_hand_frames.append(prev_left_frame)
201
+ else:
202
+ left_hand_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
203
+
204
+ if prev_right_frame is not None:
205
+ right_hand_frames.append(prev_right_frame)
206
+ else:
207
+ right_hand_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
208
+ continue
209
+
210
+ # Select hands based on pose landmarks
211
+ left_hand_landmarks, right_hand_landmarks = self.select_hands(
212
+ frame_landmarks['pose_landmarks'][0],
213
+ frame_landmarks.get('hand_landmarks'),
214
+ frame_rgb.shape
215
+ )
216
+
217
+ # Process left hand
218
+ if left_hand_landmarks is not None:
219
+ left_box = self.get_bounding_box(left_hand_landmarks, frame_rgb.shape, self.scale_factor)
220
+ left_box = self.adjust_bounding_box(left_box, frame_rgb.shape)
221
+ left_frame = self.crop_frame(frame_rgb, left_box)
222
+ left_frame = self.resize_frame(left_frame, self.output_size)
223
+ left_hand_frames.append(left_frame)
224
+ prev_left_frame = left_frame
225
+ elif prev_left_frame is not None:
226
+ left_hand_frames.append(prev_left_frame)
227
+ else:
228
+ left_hand_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
229
+
230
+ # Process right hand
231
+ if right_hand_landmarks is not None:
232
+ right_box = self.get_bounding_box(right_hand_landmarks, frame_rgb.shape, self.scale_factor)
233
+ right_box = self.adjust_bounding_box(right_box, frame_rgb.shape)
234
+ right_frame = self.crop_frame(frame_rgb, right_box)
235
+ right_frame = self.resize_frame(right_frame, self.output_size)
236
+ right_hand_frames.append(right_frame)
237
+ prev_right_frame = right_frame
238
+ elif prev_right_frame is not None:
239
+ right_hand_frames.append(prev_right_frame)
240
+ else:
241
+ right_hand_frames.append(np.zeros((*self.output_size, 3), dtype=np.uint8))
242
+
243
+ return left_hand_frames, right_hand_frames
244
+
245
+ def extract_and_save_hand_videos(self, video_input, landmarks_data: Dict[int, Any],
246
+ output_dir: str, video_name: Optional[str] = None) -> Tuple[str, str]:
247
+ """
248
+ Extract hand frames and save as video files.
249
+
250
+ Args:
251
+ video_input: Either a path to video file (str) or a decord.VideoReader object
252
+ landmarks_data: Dictionary containing pose and hand landmarks for each frame
253
+ output_dir: Directory to save the hand videos
254
+ video_name: Name for output videos (auto-generated if not provided)
255
+
256
+ Returns:
257
+ Tuple of (left_hand_video_path, right_hand_video_path)
258
+ """
259
+ # Handle video input and get FPS
260
+ if isinstance(video_input, str):
261
+ video_path = Path(video_input)
262
+ if not video_path.exists():
263
+ raise FileNotFoundError(f"Video file not found: {video_input}")
264
+ video = decord.VideoReader(str(video_path))
265
+ if video_name is None:
266
+ video_name = video_path.stem
267
+ # elif hasattr(video_input, '__len__') and hasattr(video_input, '__getitem__'):
268
+ else:
269
+ video = video_input
270
+ if video_name is None:
271
+ video_name = "video"
272
+ # else:
273
+ # raise TypeError("video_input must be either a file path (str) or a VideoReader object")
274
+
275
+ fps = video.get_avg_fps() if hasattr(video, 'get_avg_fps') else 30.0
276
+
277
+ # Create output directory
278
+ output_path = Path(output_dir)
279
+ output_path.mkdir(parents=True, exist_ok=True)
280
+
281
+ # Define output paths
282
+ left_hand_path = output_path / f"{video_name}_hand1.mp4"
283
+ right_hand_path = output_path / f"{video_name}_hand2.mp4"
284
+
285
+ # Remove existing files
286
+ if left_hand_path.exists():
287
+ left_hand_path.unlink()
288
+ if right_hand_path.exists():
289
+ right_hand_path.unlink()
290
+
291
+ # Create video writers
292
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
293
+ left_writer = cv2.VideoWriter(str(left_hand_path), fourcc, fps, self.output_size)
294
+ right_writer = cv2.VideoWriter(str(right_hand_path), fourcc, fps, self.output_size)
295
+
296
+ try:
297
+ # Extract hand frames
298
+ left_frames, right_frames = self.extract_hand_frames(video, landmarks_data)
299
+
300
+ # Write frames to video files
301
+ for left_frame, right_frame in zip(left_frames, right_frames):
302
+ left_writer.write(left_frame)
303
+ right_writer.write(right_frame)
304
+
305
+ finally:
306
+ # Clean up
307
+ left_writer.release()
308
+ right_writer.release()
309
+ del left_writer
310
+ del right_writer
311
+
312
+ return str(left_hand_path), str(right_hand_path)
313
+
314
+
315
+ # Convenience function for backward compatibility
316
+ def extract_hand_frames(video_input, landmarks_data: Dict[int, Any],
317
+ output_size: Tuple[int, int] = (224, 224)) -> Tuple[List[np.ndarray], List[np.ndarray]]:
318
+ """
319
+ Convenience function to extract hand frames from video.
320
+
321
+ Args:
322
+ video_input: Either a path to video file (str) or a decord.VideoReader object
323
+ landmarks_data: Dictionary containing pose and hand landmarks for each frame
324
+ output_size: Size of the output hand frames (width, height)
325
+
326
+ Returns:
327
+ Tuple of (left_hand_frames, right_hand_frames) as lists of numpy arrays
328
+ """
329
+ extractor = HandExtractor(output_size=output_size)
330
+ return extractor.extract_hand_frames(video_input, landmarks_data)
331
+
332
+
333
+ def video_holistic(video_file: str, hand_path: str, problem_file_path: str, pose_path: str):
334
+ """
335
+ Original function for backward compatibility with command-line usage.
336
+ """
337
+ try:
338
+ video = decord.VideoReader(video_file)
339
+ fps = video.get_avg_fps()
340
+
341
+ video_name = Path(video_file).stem
342
+ clip_hand1_path = Path(hand_path) / f"{video_name}_hand1.mp4"
343
+ clip_hand2_path = Path(hand_path) / f"{video_name}_hand2.mp4"
344
+ landmark_json_path = Path(pose_path) / f"{video_name}_pose.json"
345
+
346
+ # Load landmarks
347
+ with open(landmark_json_path, 'r') as rd:
348
+ landmarks_data = json.load(rd)
349
+
350
+ # Convert string keys to integers
351
+ landmarks_data = {int(k): v for k, v in landmarks_data.items()}
352
+
353
+ # Extract hand videos
354
+ extractor = HandExtractor()
355
+ extractor.extract_and_save_hand_videos(video, landmarks_data, hand_path, video_name)
356
+
357
+ except Exception as e:
358
+ print(f"Error processing {video_file}: {e}")
359
+ with open(problem_file_path, "a") as p:
360
+ p.write(video_file + "\n")
361
+
362
+
363
+ # Utility functions for batch processing
364
+ def load_file(filename: str):
365
+ """Load a pickled and gzipped file."""
366
+ with gzip.open(filename, "rb") as f:
367
+ return pickle.load(f)
368
+
369
+
370
+ def is_string_in_file(file_path: str, target_string: str) -> bool:
371
+ """Check if a string exists in a file."""
372
+ try:
373
+ with Path(file_path).open("r") as f:
374
+ for line in f:
375
+ if target_string in line:
376
+ return True
377
+ return False
378
+ except Exception as e:
379
+ print(f"Error: {e}")
380
+ return False
381
+
382
+
383
+ def main():
384
+ """Main function for command-line usage."""
385
+ parser = argparse.ArgumentParser()
386
+ parser.add_argument('--index', type=int, required=True,
387
+ help='index of the sub_list to work with')
388
+ parser.add_argument('--batch_size', type=int, required=True,
389
+ help='batch size')
390
+ parser.add_argument('--time_limit', type=int, required=True,
391
+ help='time limit')
392
+ parser.add_argument('--files_list', type=str, required=True,
393
+ help='files list')
394
+ parser.add_argument('--problem_file_path', type=str, required=True,
395
+ help='problem file path')
396
+ parser.add_argument('--pose_path', type=str, required=True,
397
+ help='pose path')
398
+ parser.add_argument('--hand_path', type=str, required=True,
399
+ help='hand path')
400
+
401
+ args = parser.parse_args()
402
+ start_time = time.time()
403
+
404
+ # Create directories if they do not exist
405
+ Path(args.hand_path).mkdir(parents=True, exist_ok=True)
406
+
407
+ # Load files list
408
+ fixed_list = load_file(args.files_list)
409
+
410
+ # Create problem file if it doesn't exist
411
+ if not os.path.exists(args.problem_file_path):
412
+ with open(args.problem_file_path, "w") as f:
413
+ f.write("")
414
+
415
+ # Process videos in batches
416
+ video_batches = [fixed_list[i:i + args.batch_size] for i in range(0, len(fixed_list), args.batch_size)]
417
+
418
+ for video_file in video_batches[args.index]:
419
+ current_time = time.time()
420
+ if current_time - start_time > args.time_limit:
421
+ print("Time limit reached. Stopping execution.")
422
+ break
423
+
424
+ video_name = Path(video_file).stem
425
+ clip_hand2_path = Path(args.hand_path) / f"{video_name}_hand2.mp4"
426
+
427
+ if clip_hand2_path.exists():
428
+ print(f"Skipping {video_file} - output already exists")
429
+ continue
430
+ elif is_string_in_file(args.problem_file_path, video_file):
431
+ print(f"Skipping {video_file} - found in problem file")
432
+ continue
433
+ else:
434
+ try:
435
+ print(f"Processing {video_file}")
436
+ video_holistic(video_file, args.hand_path, args.problem_file_path, args.pose_path)
437
+ print(f"Successfully processed {video_file}")
438
+ except Exception as e:
439
+ print(f"Error processing {video_file}: {e}")
440
+ with open(args.problem_file_path, "a") as p:
441
+ p.write(video_file + "\n")
442
+
443
+
444
+ if __name__ == "__main__":
445
+ main()
dinov2_features.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import decord
6
+ from decord import VideoReader
7
+ from decord import cpu, gpu
8
+ import numpy as np
9
+ import os
10
+ import pickle
11
+ import gzip
12
+ from pathlib import Path
13
+ import argparse
14
+ import json
15
+ import csv
16
+ import glob
17
+ import time
18
+ from typing import List, Union, Optional, Tuple
19
+
20
+
21
+ class DINOEmbedder:
22
+ """
23
+ A class for extracting DINOv2 embeddings from video frames or images.
24
+ """
25
+
26
+ def __init__(self, dino_model_path: str, batch_size: int = 128, device: Optional[str] = None):
27
+ """
28
+ Initialize the DINOEmbedder.
29
+
30
+ Args:
31
+ dino_model_path: Path to the fine-tuned DINOv2 model
32
+ batch_size: Batch size for processing frames
33
+ device: Device to use ('cuda' or 'cpu'). Auto-detected if None
34
+ """
35
+ self.dino_model_path = dino_model_path
36
+ self.batch_size = batch_size
37
+ self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+ # Initialize model
40
+ self.model = self._load_dino_model()
41
+ self.model.eval()
42
+
43
+ # Initialize transform
44
+ self.transform = transforms.Compose([
45
+ transforms.Resize((224, 224)),
46
+ transforms.ToTensor(),
47
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
48
+ ])
49
+
50
+ print(f"DINOEmbedder initialized on device: {self.device}")
51
+
52
+ def _load_dino_model(self) -> nn.Module:
53
+ """Load the fine-tuned DINOv2 model."""
54
+ # Load the original DINOv2 model with the correct architecture
55
+ model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg', pretrained=False)
56
+
57
+ # Load fine-tuned weights
58
+ pretrained = torch.load(self.dino_model_path, map_location=self.device)
59
+
60
+ # Make correct state dict for loading
61
+ new_state_dict = {}
62
+ for key, value in pretrained['teacher'].items():
63
+ if 'dino_head' in key:
64
+ continue # Skip dino_head layers
65
+ else:
66
+ new_key = key.replace('backbone.', '')
67
+ new_state_dict[new_key] = value
68
+
69
+ # Change shape of pos_embed
70
+ pos_embed = nn.Parameter(torch.zeros(1, 257, 384))
71
+ model.pos_embed = pos_embed
72
+
73
+ # Load state dict
74
+ model.load_state_dict(new_state_dict, strict=True)
75
+
76
+ # Move model to device
77
+ model.to(self.device)
78
+ return model
79
+
80
+ def _preprocess_frame(self, frame: np.ndarray) -> torch.Tensor:
81
+ """Preprocess a single frame."""
82
+ if isinstance(frame, np.ndarray):
83
+ image = Image.fromarray(frame)
84
+ else:
85
+ image = frame
86
+
87
+ tensor = self.transform(image)
88
+ # Ensure only RGB channels are considered
89
+ return tensor[:3]
90
+
91
+ def _preprocess_frames_batch(self, frames: List[np.ndarray]) -> torch.Tensor:
92
+ """Preprocess a batch of frames."""
93
+ batch_tensors = torch.stack([self._preprocess_frame(frame) for frame in frames])
94
+ return batch_tensors.to(self.device)
95
+
96
+ def extract_embeddings_from_frames(self, frames: List[np.ndarray]) -> np.ndarray:
97
+ """
98
+ Extract DINOv2 embeddings from a list of frames.
99
+
100
+ Args:
101
+ frames: List of frames as numpy arrays
102
+
103
+ Returns:
104
+ Numpy array of embeddings with shape (num_frames, embedding_dim)
105
+ """
106
+ all_embeddings = []
107
+
108
+ # Process frames in batches
109
+ for idx in range(0, len(frames), self.batch_size):
110
+ batch_frames = frames[idx:idx + self.batch_size]
111
+
112
+ # Preprocess batch
113
+ batch_tensors = self._preprocess_frames_batch(batch_frames)
114
+
115
+ # Extract embeddings
116
+ with torch.no_grad():
117
+ batch_embeddings = self.model(batch_tensors).cpu().numpy()
118
+
119
+ all_embeddings.append(batch_embeddings)
120
+
121
+ # Concatenate all embeddings
122
+ embeddings = np.concatenate(all_embeddings, axis=0)
123
+ return embeddings
124
+
125
+ def extract_embeddings_from_video(self, video_input: Union[str, VideoReader],
126
+ target_size: Tuple[int, int] = (224, 224)) -> np.ndarray:
127
+ """
128
+ Extract DINOv2 embeddings from a video.
129
+
130
+ Args:
131
+ video_input: Either a path to video file (str) or a VideoReader object
132
+ target_size: Target size for video frames (width, height)
133
+
134
+ Returns:
135
+ Numpy array of embeddings with shape (num_frames, embedding_dim)
136
+ """
137
+ # Handle different input types
138
+ if isinstance(video_input, str):
139
+ video_path = Path(video_input)
140
+ if not video_path.exists():
141
+ raise FileNotFoundError(f"Video file not found: {video_input}")
142
+ try:
143
+ vr = VideoReader(str(video_path), width=target_size[0], height=target_size[1])
144
+ except Exception as e:
145
+ raise RuntimeError(f"Error loading video {video_input}: {e}")
146
+ # elif hasattr(video_input, 'get_batch'):
147
+ else:
148
+ vr = video_input
149
+ # else:
150
+ # raise TypeError("video_input must be either a file path (str) or a VideoReader object")
151
+
152
+ total_frames = len(vr)
153
+ all_embeddings = []
154
+
155
+ # Process video in batches
156
+ for idx in range(0, total_frames, self.batch_size):
157
+ batch_indices = range(idx, min(idx + self.batch_size, total_frames))
158
+ # batch_frames = vr.get_batch(batch_indices).asnumpy()
159
+ batch_frames = vr[batch_indices]
160
+
161
+ # Preprocess batch
162
+ batch_tensors = self._preprocess_frames_batch(batch_frames)
163
+
164
+ # Extract embeddings
165
+ with torch.no_grad():
166
+ batch_embeddings = self.model(batch_tensors).cpu().numpy()
167
+
168
+ all_embeddings.append(batch_embeddings)
169
+
170
+ # Concatenate all embeddings
171
+ embeddings = np.concatenate(all_embeddings, axis=0)
172
+ return embeddings
173
+
174
+ def extract_embeddings_from_video_and_save(self, video_path: str, output_folder: str) -> str:
175
+ """
176
+ Extract embeddings from video and save to file.
177
+
178
+ Args:
179
+ video_path: Path to the video file
180
+ output_folder: Folder to save the embeddings
181
+
182
+ Returns:
183
+ Path to the saved embeddings file
184
+ """
185
+ # Create output folder if it doesn't exist
186
+ Path(output_folder).mkdir(parents=True, exist_ok=True)
187
+
188
+ # Extract embeddings
189
+ embeddings = self.extract_embeddings_from_video(video_path)
190
+
191
+ # Save embeddings
192
+ video_name = Path(video_path).stem
193
+ np_path = Path(output_folder) / f"{video_name}.npy"
194
+ np.save(np_path, embeddings)
195
+
196
+ return str(np_path)
197
+
198
+ def extract_embedding_from_single_image(self, image: Union[np.ndarray, Image.Image]) -> np.ndarray:
199
+ """
200
+ Extract DINOv2 embedding from a single image.
201
+
202
+ Args:
203
+ image: Image as numpy array or PIL Image
204
+
205
+ Returns:
206
+ Numpy array of embedding with shape (1, embedding_dim)
207
+ """
208
+ # Preprocess image
209
+ if isinstance(image, np.ndarray):
210
+ image = Image.fromarray(image)
211
+
212
+ tensor = self.transform(image).unsqueeze(0).to(self.device)
213
+
214
+ # Extract embedding
215
+ with torch.no_grad():
216
+ embedding = self.model(tensor).cpu().numpy()
217
+
218
+ return embedding
219
+
220
+
221
+ # Convenience functions for backward compatibility
222
+ def extract_embeddings_from_frames(frames: List[np.ndarray], dino_model_path: str,
223
+ batch_size: int = 128) -> np.ndarray:
224
+ """
225
+ Convenience function to extract embeddings from frames.
226
+
227
+ Args:
228
+ frames: List of frames as numpy arrays
229
+ dino_model_path: Path to the fine-tuned DINOv2 model
230
+ batch_size: Batch size for processing
231
+
232
+ Returns:
233
+ Numpy array of embeddings
234
+ """
235
+ embedder = DINOEmbedder(dino_model_path, batch_size)
236
+ return embedder.extract_embeddings_from_frames(frames)
237
+
238
+
239
+ def extract_embeddings_from_video(video_path: str, dino_model_path: str,
240
+ batch_size: int = 128) -> np.ndarray:
241
+ """
242
+ Convenience function to extract embeddings from video.
243
+
244
+ Args:
245
+ video_path: Path to the video file
246
+ dino_model_path: Path to the fine-tuned DINOv2 model
247
+ batch_size: Batch size for processing
248
+
249
+ Returns:
250
+ Numpy array of embeddings
251
+ """
252
+ embedder = DINOEmbedder(dino_model_path, batch_size)
253
+ return embedder.extract_embeddings_from_video(video_path)
254
+
255
+
256
+ def video_to_embeddings(video_path: str, output_folder: str, dino_path: str, batch_size: int = 128):
257
+ """
258
+ Original function for backward compatibility with command-line usage.
259
+ """
260
+ try:
261
+ embedder = DINOEmbedder(dino_path, batch_size)
262
+ embedder.extract_embeddings_from_video_and_save(video_path, output_folder)
263
+ except Exception as e:
264
+ print(f'Error processing {video_path}: {e}')
265
+
266
+
267
+ # Utility functions for batch processing
268
+ def get_mp4_files(directory: str) -> List[str]:
269
+ """Get all MP4 files in a directory."""
270
+ if not os.path.exists(directory):
271
+ raise FileNotFoundError(f'Directory not found: {directory}')
272
+
273
+ mp4_files = glob.glob(os.path.join(directory, '*.mp4'))
274
+ return [os.path.abspath(file) for file in mp4_files]
275
+
276
+
277
+ def load_file(filename: str):
278
+ """Load a pickled and gzipped file."""
279
+ with gzip.open(filename, "rb") as f:
280
+ return pickle.load(f)
281
+
282
+
283
+ def is_string_in_file(file_path: str, target_string: str) -> bool:
284
+ """Check if a string exists in a file."""
285
+ try:
286
+ with Path(file_path).open("r") as f:
287
+ for line in f:
288
+ if target_string in line:
289
+ return True
290
+ return False
291
+ except Exception as e:
292
+ print(f"Error: {e}")
293
+ return False
294
+
295
+
296
+ def main():
297
+ """Main function for command-line usage."""
298
+ parser = argparse.ArgumentParser()
299
+ parser.add_argument('--index', type=int, required=True,
300
+ help='index of the sub_list to work with')
301
+ parser.add_argument('--time_limit', type=int, required=True,
302
+ help='time limit in seconds')
303
+ parser.add_argument('--batch_size', type=int, required=True,
304
+ help='number of videos to process in this batch')
305
+ parser.add_argument('--files_list', type=str, required=True,
306
+ help='path to the files list file')
307
+ parser.add_argument('--output_folder', type=str, required=True,
308
+ help='path to the output folder')
309
+ parser.add_argument('--dino_path', type=str, required=True,
310
+ help='path to the dino model')
311
+
312
+ args = parser.parse_args()
313
+ start_time = time.time()
314
+
315
+ # Load files list
316
+ fixed_list = load_file(args.files_list)
317
+
318
+ # Create output folder if it doesn't exist
319
+ if not os.path.exists(args.output_folder):
320
+ os.makedirs(args.output_folder)
321
+
322
+ # Initialize embedder
323
+ embedder = DINOEmbedder(args.dino_path, batch_size=512)
324
+
325
+ # Process videos in batches
326
+ video_batches = [fixed_list[i:i + args.batch_size] for i in range(0, len(fixed_list), args.batch_size)]
327
+ print(f"Total number of video batches: {len(video_batches)}")
328
+
329
+ for video_path in video_batches[args.index]:
330
+ current_time = time.time()
331
+ if current_time - start_time > args.time_limit:
332
+ print("Time limit reached. Stopping execution.")
333
+ break
334
+
335
+ video_name = Path(video_path).stem
336
+ np_path = Path(args.output_folder) / f"{video_name}.npy"
337
+
338
+ if np_path.exists():
339
+ print(f"Skipping {video_path} - output already exists")
340
+ continue
341
+ else:
342
+ try:
343
+ print(f"Processing {video_path}")
344
+ embedder.extract_embeddings_from_video_and_save(video_path, args.output_folder)
345
+ print(f"Successfully processed {video_path}")
346
+ except Exception as e:
347
+ print(f"Error processing {video_path}: {e}")
348
+
349
+
350
+ if __name__ == "__main__":
351
+ main()
features.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import decord
5
+ import torch.nn as nn
6
+ import json
7
+ import cv2
8
+ from kpe_mediapipe import video_holistic
9
+ from crop_hands import HandExtractor
10
+ from crop_face import FaceExtractor
11
+ from dinov2_features import extract_embeddings_from_frames
12
+ from body_features import process_pose_landmarks
13
+ # from shubert import SignHubertModel, SignHubertConfig
14
+ from inference import test
15
+ import subprocess
16
+
17
+
18
+
19
+ class SHuBERTProcessor:
20
+
21
+ def __init__(self, config):
22
+ self.config = config
23
+
24
+ def process_video(self, video_path):
25
+
26
+ # output_file = f"{output_path}/{os.path.basename(video_file)}"
27
+
28
+
29
+ # # Target FPS is 12.5
30
+ # cmd = [
31
+ # 'ffmpeg',
32
+ # '-i', video_path,
33
+ # '-filter:v', 'fps=15',
34
+ # '-c:v', 'libx264',
35
+ # '-preset', 'medium', # Balance between speed and quality
36
+ # '-crf', '23', # Quality level (lower is better)
37
+ # '-y', # Overwrite output file if it exists
38
+ # video_path
39
+ # ]
40
+
41
+
42
+ # try:
43
+ # subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
44
+ # print(f"Saved to {video_path} at 15 fps")
45
+ # except subprocess.CalledProcessError as e:
46
+ # print(f"Error reading video {video_path}: {e}")
47
+
48
+
49
+
50
+ # Step 1: Change the fps to 15
51
+ signer_video = decord.VideoReader(video_path)
52
+
53
+ signer_video_fps = signer_video.get_avg_fps()
54
+ target_fps = 12
55
+ stride = max(1, int(round(signer_video_fps / target_fps)))
56
+ index_list = list(range(0, len(signer_video), stride))
57
+ signer_video = signer_video.get_batch(index_list)
58
+ signer_video = signer_video.asnumpy()
59
+
60
+ # Step 2: Extract pose using kpe_mediapipe
61
+ landmarks = video_holistic(
62
+ video_input=signer_video,
63
+ face_model_path=self.config['mediapipe_face_model_path'],
64
+ hand_model_path=self.config['mediapipe_hands_model_path'],
65
+ )
66
+
67
+ # Step 3: Extract stream features
68
+ hand_extractor = HandExtractor()
69
+ left_hand_frames, right_hand_frames = hand_extractor.extract_hand_frames(signer_video, landmarks)
70
+ left_hand_embeddings = extract_embeddings_from_frames(left_hand_frames, self.config['dino_hands_model_path'])
71
+ right_hand_embeddings = extract_embeddings_from_frames(right_hand_frames, self.config['dino_hands_model_path'])
72
+ del left_hand_frames, right_hand_frames
73
+
74
+ face_extractor = FaceExtractor()
75
+ face_frames = face_extractor.extract_face_frames(signer_video, landmarks)
76
+ face_embeddings = extract_embeddings_from_frames(face_frames, self.config['dino_face_model_path'])
77
+ del face_frames, signer_video
78
+
79
+ pose_embeddings = process_pose_landmarks(landmarks)
80
+ del landmarks
81
+
82
+ output_text = test(face_embeddings,
83
+ left_hand_embeddings,
84
+ right_hand_embeddings,
85
+ pose_embeddings,
86
+ self.config['slt_model_config'],
87
+ self.config['slt_model_checkpoint'],
88
+ self.config['slt_tokenizer_checkpoint'],
89
+ self.config['temp_dir'])
90
+
91
+ return output_text
92
+
93
+ if __name__ == "__main__":
94
+ config = {
95
+ 'yolov8_model_path': '/share/data/pals/shester/inference/models/yolov8n.pt',
96
+ 'dino_face_model_path': '/share/data/pals/shester/inference/models/dinov2face.pth',
97
+ 'dino_hands_model_path': '/share/data/pals/shester/inference/models/dinov2hand.pth',
98
+ 'mediapipe_face_model_path': '/share/data/pals/shester/inference/models/face_landmarker_v2_with_blendshapes.task',
99
+ 'mediapipe_hands_model_path': '/share/data/pals/shester/inference/models/hand_landmarker.task',
100
+ 'shubert_model_path': '/share/data/pals/shester/inference/models/checkpoint_836_400000.pt',
101
+ 'temp_dir': '/share/data/pals/shester/inference',
102
+ 'slt_model_config': '/share/data/pals/shester/inference/models/byt5_base/config.json',
103
+ 'slt_model_checkpoint': '/share/data/pals/shester/inference/models/checkpoint-11625',
104
+ 'slt_tokenizer_checkpoint': '/share/data/pals/shester/inference/models/byt5_base',
105
+ }
106
+
107
+ # input_clip = "/share/data/pals/shester/datasets/openasl/clips_bbox/J-0KHhPS_m4.029676-029733.mp4"
108
+ # input_clip = "/share/data/pals/shester/inference/recordings/sabrin30fps.mp4"
109
+ input_clip = "/share/data/pals/shester/inference/recordings/sabrina30fps.mp4"
110
+ processor = SHuBERTProcessor(config)
111
+ output_text = processor.process_video(input_clip)
112
+ print(f"The English translation is: {output_text}")
113
+
114
+ # /home-nfs/shesterg/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/attention.py
115
+ # /home-nfs/shesterg/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/block.py
inference.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ import random
8
+ import warnings
9
+
10
+ from transformers import (
11
+ ByT5Tokenizer,
12
+ Seq2SeqTrainingArguments,
13
+ Seq2SeqTrainer,
14
+ )
15
+
16
+ from transformers.models.t5 import T5Config
17
+ from transformers.models.t5.modeling_t5 import *
18
+
19
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput
20
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
21
+ from torch.nn import CrossEntropyLoss
22
+
23
+ from collections.abc import Mapping
24
+ from dataclasses import dataclass
25
+ from random import randint
26
+ from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
27
+ from transformers.utils import PaddingStrategy
28
+
29
+ from shubert import SignHubertModel, SignHubertConfig
30
+
31
+ class SignHubertAdapter(nn.Module):
32
+ def __init__(self, channels):
33
+ super().__init__()
34
+ # Adjust intermediate_dim based on number of channels
35
+ intermediate_dim_shubert = 1024
36
+
37
+ self.signhubert = SignHubertModel(SignHubertConfig(
38
+ channels=channels,
39
+ intermediate_dim=intermediate_dim_shubert
40
+ ))
41
+
42
+ def forward(self, x):
43
+ features = self.signhubert.extract_features(x, padding_mask=None, kmeans_labels=None, mask=False)
44
+
45
+ # Extract layer outputs
46
+ layer_outputs = []
47
+ for layer in features['layer_results']:
48
+ layer_output = layer[-1] # Shape: [B, T, D]
49
+ layer_outputs.append(layer_output)
50
+
51
+ # Stack the outputs from all layers
52
+ stacked_features = torch.stack(layer_outputs, dim=1) # Shape: [B, L, T, D]
53
+ return stacked_features
54
+
55
+ class LinearAdapter(nn.Module):
56
+ def __init__(self, face_dim, hand_dim, pose_dim, representations_dim, out_dim, extraction_layer, channels):
57
+ super().__init__()
58
+
59
+ self.signhubert_adapter = SignHubertAdapter(channels)
60
+ self.layer_weights = nn.Parameter(torch.ones(12)) # Learnable weights for each layer
61
+ self.final_layer = nn.Linear(representations_dim, out_dim)
62
+ self.extraction_layer = extraction_layer
63
+
64
+ def forward(self, face_features, left_hand_features, right_hand_features, body_posture_features):
65
+ dtype = torch.float32
66
+ face_features = face_features.to(dtype=dtype)
67
+ left_hand_features = left_hand_features.to(dtype=dtype)
68
+ right_hand_features = right_hand_features.to(dtype=dtype)
69
+ body_posture_features = body_posture_features.to(dtype=dtype)
70
+
71
+ batch_size, seq_len = face_features.shape[:2]
72
+ dummy_labels = torch.zeros((seq_len, 1), dtype=dtype, device=face_features.device)
73
+
74
+ source = []
75
+ for i in range(batch_size):
76
+ source.append({
77
+ "face": face_features[i],
78
+ "left_hand": left_hand_features[i],
79
+ "right_hand": right_hand_features[i],
80
+ "body_posture": body_posture_features[i],
81
+ "label_face": dummy_labels,
82
+ "label_left_hand": dummy_labels,
83
+ "label_right_hand": dummy_labels,
84
+ "label_body_posture": dummy_labels
85
+ })
86
+
87
+ # Get representations from SignHubert
88
+ representations_features = self.signhubert_adapter(source) # [T, L, B, D]
89
+ representations_features = representations_features.permute(2, 1, 0, 3) # [B, L, T, D]
90
+
91
+ if self.extraction_layer == 0:
92
+ normalized_weights = self.layer_weights
93
+ weighted_representations = representations_features * normalized_weights.view(1, -1, 1, 1)
94
+ representations_for_downstream_task = torch.sum(weighted_representations, dim=1)
95
+ else:
96
+ representations_for_downstream_task = representations_features[:, self.extraction_layer-1, :, :]
97
+
98
+ final_output = self.final_layer(representations_for_downstream_task)
99
+
100
+ return final_output
101
+
102
+ class SignLanguageByT5Config(T5Config):
103
+ def __init__(
104
+ self,
105
+ representations_dim=768,
106
+ adapter="linear",
107
+ finetune_signhubert=False,
108
+ face_dim=384,
109
+ hand_dim=384,
110
+ pose_dim=14,
111
+ extraction_layer=0, # use last layer
112
+ channels="face,left_hand,right_hand,body_posture",
113
+ **kwargs
114
+ ):
115
+ self.representations_dim = representations_dim
116
+ self.adapter = adapter
117
+ self.finetune_signhubert = finetune_signhubert
118
+ self.face_dim = face_dim
119
+ self.hand_dim = hand_dim
120
+ self.pose_dim = pose_dim
121
+ self.extraction_layer = extraction_layer
122
+ self.channels = channels
123
+ super().__init__(**kwargs)
124
+
125
+ class SignLanguageByT5Encoder(T5PreTrainedModel):
126
+ def __init__(self, config):
127
+ super().__init__(config)
128
+
129
+ # Initialize the adapter based on the configuration
130
+ if config.adapter == "linear":
131
+ self.adapter = LinearAdapter(
132
+ config.face_dim,
133
+ config.hand_dim,
134
+ config.pose_dim,
135
+ config.representations_dim,
136
+ config.d_model,
137
+ config.extraction_layer,
138
+ config.channels
139
+ )
140
+ else:
141
+ raise NotImplementedError("Adapter type not implemented.")
142
+
143
+ self.is_decoder = config.is_decoder
144
+
145
+ # Define the encoder blocks
146
+ self.block = nn.ModuleList(
147
+ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
148
+ )
149
+ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
150
+ self.dropout = nn.Dropout(config.dropout_rate)
151
+
152
+ # Initialize weights and apply final processing
153
+ self.post_init()
154
+
155
+ # Model parallel settings
156
+ self.model_parallel = False
157
+ self.device_map = None
158
+ self.gradient_checkpointing = False
159
+
160
+ def parallelize(self, device_map=None):
161
+ warnings.warn(
162
+ "`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
163
+ " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
164
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
165
+ " 'block.1': 1, ...}",
166
+ FutureWarning,
167
+ )
168
+ # Check validity of device_map
169
+ self.device_map = (
170
+ get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
171
+ )
172
+ assert_device_map(self.device_map, len(self.block))
173
+ self.model_parallel = True
174
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
175
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
176
+ # Load onto devices
177
+ for k, v in self.device_map.items():
178
+ for layer in v:
179
+ cuda_device = "cuda:" + str(k)
180
+ self.block[layer] = self.block[layer].to(cuda_device)
181
+
182
+ # Set embed_tokens to first layer
183
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
184
+ # Set final layer norm to last device
185
+ self.final_layer_norm = self.final_layer_norm.to(self.last_device)
186
+
187
+ def deparallelize(self):
188
+ warnings.warn(
189
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
190
+ FutureWarning,
191
+ )
192
+ self.model_parallel = False
193
+ self.device_map = None
194
+ self.first_device = "cpu"
195
+ self.last_device = "cpu"
196
+ for i in range(len(self.block)):
197
+ self.block[i] = self.block[i].to("cpu")
198
+ self.embed_tokens = self.embed_tokens.to("cpu")
199
+ self.final_layer_norm = self.final_layer_norm.to("cpu")
200
+ torch.cuda.empty_cache()
201
+
202
+ def get_input_embeddings(self):
203
+ return self.embed_tokens
204
+
205
+ def set_input_embeddings(self, new_embeddings):
206
+ self.embed_tokens = new_embeddings
207
+
208
+ def forward(
209
+ self,
210
+ face_features=None,
211
+ left_hand_features=None,
212
+ right_hand_features=None,
213
+ pose_features=None,
214
+ attention_mask=None,
215
+ head_mask=None,
216
+ encoder_hidden_states=None,
217
+ encoder_attention_mask=None,
218
+ cross_attn_head_mask=None,
219
+ past_key_values=None,
220
+ use_cache=None,
221
+ output_attentions=None,
222
+ output_hidden_states=None,
223
+ return_dict=None,
224
+ ):
225
+ # Set default values if not provided
226
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
227
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
228
+ output_hidden_states = (
229
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
230
+ )
231
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
232
+
233
+ # Use the adapter to convert representation features into embeddings
234
+ inputs_embeds = self.adapter(face_features, left_hand_features, right_hand_features, pose_features)
235
+
236
+ input_shape = inputs_embeds.shape[:2]
237
+ batch_size, seq_length = input_shape
238
+
239
+ mask_seq_length = seq_length
240
+
241
+ if attention_mask is None:
242
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
243
+
244
+ # Initialize past_key_values if not provided
245
+ if past_key_values is None:
246
+ past_key_values = [None] * len(self.block)
247
+
248
+ # Extend attention mask
249
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
250
+
251
+ # Prepare head mask if needed
252
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
253
+ present_key_value_states = () if use_cache else None
254
+ all_hidden_states = () if output_hidden_states else None
255
+ all_attentions = () if output_attentions else None
256
+
257
+ hidden_states = self.dropout(inputs_embeds)
258
+
259
+ # Iterate over each encoder block
260
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
261
+ layer_head_mask = head_mask[i]
262
+
263
+ if output_hidden_states:
264
+ all_hidden_states = all_hidden_states + (hidden_states,)
265
+
266
+ layer_outputs = layer_module(
267
+ hidden_states,
268
+ attention_mask=extended_attention_mask,
269
+ position_bias=None,
270
+ encoder_hidden_states=encoder_hidden_states,
271
+ encoder_attention_mask=encoder_attention_mask,
272
+ encoder_decoder_position_bias=None,
273
+ layer_head_mask=layer_head_mask,
274
+ cross_attn_layer_head_mask=cross_attn_head_mask,
275
+ past_key_value=past_key_value,
276
+ use_cache=use_cache,
277
+ output_attentions=output_attentions,
278
+ )
279
+
280
+ hidden_states = layer_outputs[0]
281
+
282
+ if use_cache:
283
+ present_key_value_states = present_key_value_states + (layer_outputs[1],)
284
+
285
+ if output_attentions:
286
+ all_attentions = all_attentions + (layer_outputs[2],)
287
+
288
+ hidden_states = self.final_layer_norm(hidden_states)
289
+ hidden_states = self.dropout(hidden_states)
290
+
291
+ # Add last hidden state
292
+ if output_hidden_states:
293
+ all_hidden_states = all_hidden_states + (hidden_states,)
294
+
295
+ if not return_dict:
296
+ return tuple(
297
+ v
298
+ for v in [hidden_states, present_key_value_states, all_hidden_states, all_attentions]
299
+ if v is not None
300
+ )
301
+ return BaseModelOutputWithPastAndCrossAttentions(
302
+ last_hidden_state=hidden_states,
303
+ past_key_values=present_key_value_states,
304
+ hidden_states=all_hidden_states,
305
+ attentions=all_attentions,
306
+ cross_attentions=None,
307
+ )
308
+
309
+ class SignLanguageByT5ForConditionalGeneration(T5PreTrainedModel):
310
+ _keys_to_ignore_on_load_unexpected = [
311
+ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
312
+ ]
313
+ _tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
314
+
315
+ def __init__(self, config: T5Config):
316
+ super().__init__(config)
317
+ self.model_dim = config.d_model
318
+
319
+ # Initialize the decoder embedding
320
+ self.decoder_emb = nn.Embedding(config.vocab_size, config.d_model)
321
+
322
+ # Initialize the encoder with the custom SignLanguageByT5Encoder
323
+ encoder_config = copy.deepcopy(config)
324
+ encoder_config.is_decoder = False
325
+ encoder_config.use_cache = False
326
+ encoder_config.is_encoder_decoder = False
327
+ self.encoder = SignLanguageByT5Encoder(encoder_config)
328
+
329
+ # Initialize the decoder
330
+ decoder_config = copy.deepcopy(config)
331
+ decoder_config.is_decoder = True
332
+ decoder_config.is_encoder_decoder = False
333
+ decoder_config.num_layers = config.num_decoder_layers
334
+ self.decoder = T5Stack(decoder_config, self.decoder_emb)
335
+
336
+ # Initialize the language modeling head
337
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
338
+
339
+ # Initialize weights and apply final processing
340
+ self.post_init()
341
+
342
+ # Model parallel settings
343
+ self.model_parallel = False
344
+ self.device_map = None
345
+
346
+ def parallelize(self, device_map=None):
347
+ warnings.warn(
348
+ "`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
349
+ " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
350
+ " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
351
+ " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
352
+ FutureWarning,
353
+ )
354
+ self.device_map = (
355
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
356
+ if device_map is None
357
+ else device_map
358
+ )
359
+ assert_device_map(self.device_map, len(self.encoder.block))
360
+ self.encoder.parallelize(self.device_map)
361
+ self.decoder.parallelize(self.device_map)
362
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
363
+ self.model_parallel = True
364
+
365
+ def deparallelize(self):
366
+ warnings.warn(
367
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
368
+ FutureWarning,
369
+ )
370
+ self.encoder.deparallelize()
371
+ self.decoder.deparallelize()
372
+ self.encoder = self.encoder.to("cpu")
373
+ self.decoder = self.decoder.to("cpu")
374
+ self.lm_head = self.lm_head.to("cpu")
375
+ self.model_parallel = False
376
+ self.device_map = None
377
+ torch.cuda.empty_cache()
378
+
379
+ def get_input_embeddings(self):
380
+ return self.decoder_emb
381
+
382
+ def set_input_embeddings(self, new_embeddings):
383
+ self.shared = new_embeddings
384
+ self.encoder.set_input_embeddings(new_embeddings)
385
+ self.decoder.set_input_embeddings(new_embeddings)
386
+
387
+ def set_output_embeddings(self, new_embeddings):
388
+ self.lm_head = new_embeddings
389
+
390
+ def get_output_embeddings(self):
391
+ return self.lm_head
392
+
393
+ def get_encoder(self):
394
+ return self.encoder
395
+
396
+ def get_decoder(self):
397
+ return self.decoder
398
+
399
+ def prepare_inputs_for_generation(
400
+ self,
401
+ input_ids,
402
+ past_key_values=None,
403
+ attention_mask=None,
404
+ head_mask=None,
405
+ decoder_head_mask=None,
406
+ decoder_attention_mask=None,
407
+ cross_attn_head_mask=None,
408
+ use_cache=None,
409
+ encoder_outputs=None,
410
+ **kwargs,
411
+ ):
412
+ # cut decoder_input_ids if past is used
413
+ if past_key_values is not None:
414
+ input_ids = input_ids[:, -1:]
415
+
416
+ return {
417
+ "decoder_input_ids": input_ids,
418
+ "past_key_values": past_key_values,
419
+ "encoder_outputs": encoder_outputs,
420
+ "attention_mask": attention_mask,
421
+ "head_mask": head_mask,
422
+ "decoder_head_mask": decoder_head_mask,
423
+ "decoder_attention_mask": decoder_attention_mask,
424
+ "cross_attn_head_mask": cross_attn_head_mask,
425
+ "use_cache": use_cache,
426
+ }
427
+
428
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
429
+ return self._shift_right(labels)
430
+
431
+ def _reorder_cache(self, past_key_values, beam_idx):
432
+ # if decoder past is not included in output
433
+ # speedy decoding is disabled and no need to reorder
434
+ if past_key_values is None:
435
+ logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
436
+ return past_key_values
437
+
438
+ reordered_decoder_past = ()
439
+ for layer_past_states in past_key_values:
440
+ # get the correct batch idx from layer past batch dim
441
+ # batch dim of `past` is at 2nd position
442
+ reordered_layer_past_states = ()
443
+ for layer_past_state in layer_past_states:
444
+ # need to set correct `past` for each of the four key / value states
445
+ reordered_layer_past_states = reordered_layer_past_states + (
446
+ layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
447
+ )
448
+
449
+ if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
450
+ raise ValueError(
451
+ f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
452
+ )
453
+ if len(reordered_layer_past_states) != len(layer_past_states):
454
+ raise ValueError(
455
+ f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
456
+ )
457
+
458
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
459
+ return reordered_decoder_past
460
+
461
+ def forward(
462
+ self,
463
+ face_features=None,
464
+ left_hand_features=None,
465
+ right_hand_features=None,
466
+ pose_features=None,
467
+ attention_mask=None,
468
+ decoder_input_ids=None,
469
+ decoder_attention_mask=None,
470
+ head_mask=None,
471
+ decoder_head_mask=None,
472
+ cross_attn_head_mask=None,
473
+ encoder_outputs=None,
474
+ past_key_values=None,
475
+ decoder_inputs_embeds=None,
476
+ labels=None, # Keep this for training compatibility
477
+ use_cache=None,
478
+ output_attentions=None,
479
+ output_hidden_states=None,
480
+ return_dict=None,
481
+ ):
482
+ # Set default values if not provided
483
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
484
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
485
+
486
+ # Prepare head masks if needed
487
+ if head_mask is not None and decoder_head_mask is None:
488
+ if self.config.num_layers == self.config.num_decoder_layers:
489
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
490
+ decoder_head_mask = head_mask
491
+
492
+ # Encode if encoder outputs are not provided
493
+ if encoder_outputs is None:
494
+ encoder_outputs = self.encoder(
495
+ face_features=face_features,
496
+ left_hand_features=left_hand_features,
497
+ right_hand_features=right_hand_features,
498
+ pose_features=pose_features,
499
+ attention_mask=attention_mask,
500
+ head_mask=head_mask,
501
+ output_attentions=output_attentions,
502
+ output_hidden_states=output_hidden_states,
503
+ return_dict=return_dict,
504
+ )
505
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutputWithPastAndCrossAttentions):
506
+ encoder_outputs = BaseModelOutputWithPastAndCrossAttentions(
507
+ last_hidden_state=encoder_outputs[0],
508
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
509
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
510
+ )
511
+
512
+ hidden_states = encoder_outputs[0]
513
+
514
+ # Prepare decoder inputs
515
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
516
+ decoder_input_ids = self._shift_right(labels)
517
+
518
+ # Decode
519
+ decoder_outputs = self.decoder(
520
+ input_ids=decoder_input_ids,
521
+ attention_mask=decoder_attention_mask,
522
+ inputs_embeds=decoder_inputs_embeds,
523
+ past_key_values=past_key_values,
524
+ encoder_hidden_states=hidden_states,
525
+ encoder_attention_mask=attention_mask,
526
+ head_mask=decoder_head_mask,
527
+ cross_attn_head_mask=cross_attn_head_mask,
528
+ use_cache=use_cache,
529
+ output_attentions=output_attentions,
530
+ output_hidden_states=output_hidden_states,
531
+ return_dict=return_dict,
532
+ )
533
+
534
+ sequence_output = decoder_outputs[0]
535
+
536
+ # Scale sequence output if embeddings are tied
537
+ if self.config.tie_word_embeddings:
538
+ sequence_output = sequence_output * (self.model_dim ** -0.5)
539
+
540
+ # Compute language modeling logits
541
+ lm_logits = self.lm_head(sequence_output)
542
+
543
+ loss = None
544
+ if labels is not None:
545
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
546
+ labels = labels.to(lm_logits.device)
547
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
548
+
549
+ if not return_dict:
550
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
551
+ return ((loss,) + output) if loss is not None else output
552
+
553
+ return Seq2SeqLMOutput(
554
+ loss=loss,
555
+ logits=lm_logits,
556
+ past_key_values=decoder_outputs.past_key_values,
557
+ decoder_hidden_states=decoder_outputs.hidden_states,
558
+ decoder_attentions=decoder_outputs.attentions,
559
+ cross_attentions=decoder_outputs.cross_attentions,
560
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
561
+ encoder_hidden_states=encoder_outputs.hidden_states,
562
+ encoder_attentions=encoder_outputs.attentions,
563
+ )
564
+
565
+ def generate(
566
+ self,
567
+ face_features=None,
568
+ left_hand_features=None,
569
+ right_hand_features=None,
570
+ pose_features=None,
571
+ attention_mask=None,
572
+ **kwargs
573
+ ):
574
+ """
575
+ Generate method to handle sign language features and generate output sequences.
576
+ """
577
+ # Compute encoder outputs using sign language features
578
+ encoder_outputs = self.encoder(
579
+ face_features=face_features,
580
+ left_hand_features=left_hand_features,
581
+ right_hand_features=right_hand_features,
582
+ pose_features=pose_features,
583
+ attention_mask=attention_mask,
584
+ return_dict=True
585
+ )
586
+
587
+ # Pass encoder outputs to the decoder
588
+ kwargs["encoder_outputs"] = encoder_outputs
589
+
590
+ # Generate sequences using the parent class's generate method
591
+ return super().generate(
592
+ attention_mask=attention_mask,
593
+ **kwargs
594
+ )
595
+
596
+ @dataclass
597
+ class SignLanguageT5Collator:
598
+ model: Optional[Any] = None
599
+ padding: Union[bool, str, PaddingStrategy] = True
600
+ max_length: Optional[int] = None
601
+ pad_to_multiple_of: Optional[int] = None
602
+ label_pad_token_id: int = -100
603
+ return_tensors: str = "pt"
604
+
605
+ def __call__(self, features, return_tensors=None):
606
+ if return_tensors is None:
607
+ return_tensors = self.return_tensors
608
+
609
+ face_embeds = [feature["face_features"] for feature in features]
610
+ left_hand_embeds = [feature["left_hand_features"] for feature in features]
611
+ right_hand_embeds = [feature["right_hand_features"] for feature in features]
612
+ pose_embeds = [feature["pose_features"] for feature in features]
613
+
614
+ # Padding
615
+ max_len = max([emb.shape[0] for emb in face_embeds])
616
+
617
+ def pad_embeds(embeds):
618
+ padded_embeds = []
619
+ for emb in embeds:
620
+ if emb.dim() == 3: # For 3D tensors (pose features)
621
+ pad_len = max_len - emb.shape[1] # padding the second dimension (T)
622
+ emb_pad = torch.nn.functional.pad(emb, (0, 0, 0, pad_len, 0, 0), value=0)
623
+ else: # For 2D tensors (face, hand features)
624
+ pad_len = max_len - emb.shape[0]
625
+ emb_pad = torch.nn.functional.pad(emb, (0, 0, 0, pad_len), value=0)
626
+ padded_embeds.append(emb_pad)
627
+ return padded_embeds
628
+
629
+ padded_face_embeds = pad_embeds(face_embeds)
630
+ padded_left_hand_embeds = pad_embeds(left_hand_embeds)
631
+ padded_right_hand_embeds = pad_embeds(right_hand_embeds)
632
+ padded_pose_embeds = pad_embeds(pose_embeds)
633
+
634
+ batch = {}
635
+ batch["face_features"] = torch.stack(padded_face_embeds, dim=0)
636
+ batch["left_hand_features"] = torch.stack(padded_left_hand_embeds, dim=0)
637
+ batch["right_hand_features"] = torch.stack(padded_right_hand_embeds, dim=0)
638
+ batch["pose_features"] = torch.stack(padded_pose_embeds, dim=0)
639
+
640
+ # For inference, we don't need decoder_input_ids - the model.generate() will handle this
641
+ # Remove the decoder_input_ids requirement
642
+ return batch
643
+
644
+ class TranslationFeatures(torch.utils.data.Dataset):
645
+ def __init__(self, face_embeddings, left_hand_embeddings, right_hand_embeddings, body_posture_embeddings):
646
+ self.face_embeddings = face_embeddings
647
+ self.left_hand_embeddings = left_hand_embeddings
648
+ self.right_hand_embeddings = right_hand_embeddings
649
+ self.body_posture_embeddings = body_posture_embeddings
650
+
651
+ def __len__(self):
652
+ return 1
653
+
654
+ def __getitem__(self, idx):
655
+ return {
656
+ "face_features": torch.tensor(self.face_embeddings),
657
+ "left_hand_features": torch.tensor(self.left_hand_embeddings),
658
+ "right_hand_features": torch.tensor(self.right_hand_embeddings),
659
+ "pose_features": torch.tensor(self.body_posture_embeddings),
660
+ }
661
+
662
+ def generate_text_from_features(
663
+ face_embeddings: np.ndarray,
664
+ left_hand_embeddings: np.ndarray,
665
+ right_hand_embeddings: np.ndarray,
666
+ body_posture_embeddings: np.ndarray,
667
+ model_config: str,
668
+ model_checkpoint: str,
669
+ tokenizer_checkpoint: str,
670
+ output_dir: str,
671
+ generation_max_length: int = 2048,
672
+ generation_num_beams: int = 5,
673
+ ):
674
+ """
675
+ Direct inference function that generates text from sign language features.
676
+ """
677
+ # Load model and tokenizer
678
+ config = SignLanguageByT5Config.from_pretrained(model_config)
679
+ model = SignLanguageByT5ForConditionalGeneration.from_pretrained(
680
+ model_checkpoint,
681
+ # config=config,
682
+ cache_dir=os.path.join(output_dir, "cache"),
683
+ )
684
+ tokenizer = ByT5Tokenizer.from_pretrained(tokenizer_checkpoint)
685
+
686
+ # Move model to appropriate device
687
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
688
+ model.to(device)
689
+ model.eval()
690
+
691
+ # Convert inputs to tensors and move to device
692
+ face_tensor = torch.tensor(face_embeddings, dtype=torch.float32).unsqueeze(0).to(device)
693
+ left_hand_tensor = torch.tensor(left_hand_embeddings, dtype=torch.float32).unsqueeze(0).to(device)
694
+ right_hand_tensor = torch.tensor(right_hand_embeddings, dtype=torch.float32).unsqueeze(0).to(device)
695
+ pose_tensor = torch.tensor(body_posture_embeddings, dtype=torch.float32).unsqueeze(0).to(device)
696
+
697
+ # Generate text
698
+ with torch.no_grad():
699
+ generated_ids = model.generate(
700
+ face_features=face_tensor,
701
+ left_hand_features=left_hand_tensor,
702
+ right_hand_features=right_hand_tensor,
703
+ pose_features=pose_tensor,
704
+ max_length=generation_max_length,
705
+ num_beams=generation_num_beams,
706
+ early_stopping=True,
707
+ pad_token_id=tokenizer.pad_token_id,
708
+ eos_token_id=tokenizer.eos_token_id,
709
+ )
710
+
711
+ # Decode generated text
712
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
713
+ return generated_text
714
+
715
+ def test(
716
+ face_embeddings: np.ndarray,
717
+ left_hand_embeddings: np.ndarray,
718
+ right_hand_embeddings: np.ndarray,
719
+ body_posture_embeddings: np.ndarray,
720
+ model_config: str,
721
+ model_checkpoint: str,
722
+ tokenizer_checkpoint: str,
723
+ output_dir: str,
724
+ ):
725
+ """
726
+ Test function for inference - generates text from sign language features.
727
+ This is a simpler wrapper around the direct inference function.
728
+ """
729
+ return generate_text_from_features(
730
+ face_embeddings=face_embeddings,
731
+ left_hand_embeddings=left_hand_embeddings,
732
+ right_hand_embeddings=right_hand_embeddings,
733
+ body_posture_embeddings=body_posture_embeddings,
734
+ model_config=model_config,
735
+ model_checkpoint=model_checkpoint,
736
+ tokenizer_checkpoint=tokenizer_checkpoint,
737
+ output_dir=output_dir,
738
+ )
kpe_mediapipe.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mediapipe as mp
2
+ from mediapipe.tasks import python
3
+ from mediapipe.tasks.python import vision
4
+ import cv2
5
+ import numpy as np
6
+ import json
7
+ from pathlib import Path
8
+ import decord
9
+ from typing import Dict, Optional, Tuple, Any
10
+
11
+
12
+ class HolisticDetector:
13
+ """
14
+ A class for detecting face, hand, and pose landmarks in videos using MediaPipe.
15
+ """
16
+
17
+ def __init__(self, face_model_path: str, hand_model_path: str,
18
+ min_detection_confidence: float = 0.1,
19
+ min_hand_detection_confidence: float = 0.05,
20
+ max_faces: int = 6, max_hands: int = 6):
21
+ """
22
+ Initialize the HolisticDetector with model paths and configuration.
23
+
24
+ Args:
25
+ face_model_path: Path to the face detection model
26
+ hand_model_path: Path to the hand detection model
27
+ min_detection_confidence: Minimum confidence for pose detection
28
+ min_hand_detection_confidence: Minimum confidence for hand detection
29
+ max_faces: Maximum number of faces to detect
30
+ max_hands: Maximum number of hands to detect
31
+ """
32
+ self.face_model_path = face_model_path
33
+ self.hand_model_path = hand_model_path
34
+ self.min_detection_confidence = min_detection_confidence
35
+ self.min_hand_detection_confidence = min_hand_detection_confidence
36
+ self.max_faces = max_faces
37
+ self.max_hands = max_hands
38
+
39
+ self._initialize_detectors()
40
+
41
+ def _initialize_detectors(self):
42
+ """Initialize the MediaPipe detectors."""
43
+ # Initialize face detector
44
+ base_options_face = python.BaseOptions(model_asset_path=self.face_model_path)
45
+ options_face = vision.FaceLandmarkerOptions(
46
+ base_options=base_options_face,
47
+ output_face_blendshapes=True,
48
+ output_facial_transformation_matrixes=True,
49
+ num_faces=self.max_faces
50
+ )
51
+ self.face_detector = vision.FaceLandmarker.create_from_options(options_face)
52
+
53
+ # Initialize hand detector
54
+ base_options_hand = python.BaseOptions(model_asset_path=self.hand_model_path)
55
+ options_hand = vision.HandLandmarkerOptions(
56
+ base_options=base_options_hand,
57
+ num_hands=self.max_hands,
58
+ min_hand_detection_confidence=self.min_hand_detection_confidence
59
+ )
60
+ self.hand_detector = vision.HandLandmarker.create_from_options(options_hand)
61
+
62
+ # Initialize holistic model for pose
63
+ self.mp_holistic = mp.solutions.holistic.Holistic(
64
+ min_detection_confidence=self.min_detection_confidence
65
+ )
66
+
67
+ def detect_frame_landmarks(self, image: np.ndarray) -> Tuple[Dict[str, int], Dict[str, Any]]:
68
+ """
69
+ Detect landmarks in a single frame.
70
+
71
+ Args:
72
+ image: Input image as numpy array
73
+
74
+ Returns:
75
+ Tuple of (bounding_boxes_count, landmarks_data)
76
+ """
77
+ results = self.mp_holistic.process(image)
78
+
79
+ mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
80
+ face_prediction = self.face_detector.detect(mp_image)
81
+ hand_prediction = self.hand_detector.detect(mp_image)
82
+
83
+ bounding_boxes = {}
84
+ landmarks_data = {}
85
+
86
+ # Process face landmarks
87
+ if face_prediction.face_landmarks:
88
+ bounding_boxes['#face'] = len(face_prediction.face_landmarks)
89
+ landmarks_data['face_landmarks'] = []
90
+ for face in face_prediction.face_landmarks:
91
+ landmarks_face = [[landmark.x, landmark.y, landmark.z] for landmark in face]
92
+ landmarks_data['face_landmarks'].append(landmarks_face)
93
+ else:
94
+ bounding_boxes['#face'] = 0
95
+ landmarks_data['face_landmarks'] = None
96
+
97
+ # Process hand landmarks
98
+ if hand_prediction.hand_landmarks:
99
+ bounding_boxes['#hands'] = len(hand_prediction.hand_landmarks)
100
+ landmarks_data['hand_landmarks'] = []
101
+ for hand in hand_prediction.hand_landmarks:
102
+ landmarks_hand = [[landmark.x, landmark.y, landmark.z] for landmark in hand]
103
+ landmarks_data['hand_landmarks'].append(landmarks_hand)
104
+ else:
105
+ bounding_boxes['#hands'] = 0
106
+ landmarks_data['hand_landmarks'] = None
107
+
108
+ # Process pose landmarks
109
+ if results.pose_landmarks:
110
+ bounding_boxes['#pose'] = 1
111
+ landmarks_data['pose_landmarks'] = []
112
+ pose_landmarks = [[landmark.x, landmark.y, landmark.z] for landmark in results.pose_landmarks.landmark]
113
+ landmarks_data['pose_landmarks'].append(pose_landmarks)
114
+ else:
115
+ bounding_boxes['#pose'] = 0
116
+ landmarks_data['pose_landmarks'] = None
117
+
118
+ return bounding_boxes, landmarks_data
119
+
120
+ def process_video(self, video_input, save_results: bool = False,
121
+ output_dir: Optional[str] = None, video_name: Optional[str] = None) -> Dict[int, Any]:
122
+ """
123
+ Process a video and extract landmarks from all frames.
124
+
125
+ Args:
126
+ video_input: Either a path to video file (str) or a decord.VideoReader object
127
+ save_results: Whether to save results to files
128
+ output_dir: Directory to save results (required if save_results=True)
129
+ video_name: Name for output files (required if save_results=True and video_input is VideoReader)
130
+
131
+ Returns:
132
+ Dictionary containing landmarks for each frame
133
+
134
+ Raises:
135
+ FileNotFoundError: If video file doesn't exist
136
+ ValueError: If save_results=True but output_dir is None, or if video_name is None when needed
137
+ TypeError: If video_input is neither string nor VideoReader
138
+ """
139
+ if save_results and output_dir is None:
140
+ raise ValueError("output_dir must be provided when save_results=True")
141
+
142
+ # Handle different input types
143
+ if isinstance(video_input, str):
144
+ # Input is a file path
145
+ video_path = Path(video_input)
146
+ if not video_path.exists():
147
+ raise FileNotFoundError(f"Video file not found: {video_input}")
148
+
149
+ try:
150
+ video = decord.VideoReader(str(video_path))
151
+ except Exception as e:
152
+ raise RuntimeError(f"Error loading video {video_input}: {e}")
153
+
154
+ file_name = video_path.stem
155
+
156
+ # elif hasattr(video_input, '__len__') and hasattr(video_input, '__getitem__'):
157
+ else:
158
+ # Input is a VideoReader object or similar
159
+ video = video_input
160
+ if save_results and video_name is None:
161
+ raise ValueError("video_name must be provided when save_results=True and video_input is a VideoReader object")
162
+ file_name = video_name or "video"
163
+
164
+ # else:
165
+ # raise TypeError("video_input must be either a file path (str) or a VideoReader object")
166
+
167
+ result_dict = {}
168
+ stats = {}
169
+
170
+ # Process each frame
171
+ for i in range(len(video)):
172
+ try:
173
+ # frame_rgb = video[i].asnumpy()
174
+ frame_rgb = video[i]
175
+ if hasattr(video, 'seek'):
176
+ video.seek(0)
177
+ bounding_boxes, landmarks = self.detect_frame_landmarks(frame_rgb)
178
+ result_dict[i] = landmarks
179
+ stats[i] = bounding_boxes
180
+ except Exception as e:
181
+ print(f"Error processing frame {i}: {e}")
182
+ result_dict[i] = None
183
+ stats[i] = {'#face': 0, '#hands': 0, '#pose': 0}
184
+
185
+ # Save results if requested
186
+ if save_results:
187
+ self._save_results(file_name, result_dict, stats, output_dir)
188
+
189
+ return result_dict
190
+
191
+ def process_video_frames(self, frames: list, save_results: bool = False,
192
+ output_dir: Optional[str] = None, video_name: str = "video") -> Dict[int, Any]:
193
+ """
194
+ Process a list of frames and extract landmarks.
195
+
196
+ Args:
197
+ frames: List of frame images as numpy arrays
198
+ save_results: Whether to save results to files
199
+ output_dir: Directory to save results (required if save_results=True)
200
+ video_name: Name for output files
201
+
202
+ Returns:
203
+ Dictionary containing landmarks for each frame
204
+ """
205
+ if save_results and output_dir is None:
206
+ raise ValueError("output_dir must be provided when save_results=True")
207
+
208
+ result_dict = {}
209
+ stats = {}
210
+
211
+ # Process each frame
212
+ for i, frame in enumerate(frames):
213
+ try:
214
+ bounding_boxes, landmarks = self.detect_frame_landmarks(frame)
215
+ result_dict[i] = landmarks
216
+ stats[i] = bounding_boxes
217
+ except Exception as e:
218
+ print(f"Error processing frame {i}: {e}")
219
+ result_dict[i] = None
220
+ stats[i] = {'#face': 0, '#hands': 0, '#pose': 0}
221
+
222
+ # Save results if requested
223
+ if save_results:
224
+ self._save_results(video_name, result_dict, stats, output_dir)
225
+
226
+ return result_dict
227
+
228
+ def _save_results(self, video_name: str, landmarks_data: Dict, stats_data: Dict, output_dir: str):
229
+ """Save landmarks and stats to JSON files."""
230
+ output_path = Path(output_dir)
231
+ output_path.mkdir(parents=True, exist_ok=True)
232
+
233
+ # Save landmarks
234
+ landmarks_file = output_path / f"{video_name}_pose.json"
235
+ with open(landmarks_file, 'w') as f:
236
+ json.dump(landmarks_data, f)
237
+
238
+ # Save stats
239
+ stats_file = output_path / f"{video_name}_stats.json"
240
+ with open(stats_file, 'w') as f:
241
+ json.dump(stats_data, f)
242
+
243
+ def compute_video_stats(self, landmarks_data: Dict) -> Dict[str, Any]:
244
+ """
245
+ Compute statistics from landmarks data.
246
+
247
+ Args:
248
+ landmarks_data: Dictionary containing landmarks for each frame
249
+
250
+ Returns:
251
+ Dictionary containing frame-by-frame stats and maximums
252
+ """
253
+ stats = {}
254
+ max_counts = {'#face': 0, '#hands': 0, '#pose': 0}
255
+
256
+ for frame, landmarks in landmarks_data.items():
257
+ if landmarks is None:
258
+ presence = {'#face': 0, '#hands': 0, '#pose': 0}
259
+ else:
260
+ presence = {
261
+ '#face': len(landmarks.get('face_landmarks', [])) if landmarks.get('face_landmarks') else 0,
262
+ '#hands': len(landmarks.get('hand_landmarks', [])) if landmarks.get('hand_landmarks') else 0,
263
+ '#pose': len(landmarks.get('pose_landmarks', [])) if landmarks.get('pose_landmarks') else 0
264
+ }
265
+ stats[frame] = presence
266
+
267
+ # Update max counts
268
+ for key in max_counts:
269
+ max_counts[key] = max(max_counts[key], presence[key])
270
+
271
+ stats['max'] = max_counts
272
+ return stats
273
+
274
+
275
+ # Convenience function for backward compatibility and simple usage
276
+ def video_holistic(video_input, face_model_path: str, hand_model_path: str,
277
+ save_results: bool = False, output_dir: Optional[str] = None,
278
+ video_name: Optional[str] = None) -> Dict[int, Any]:
279
+ """
280
+ Convenience function to process a video and extract holistic landmarks.
281
+
282
+ Args:
283
+ video_input: Either a path to video file (str) or a decord.VideoReader object
284
+ face_model_path: Path to the face detection model
285
+ hand_model_path: Path to the hand detection model
286
+ save_results: Whether to save results to files
287
+ output_dir: Directory to save results
288
+ video_name: Name for output files (required if save_results=True and video_input is VideoReader)
289
+
290
+ Returns:
291
+ Dictionary containing landmarks for each frame
292
+ """
293
+ detector = HolisticDetector(face_model_path, hand_model_path)
294
+ return detector.process_video(video_input, save_results, output_dir, video_name)
295
+
296
+
297
+ # Utility functions for batch processing
298
+ def load_file(filename: str):
299
+ """Load a pickled and gzipped file."""
300
+ import pickle
301
+ import gzip
302
+ with gzip.open(filename, "rb") as f:
303
+ return pickle.load(f)
304
+
305
+
306
+ def is_string_in_file(file_path: str, target_string: str) -> bool:
307
+ """Check if a string exists in a file."""
308
+ try:
309
+ with Path(file_path).open("r") as f:
310
+ for line in f:
311
+ if target_string in line:
312
+ return True
313
+ return False
314
+ except Exception as e:
315
+ print(f"Error: {e}")
316
+ return False
317
+
318
+
319
+ def main():
320
+ """Main function for command-line usage."""
321
+ import argparse
322
+ import time
323
+ import os
324
+
325
+ parser = argparse.ArgumentParser()
326
+ parser.add_argument('--index', type=int, required=True,
327
+ help='index of the sub_list to work with')
328
+ parser.add_argument('--batch_size', type=int, required=True,
329
+ help='batch size')
330
+ parser.add_argument('--pose_path', type=str, required=True,
331
+ help='path to where the pose data will be saved')
332
+ parser.add_argument('--stats_path', type=str, required=True,
333
+ help='path to where the stats data will be saved')
334
+ parser.add_argument('--time_limit', type=int, required=True,
335
+ help='time limit')
336
+ parser.add_argument('--files_list', type=str, required=True,
337
+ help='files list')
338
+ parser.add_argument('--problem_file_path', type=str, required=True,
339
+ help='problem file path')
340
+ parser.add_argument('--face_model_path', type=str, required=True,
341
+ help='face model path')
342
+ parser.add_argument('--hand_model_path', type=str, required=True,
343
+ help='hand model path')
344
+
345
+ args = parser.parse_args()
346
+
347
+ start_time = time.time()
348
+
349
+ # Initialize detector
350
+ detector = HolisticDetector(args.face_model_path, args.hand_model_path)
351
+
352
+ # Load the files list
353
+ fixed_list = load_file(args.files_list)
354
+
355
+ # Create folders if they do not exist
356
+ Path(args.pose_path).mkdir(parents=True, exist_ok=True)
357
+ Path(args.stats_path).mkdir(parents=True, exist_ok=True)
358
+
359
+ # Create problem file if it doesn't exist
360
+ if not os.path.exists(args.problem_file_path):
361
+ with open(args.problem_file_path, 'w') as f:
362
+ pass
363
+
364
+ # Process videos in batches
365
+ video_batches = [fixed_list[i:i + args.batch_size] for i in range(0, len(fixed_list), args.batch_size)]
366
+
367
+ for video_file in video_batches[args.index]:
368
+ current_time = time.time()
369
+ if current_time - start_time > args.time_limit:
370
+ print("Time limit reached. Stopping execution.")
371
+ break
372
+
373
+ # Check if output files already exist
374
+ video_name = Path(video_file).stem
375
+ landmark_json_path = Path(args.pose_path) / f"{video_name}_pose.json"
376
+ stats_json_path = Path(args.stats_path) / f"{video_name}_stats.json"
377
+
378
+ if landmark_json_path.exists() and stats_json_path.exists():
379
+ print(f"Skipping {video_file} - output files already exist")
380
+ continue
381
+ elif is_string_in_file(args.problem_file_path, video_file):
382
+ print(f"Skipping {video_file} - found in problem file")
383
+ continue
384
+ else:
385
+ try:
386
+ print(f"Processing {video_file}")
387
+ result_dict = detector.process_video(
388
+ video_file_path=video_file,
389
+ save_results=True,
390
+ output_dir=args.pose_path
391
+ )
392
+
393
+ # Also save stats separately for compatibility
394
+ stats = detector.compute_video_stats(result_dict)
395
+ with open(stats_json_path, 'w') as f:
396
+ json.dump(stats, f)
397
+
398
+ print(f"Successfully processed {video_file}")
399
+
400
+ except Exception as e:
401
+ print(f"Error processing {video_file}: {e}")
402
+ # Add to problem file
403
+ with open(args.problem_file_path, "a") as p:
404
+ p.write(video_file + "\n")
405
+
406
+
407
+ if __name__ == "__main__":
408
+ main()
shubert.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional
4
+
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.distributed as dist
10
+
11
+ import numpy as np
12
+ import random
13
+ import os
14
+ import sys
15
+
16
+ from fairseq.data.data_utils import compute_mask_indices
17
+ from fairseq.models import BaseFairseqModel, register_model
18
+ from fairseq.models.wav2vec import (
19
+ Wav2Vec2Config,
20
+ TransformerEncoder,
21
+ )
22
+
23
+ # Debug print to show where Wav2Vec2Config is defined
24
+ print(f"Wav2Vec2Config is imported from: {Wav2Vec2Config.__module__}")
25
+ print(f"Full path: {sys.modules[Wav2Vec2Config.__module__].__file__}")
26
+
27
+ from fairseq.modules import (
28
+ LayerNorm,
29
+ )
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ @dataclass
35
+ class SignHubertConfig(Wav2Vec2Config):
36
+ # pos_conv_kernel: int = field(default=32)
37
+ conv_pos: int = field(default=32)
38
+ discrete: bool = field(default=False)
39
+ codebook_size: int = field(default=256)
40
+ channels_embed_dim: int = field(default=384)
41
+ channels_pose_embed_dim: int = field(default=14)
42
+ intermediate_dim: int = field(default=1024) # This will be overridden if needed
43
+ mask_strategy: str = field(default="random")
44
+ channels: str = field(default="face,left_hand,right_hand,body_posture")
45
+
46
+
47
+ @register_model("signhubert_onlyhands", dataclass=SignHubertConfig)
48
+ class SignHubertModel(BaseFairseqModel):
49
+ def __init__(self, cfg: SignHubertConfig):
50
+ super().__init__()
51
+ self.cfg = cfg
52
+ # print(cfg)
53
+ self.discrete = cfg.discrete # since it's hubert this will always be discrete anyways
54
+
55
+ self.embed = cfg.encoder_embed_dim # whether it is small(384), base(768), large, etc.
56
+ self.channel_embed = cfg.channels_embed_dim # embedding dimension for face, left_hand and right_hand (default: 384)
57
+ self.channel_pose_embed = cfg.channels_pose_embed_dim # embedding dimension for pose (default: 14)
58
+ self.intermediate_dim = cfg.intermediate_dim # intermediate dimension before the projection layer to encoder_embed_dim (default: 1024)
59
+
60
+ self.channels = cfg.channels.split(",")
61
+
62
+ self.post_extract_proj = nn.Linear(cfg.intermediate_dim, cfg.encoder_embed_dim) # 4 channels concatenated
63
+
64
+ self.mask_prob = cfg.mask_prob
65
+ self.mask_selection = cfg.mask_selection
66
+ self.mask_strategy = cfg.mask_strategy
67
+ self.mask_other = cfg.mask_other
68
+ self.mask_length = cfg.mask_length
69
+ self.no_mask_overlap = cfg.no_mask_overlap
70
+ self.mask_min_space = cfg.mask_min_space
71
+
72
+ self.mask_channel_prob = cfg.mask_channel_prob
73
+ self.mask_channel_before = cfg.mask_channel_before
74
+ self.mask_channel_selection = cfg.mask_channel_selection
75
+ self.mask_channel_other = cfg.mask_channel_other
76
+ self.mask_channel_length = cfg.mask_channel_length
77
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
78
+ self.mask_channel_min_space = cfg.mask_channel_min_space
79
+
80
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
81
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
82
+
83
+ self.feature_grad_mult = cfg.feature_grad_mult
84
+
85
+ self.mask_emb = nn.Parameter(
86
+ torch.FloatTensor(1, 1, 1, cfg.intermediate_dim // len(self.channels)).uniform_()
87
+ )
88
+
89
+ self.encoder = TransformerEncoder(cfg)
90
+ self.layer_norm = LayerNorm(self.channel_embed * len(self.channels))
91
+
92
+
93
+ if "face" in self.channels:
94
+ self.layer_norm_face = LayerNorm(self.channel_embed)
95
+ self.face_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // len(self.channels))
96
+ if "left_hand" in self.channels:
97
+ self.layer_norm_lhand = LayerNorm(self.channel_embed)
98
+ self.left_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // len(self.channels))
99
+ if "right_hand" in self.channels:
100
+ self.layer_norm_rhand = LayerNorm(self.channel_embed)
101
+ self.right_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // len(self.channels))
102
+ if "body_posture" in self.channels:
103
+ self.layer_norm_body = LayerNorm(self.channel_pose_embed)
104
+ self.body_posture_proj = nn.Linear(self.channel_pose_embed, cfg.intermediate_dim // len(self.channels))
105
+
106
+ self.codebook_size = cfg.codebook_size # number of codebook vectors
107
+
108
+ self.heads = []
109
+ for i in range(len(self.channels)):
110
+ self.heads.append(nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size))
111
+
112
+ self.heads = torch.nn.ModuleList(self.heads)
113
+
114
+ # self.heads = torch.nn.ModuleList([
115
+ # nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size) ,
116
+ # nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size),
117
+ # nn.Linear(cfg.encoder_embed_dim, cfg.codebook_size),
118
+ # ]
119
+ # )
120
+
121
+
122
+
123
+ # # Define separate linear layers for each channel
124
+ # self.face_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // 4)
125
+ # self.left_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // 4)
126
+ # self.right_hand_proj = nn.Linear(self.channel_embed, cfg.intermediate_dim // 4)
127
+ # self.body_posture_proj = nn.Linear(self.channel_pose_embed, cfg.intermediate_dim // 4)
128
+
129
+
130
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
131
+
132
+ state = super().state_dict(destination, prefix, keep_vars)
133
+
134
+ return state
135
+
136
+
137
+
138
+ @classmethod
139
+ def build_model(cls, cfg: SignHubertConfig, task=None):
140
+ """Build a new model instance."""
141
+
142
+ return cls(cfg)
143
+
144
+ def apply_mask(
145
+ self,
146
+ x,
147
+ padding_mask,
148
+ mask_indices=None,
149
+ mask_channel_indices=None,
150
+ ):
151
+ B, T, C, D = x.shape
152
+
153
+ # Initialize a mask vector with ones (same shape as x)
154
+ mask = torch.ones_like(x)
155
+
156
+ # channel masking
157
+ if self.mask_prob > 0 and self.mask_strategy == "channel":
158
+ if mask_indices is None:
159
+ mask_indices = torch.zeros_like(x[:,:,:,0], dtype=bool)
160
+ num_channels_to_mask = int(C * self.mask_prob)
161
+ num_channels_to_mask = max(1, num_channels_to_mask)
162
+
163
+ for i in range(B):
164
+ channels_to_mask = np.random.choice(C, num_channels_to_mask, replace=False)
165
+ mask_indices[i, :, channels_to_mask] = True
166
+
167
+ mask[mask_indices.unsqueeze(-1).expand(-1, -1, -1, D)] = 0
168
+
169
+ # gloss/time masking
170
+ elif self.mask_prob > 0 and self.mask_strategy == "gloss":
171
+ if mask_indices is None:
172
+ mask_indices_channel = compute_mask_indices(
173
+ (B, T),
174
+ padding_mask,
175
+ self.mask_prob,
176
+ self.mask_length,
177
+ self.mask_selection,
178
+ self.mask_other,
179
+ min_masks=1,
180
+ no_overlap=self.no_mask_channel_overlap,
181
+ min_space=self.mask_min_space,
182
+ require_same_masks=self.cfg.require_same_masks,
183
+ mask_dropout=self.cfg.mask_dropout,
184
+ )
185
+ mask_indices_channel = torch.from_numpy(mask_indices_channel).to(x.device)
186
+
187
+ # Apply the same mask to all channels
188
+ mask_indices = mask_indices_channel.unsqueeze(2).expand(-1, -1, C)
189
+ mask_indices = mask_indices.unsqueeze(3).expand(-1, -1, -1, D)
190
+ mask[mask_indices] = 0
191
+
192
+ # random masking
193
+ elif self.mask_prob > 0 and self.mask_strategy == "random":
194
+ if mask_indices is None:
195
+ mask_indices = compute_mask_indices(
196
+ (B, T*C), # Note: T*C instead of T
197
+ padding_mask,
198
+ self.mask_prob,
199
+ self.mask_length,
200
+ self.mask_selection,
201
+ self.mask_other,
202
+ min_masks=1,
203
+ no_overlap=self.no_mask_channel_overlap,
204
+ min_space=self.mask_min_space,
205
+ require_same_masks=self.cfg.require_same_masks,
206
+ mask_dropout=self.cfg.mask_dropout,
207
+ )
208
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
209
+ mask_indices = mask_indices.view(B, T, C)
210
+ mask_indices = mask_indices.unsqueeze(3).expand(-1, -1, -1, D)
211
+ mask[mask_indices] = 0
212
+ else:
213
+ raise ValueError(f"unknown mask strategy {self.mask_strategy}")
214
+
215
+ # Apply the mask to x and return the masked tensor with the same shape as x
216
+ # x = x * mask
217
+ x = x * mask + self.mask_emb * (1 - mask)
218
+
219
+ return x, mask
220
+ # mask is a tensor of shape BxTx4x256 where 0 means the value is masked and 1 means the value is not masked
221
+
222
+
223
+ def forward(
224
+ self,
225
+ source,
226
+ padding_mask=None,
227
+ mask=True,
228
+ features_only=False,
229
+ layer=None,
230
+ mask_indices=None,
231
+ mask_channel_indices=None,
232
+ padding_count=None,
233
+ kmeans_labels=None,
234
+ ):
235
+
236
+ channels_to_use = []
237
+ for c in self.channels:
238
+ if c in source[0]:
239
+ channels_to_use.append(c)
240
+
241
+ for c in channels_to_use:
242
+ if c == "face":
243
+ face_features_list = []
244
+ label_face_features_list = []
245
+ elif c == "left_hand":
246
+ left_hand_features_list = []
247
+ label_left_hand_features_list = []
248
+ elif c == "right_hand":
249
+ right_hand_features_list = []
250
+ label_right_hand_features_list = []
251
+ elif c == "body_posture":
252
+ body_posture_features_list = []
253
+ label_body_posture_features_list = []
254
+
255
+ # # source is a list of dictionaries with keys "face", "left_hand", "right_hand", "body_posture"
256
+ # face_features_list = []
257
+ # left_hand_features_list = []
258
+ # right_hand_features_list = []
259
+ # body_posture_features_list = []
260
+ # label_face_features_list = []
261
+ # label_left_hand_features_list = []
262
+ # label_right_hand_features_list = []
263
+ # label_body_posture_features_list = []
264
+
265
+ # for sample in source:
266
+ # face_features_list.append(sample["face"]) # Tx384
267
+ # left_hand_features_list.append(sample["left_hand"]) # Tx384
268
+ # right_hand_features_list.append(sample["right_hand"]) # Tx384
269
+ # body_posture_features_list.append(sample["body_posture"]) # Tx14
270
+ # label_face_features_list.append(sample["label_face"]) # Tx1
271
+ # label_left_hand_features_list.append(sample["label_left_hand"]) # Tx1
272
+ # label_right_hand_features_list.append(sample["label_right_hand"]) # Tx1
273
+ # label_body_posture_features_list.append(sample["label_body_posture"]) # Tx1
274
+
275
+ for sample in source:
276
+ for c in channels_to_use:
277
+ if c == "face":
278
+ face_features_list.append(sample["face"]) # Tx384
279
+ label_face_features_list.append(sample["label_face"]) # Tx1
280
+ elif c == "left_hand":
281
+ left_hand_features_list.append(sample["left_hand"]) # Tx384
282
+ label_left_hand_features_list.append(sample["label_left_hand"]) # Tx1
283
+ elif c == "right_hand":
284
+ right_hand_features_list.append(sample["right_hand"]) # Tx384
285
+ label_right_hand_features_list.append(sample["label_right_hand"]) # Tx1
286
+ elif c == "body_posture":
287
+ body_posture_features_list.append(sample["body_posture"]) # Tx14
288
+ label_body_posture_features_list.append(sample["label_body_posture"]) # Tx1
289
+
290
+
291
+
292
+
293
+ # face_features = torch.stack(face_features_list) # BxTx384
294
+ # left_hand_features = torch.stack(left_hand_features_list) # BxTx384
295
+ # right_hand_features = torch.stack(right_hand_features_list) # BxTx384
296
+ # body_posture_features = torch.stack(body_posture_features_list) # BxTx14
297
+ # face_labels = torch.stack(label_face_features_list) # BxTx1
298
+ # left_hand_labels = torch.stack(label_left_hand_features_list) # BxTx1
299
+ # right_hand_labels = torch.stack(label_right_hand_features_list) # BxTx1
300
+ # body_posture_labels = torch.stack(label_body_posture_features_list) # BxTx1
301
+
302
+
303
+ # # Apply layer normalization to each part
304
+ # face_features = self.layer_norm_face(face_features) # BxTx384
305
+ # left_hand_features = self.layer_norm_lhand(left_hand_features) # BxTx384
306
+ # right_hand_features = self.layer_norm_rhand(right_hand_features) # BxTx384
307
+ # body_posture_features = self.layer_norm_body(body_posture_features) # BxTx14
308
+
309
+ # # Apply separate linear projections for each channel
310
+ # face_features = self.face_proj(face_features) # BxTx256
311
+ # left_hand_features = self.left_hand_proj(left_hand_features) # BxTx256
312
+ # right_hand_features = self.right_hand_proj(right_hand_features) # BxTx256
313
+ # body_posture_features = self.body_posture_proj(body_posture_features) # BxTx256
314
+
315
+ features_list = []
316
+ labels_list = []
317
+
318
+ for c in channels_to_use:
319
+ if c == "face":
320
+ face_features = torch.stack(face_features_list) # BxTx384
321
+ face_labels = torch.stack(label_face_features_list) # BxTx1
322
+ face_features = self.layer_norm_face(face_features) # BxTx384
323
+ face_features = self.face_proj(face_features) # BxTx256
324
+ features_list.append(face_features)
325
+ labels_list.append(face_labels)
326
+ elif c == "left_hand":
327
+ left_hand_features = torch.stack(left_hand_features_list) # BxTx384
328
+ left_hand_labels = torch.stack(label_left_hand_features_list) # BxTx1
329
+ left_hand_features = self.layer_norm_lhand(left_hand_features) # BxTx384
330
+ left_hand_features = self.left_hand_proj(left_hand_features) # BxTx256
331
+ features_list.append(left_hand_features)
332
+ labels_list.append(left_hand_labels)
333
+ elif c == "right_hand":
334
+ right_hand_features = torch.stack(right_hand_features_list) # BxTx384
335
+ right_hand_labels = torch.stack(label_right_hand_features_list) # BxTx1
336
+ right_hand_features = self.layer_norm_rhand(right_hand_features) # BxTx384
337
+ right_hand_features = self.right_hand_proj(right_hand_features) # BxTx256
338
+ features_list.append(right_hand_features)
339
+ labels_list.append(right_hand_labels)
340
+ elif c == "body_posture":
341
+ body_posture_features = torch.stack(body_posture_features_list) # BxTx14
342
+ body_posture_labels = torch.stack(label_body_posture_features_list) # BxTx1
343
+ body_posture_features = self.layer_norm_body(body_posture_features) # BxTx14
344
+ body_posture_features = self.body_posture_proj(body_posture_features) # BxTx256
345
+ features_list.append(body_posture_features)
346
+ labels_list.append(body_posture_labels)
347
+
348
+
349
+ # concatenate the projected features to have dimension BxTxCxD where C=4 and D=256
350
+ # features = torch.stack(
351
+ # [
352
+ # face_features,
353
+ # left_hand_features,
354
+ # right_hand_features,
355
+ # body_posture_features
356
+ # ],
357
+ # dim=2) # BxTx4x256
358
+
359
+ features = torch.stack(features_list, dim=2) # BxTx4x256
360
+
361
+ if mask:
362
+ x, mask_indices = self.apply_mask(
363
+ features,
364
+ padding_mask,
365
+ mask_indices=mask_indices,
366
+ mask_channel_indices=mask_channel_indices,
367
+ )
368
+ # mask_indices is a tensor of shape BxTx4x256 where 0 means the value is masked and 1 means the value is not masked
369
+ else:
370
+ x = features
371
+ mask_indices = None
372
+
373
+
374
+ x = self.dropout_input(x) # BxTx4x256
375
+
376
+ x = x.view(x.size(0), x.size(1), -1) # BxTx1024
377
+ if self.post_extract_proj is not None:
378
+ x = self.post_extract_proj(x) # BxTx768
379
+
380
+ x, layer_results = self.encoder(
381
+ x,
382
+ padding_mask=padding_mask,
383
+ layer=layer,
384
+ )
385
+
386
+ if features_only:
387
+ return {
388
+ "x": x,
389
+ "padding_mask": padding_mask,
390
+ "layer_results": layer_results,
391
+ }
392
+
393
+ result = {
394
+ "losses": {},
395
+ }
396
+
397
+ # use linear heads to compute the discrete prediction for each channel and make it into a single tensor of shape BxTxCxcodebook_size
398
+ predictions = []
399
+ for i, head in enumerate(self.heads):
400
+ channel_pred = head(x) # BxTxcodebook_size
401
+ predictions.append(channel_pred)
402
+ predictions = torch.stack(predictions, dim=2) # BxTx4xcodebook_size
403
+
404
+ # labels = torch.stack(
405
+ # [
406
+ # face_labels,
407
+ # left_hand_labels,
408
+ # right_hand_labels,
409
+ # body_posture_labels
410
+ # ],
411
+ # dim=2) # BxTx4x1
412
+
413
+ labels = torch.stack(labels_list, dim=2) # BxTx4x1
414
+ # print(f"predictions shape: {predictions.shape} and labels shape: {labels.shape}")
415
+
416
+ predictions_flat = predictions.view(-1, self.codebook_size) # Shape: (B * T * C, codebook_size)
417
+ labels_flat = labels.view(-1) # Shape: (B * T * C)
418
+
419
+ # Ensure labels are of correct shape
420
+ labels_flat = labels_flat.squeeze(-1) # Remove the last dimension if it's size 1
421
+
422
+ # Correct the mask_indices to match the shape of predictions_flat
423
+ mask_indices_reduced = mask_indices.any(dim=-1) # Reduce mask to (B, T, C) by collapsing last dimension
424
+ mask_indices_flat = mask_indices_reduced.view(-1) # Flatten to match the shape of (B * T * C)
425
+
426
+ # Calculate the loss only for the masked positions (where mask_indices_flat is zero)
427
+ masked_loss = F.cross_entropy(
428
+ predictions_flat[mask_indices_flat == 0],
429
+ labels_flat[mask_indices_flat == 0],
430
+ reduction='none'
431
+ )
432
+
433
+ # Store the result
434
+ result['losses']['kmeans_loss'] = masked_loss
435
+
436
+
437
+
438
+ if "sample_size" not in result:
439
+ result['sample_size'] = masked_loss.numel()
440
+
441
+ return result
442
+
443
+ @staticmethod
444
+ def compute_var(y):
445
+ y = y.view(-1, y.size(-1))
446
+ if dist.is_initialized():
447
+ zc = torch.tensor(y.size(0)).cuda()
448
+ zs = y.sum(dim=0)
449
+ zss = (y ** 2).sum(dim=0)
450
+
451
+ dist.all_reduce(zc)
452
+ dist.all_reduce(zs)
453
+ dist.all_reduce(zss)
454
+
455
+ var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
456
+ return torch.sqrt(var + 1e-6).mean()
457
+ else:
458
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
459
+
460
+ def extract_features(
461
+ self, source, padding_mask, kmeans_labels, mask=False, layer=None
462
+ ):
463
+ res = self.forward(
464
+ source,
465
+ padding_mask,
466
+ mask=mask,
467
+ features_only=True,
468
+ layer=layer,
469
+ kmeans_labels=kmeans_labels,
470
+ )
471
+ return res
472
+
473
+ def remove_pretraining_modules(self, last_layer=None):
474
+ self.heads = None
475
+ self.final_proj = None
476
+ if last_layer is not None:
477
+ self.encoder.layers = nn.ModuleList(
478
+ l for i, l in enumerate(self.encoder.layers) if i <= last_layer
479
+ )
shubert_inference.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import csv
4
+ import os
5
+ from tqdm import tqdm
6
+ import argparse
7
+ from pathlib import Path
8
+ from typing import Dict, List, Optional, Tuple, Union, Any
9
+ from examples.shubert.models.shubert import SHubertModel, SHubertConfig
10
+ from transformers import ByT5Tokenizer, ByT5ForConditionalGeneration
11
+
12
+
13
+ class SHubertProcessor:
14
+ """
15
+ A class for processing multi-modal embeddings through SHubert model.
16
+ """
17
+
18
+ def __init__(self, checkpoint_path: str, device: Optional[str] = None):
19
+ """
20
+ Initialize the SHubertProcessor.
21
+
22
+ Args:
23
+ checkpoint_path: Path to the SHubert model checkpoint
24
+ device: Device to use ('cuda' or 'cpu'). Auto-detected if None
25
+ """
26
+ self.checkpoint_path = checkpoint_path
27
+ self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ # Load the model
30
+ self.model = self._load_model()
31
+
32
+ print(f"SHubertProcessor initialized on device: {self.device}")
33
+
34
+ def _load_model(self) -> SHubertModel:
35
+ """Load the SHubert model from checkpoint."""
36
+ # Initialize configuration
37
+ cfg = SHubertConfig()
38
+
39
+ # Initialize the model
40
+ model = SHubertModel(cfg)
41
+
42
+ # Load the checkpoint
43
+ checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
44
+
45
+ # Extract state dict
46
+ if 'model' in checkpoint:
47
+ state_dict = checkpoint['model']
48
+ else:
49
+ state_dict = checkpoint
50
+
51
+ # Load the state dictionary into the model
52
+ model.load_state_dict(state_dict, strict=False)
53
+
54
+ model.eval()
55
+ model.to(self.device)
56
+ return model
57
+
58
+ def process_embeddings(self, face_embeddings: np.ndarray,
59
+ left_hand_embeddings: np.ndarray,
60
+ right_hand_embeddings: np.ndarray,
61
+ pose_embeddings: np.ndarray) -> np.ndarray:
62
+ """
63
+ Process multi-modal embeddings through SHubert model.
64
+
65
+ Args:
66
+ face_embeddings: Face embeddings array of shape (num_frames, embedding_dim)
67
+ left_hand_embeddings: Left hand embeddings array of shape (num_frames, embedding_dim)
68
+ right_hand_embeddings: Right hand embeddings array of shape (num_frames, embedding_dim)
69
+ pose_embeddings: Pose embeddings array of shape (num_frames, pose_dim)
70
+
71
+ Returns:
72
+ Numpy array of SHubert features with shape (num_layers, num_frames, feature_dim)
73
+ """
74
+ # Convert to tensors and move to device
75
+ face = torch.from_numpy(face_embeddings).float().to(self.device)
76
+ left_hand = torch.from_numpy(left_hand_embeddings).float().to(self.device)
77
+ right_hand = torch.from_numpy(right_hand_embeddings).float().to(self.device)
78
+ body_posture = torch.from_numpy(pose_embeddings).float().to(self.device)
79
+
80
+ length = face.shape[0]
81
+
82
+ # Prepare input in the format expected by SHubert
83
+ source = [{
84
+ "face": face,
85
+ "left_hand": left_hand,
86
+ "right_hand": right_hand,
87
+ "body_posture": body_posture,
88
+ # Add dummy labels to match the expected input format
89
+ "label_face": torch.zeros((length, 1)).to(self.device),
90
+ "label_left_hand": torch.zeros((length, 1)).to(self.device),
91
+ "label_right_hand": torch.zeros((length, 1)).to(self.device),
92
+ "label_body_posture": torch.zeros((length, 1)).to(self.device)
93
+ }]
94
+
95
+ # Extract features
96
+ with torch.no_grad():
97
+ result = self.model.extract_features(source, padding_mask=None, kmeans_labels=None, mask=False)
98
+
99
+ # Extract layer outputs
100
+ layer_outputs = []
101
+ for layer in result['layer_results']:
102
+ # layer_output has shape [T, B, D]
103
+ # Since batch size B is 1, we can squeeze it
104
+ layer_output = layer[-1]
105
+ layer_output = layer_output.squeeze(1) # Shape: [T, D]
106
+ layer_outputs.append(layer_output.cpu().numpy()) # Convert to NumPy array
107
+
108
+ # Stack the outputs from all layers to get an array of shape [L, T, D]
109
+ features = np.stack(layer_outputs, axis=0) # Shape: [L, T, D]
110
+ return features
111
+
112
+ def process_embeddings_from_files(self, face_path: str, left_hand_path: str,
113
+ right_hand_path: str, pose_path: str) -> np.ndarray:
114
+ """
115
+ Process embeddings loaded from files.
116
+
117
+ Args:
118
+ face_path: Path to face embeddings .npy file
119
+ left_hand_path: Path to left hand embeddings .npy file
120
+ right_hand_path: Path to right hand embeddings .npy file
121
+ pose_path: Path to pose embeddings .npy file
122
+
123
+ Returns:
124
+ Numpy array of SHubert features with shape (num_layers, num_frames, feature_dim)
125
+ """
126
+ # Load numpy arrays
127
+ face_embeddings = np.load(face_path)
128
+ left_hand_embeddings = np.load(left_hand_path)
129
+ right_hand_embeddings = np.load(right_hand_path)
130
+ pose_embeddings = np.load(pose_path)
131
+
132
+ return self.process_embeddings(face_embeddings, left_hand_embeddings,
133
+ right_hand_embeddings, pose_embeddings)
134
+
135
+ def process_and_save_embeddings(self, face_embeddings: np.ndarray,
136
+ left_hand_embeddings: np.ndarray,
137
+ right_hand_embeddings: np.ndarray,
138
+ pose_embeddings: np.ndarray,
139
+ output_path: str) -> str:
140
+ """
141
+ Process embeddings and save to file.
142
+
143
+ Args:
144
+ face_embeddings: Face embeddings array
145
+ left_hand_embeddings: Left hand embeddings array
146
+ right_hand_embeddings: Right hand embeddings array
147
+ pose_embeddings: Pose embeddings array
148
+ output_path: Path to save the output file
149
+
150
+ Returns:
151
+ Path to the saved file
152
+ """
153
+ # Process embeddings
154
+ features = self.process_embeddings(face_embeddings, left_hand_embeddings,
155
+ right_hand_embeddings, pose_embeddings)
156
+
157
+ # Create output directory if it doesn't exist
158
+ output_dir = Path(output_path).parent
159
+ output_dir.mkdir(parents=True, exist_ok=True)
160
+
161
+ # Save features
162
+ np.save(output_path, features)
163
+
164
+ return str(output_path)
165
+
166
+ def process_from_files_and_save(self, face_path: str, left_hand_path: str,
167
+ right_hand_path: str, pose_path: str,
168
+ output_path: str) -> str:
169
+ """
170
+ Process embeddings from files and save results.
171
+
172
+ Args:
173
+ face_path: Path to face embeddings .npy file
174
+ left_hand_path: Path to left hand embeddings .npy file
175
+ right_hand_path: Path to right hand embeddings .npy file
176
+ pose_path: Path to pose embeddings .npy file
177
+ output_path: Path to save the output file
178
+
179
+ Returns:
180
+ Path to the saved file
181
+ """
182
+ # Process embeddings
183
+ features = self.process_embeddings_from_files(face_path, left_hand_path,
184
+ right_hand_path, pose_path)
185
+
186
+ # Create output directory if it doesn't exist
187
+ output_dir = Path(output_path).parent
188
+ output_dir.mkdir(parents=True, exist_ok=True)
189
+
190
+ # Save features
191
+ np.save(output_path, features)
192
+
193
+ return str(output_path)
194
+
195
+
196
+ class SHuBERTTextGenerator:
197
+ """
198
+ A class that combines SHuBERT feature extraction with BYT5 text generation.
199
+ """
200
+
201
+ def __init__(self, shubert_checkpoint: str, byt5_model_name: str = "google/byt5-base",
202
+ device: Optional[str] = None):
203
+ """
204
+ Initialize with SHuBERT and BYT5 models.
205
+
206
+ Args:
207
+ shubert_checkpoint: Path to SHuBERT model checkpoint
208
+ byt5_model_name: Name of BYT5 model (default: "google/byt5-base")
209
+ device: Device to use ('cuda' or 'cpu')
210
+ """
211
+ self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
212
+
213
+ # Initialize SHuBERT processor
214
+ self.shubert_processor = SHubertProcessor(shubert_checkpoint, self.device)
215
+
216
+ # Initialize BYT5 model
217
+ self.tokenizer = ByT5Tokenizer.from_pretrained(byt5_model_name)
218
+ self.model = ByT5ForConditionalGeneration.from_pretrained(byt5_model_name).to(self.device)
219
+
220
+ def generate_text(self, face_embeddings: np.ndarray,
221
+ left_hand_embeddings: np.ndarray,
222
+ right_hand_embeddings: np.ndarray,
223
+ pose_embeddings: np.ndarray,
224
+ max_length: int = 1024,
225
+ num_beams: int = 5) -> str:
226
+ """
227
+ Generate text from multi-modal embeddings.
228
+
229
+ Args:
230
+ face_embeddings: Face embeddings array
231
+ left_hand_embeddings: Left hand embeddings array
232
+ right_hand_embeddings: Right hand embeddings array
233
+ pose_embeddings: Pose embeddings array
234
+ max_length: Maximum length of generated text
235
+ num_beams: Number of beams for beam search
236
+
237
+ Returns:
238
+ Generated text string
239
+ """
240
+ # Get SHuBERT features
241
+ features = self.shubert_processor.process_embeddings(
242
+ face_embeddings, left_hand_embeddings, right_hand_embeddings, pose_embeddings)
243
+
244
+ # Select features from specific layer (default: last layer)
245
+ features = features[-1] # Shape: [T, D]
246
+
247
+ # Convert to tensor and add batch dimension
248
+ features = torch.from_numpy(features).float().unsqueeze(0).to(self.device)
249
+
250
+ # Generate text
251
+ generated_ids = self.model.generate(
252
+ inputs_embeds=features,
253
+ max_length=max_length,
254
+ num_beams=num_beams,
255
+ early_stopping=True
256
+ )
257
+
258
+ # Decode generated tokens to text
259
+ return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
260
+
261
+
262
+ def generate_text_from_features(face_embeddings: np.ndarray,
263
+ left_hand_embeddings: np.ndarray,
264
+ right_hand_embeddings: np.ndarray,
265
+ pose_embeddings: np.ndarray,
266
+ shubert_checkpoint: str,
267
+ byt5_model_name: str = "google/byt5-base",
268
+ max_length: int = 1024,
269
+ num_beams: int = 5) -> str:
270
+ """
271
+ Convenience function to generate text from features.
272
+ """
273
+ generator = SHuBERTTextGenerator(shubert_checkpoint, byt5_model_name)
274
+ return generator.generate_text(
275
+ face_embeddings, left_hand_embeddings, right_hand_embeddings, pose_embeddings,
276
+ max_length=max_length, num_beams=num_beams
277
+ )
278
+
279
+
280
+ # Convenience functions for backward compatibility
281
+ def process_shubert_embeddings(face_embeddings: np.ndarray,
282
+ left_hand_embeddings: np.ndarray,
283
+ right_hand_embeddings: np.ndarray,
284
+ pose_embeddings: np.ndarray,
285
+ checkpoint_path: str) -> np.ndarray:
286
+ """
287
+ Convenience function to process embeddings through SHubert.
288
+
289
+ Args:
290
+ face_embeddings: Face embeddings array
291
+ left_hand_embeddings: Left hand embeddings array
292
+ right_hand_embeddings: Right hand embeddings array
293
+ pose_embeddings: Pose embeddings array
294
+ checkpoint_path: Path to the SHubert model checkpoint
295
+
296
+ Returns:
297
+ Numpy array of SHubert features
298
+ """
299
+ processor = SHubertProcessor(checkpoint_path)
300
+ return processor.process_embeddings(face_embeddings, left_hand_embeddings,
301
+ right_hand_embeddings, pose_embeddings)
302
+
303
+
304
+ def process_sample(model: SHubertModel, face_path: str, left_hand_path: str,
305
+ right_hand_path: str, body_posture_path: str) -> np.ndarray:
306
+ """
307
+ Original function for backward compatibility with command-line usage.
308
+ """
309
+ # Load numpy arrays
310
+ face_np = np.load(face_path)
311
+ left_hand_np = np.load(left_hand_path)
312
+ right_hand_np = np.load(right_hand_path)
313
+ body_posture_np = np.load(body_posture_path)
314
+
315
+ face = torch.from_numpy(face_np).float().cuda()
316
+ left_hand = torch.from_numpy(left_hand_np).float().cuda()
317
+ right_hand = torch.from_numpy(right_hand_np).float().cuda()
318
+ body_posture = torch.from_numpy(body_posture_np).float().cuda()
319
+
320
+ length = face.shape[0]
321
+
322
+ # Prepare input
323
+ source = [{
324
+ "face": face,
325
+ "left_hand": left_hand,
326
+ "right_hand": right_hand,
327
+ "body_posture": body_posture,
328
+ # Add dummy labels to match the expected input format
329
+ "label_face": torch.zeros((length, 1)).cuda(),
330
+ "label_left_hand": torch.zeros((length, 1)).cuda(),
331
+ "label_right_hand": torch.zeros((length, 1)).cuda(),
332
+ "label_body_posture": torch.zeros((length, 1)).cuda()
333
+ }]
334
+
335
+ # Extract features
336
+ with torch.no_grad():
337
+ result = model.extract_features(source, padding_mask=None, kmeans_labels=None, mask=False)
338
+
339
+ # Extract layer outputs
340
+ layer_outputs = []
341
+ for layer in result['layer_results']:
342
+ # layer_output has shape [T, B, D]
343
+ # Since batch size B is 1, we can squeeze it
344
+ layer_output = layer[-1]
345
+ layer_output = layer_output.squeeze(1) # Shape: [T, D]
346
+ layer_outputs.append(layer_output.cpu().numpy()) # Convert to NumPy array
347
+
348
+ # Stack the outputs from all layers to get an array of shape [L, T, D]
349
+ features = np.stack(layer_outputs, axis=0) # Shape: [L, T, D]
350
+ return features
351
+
352
+
353
+ def load_model(checkpoint_path: str) -> SHubertModel:
354
+ """
355
+ Original function for backward compatibility with command-line usage.
356
+ """
357
+ cfg = SHubertConfig()
358
+
359
+ # Initialize the model
360
+ model = SHubertModel(cfg)
361
+
362
+ # Load the checkpoint
363
+ checkpoint = torch.load(checkpoint_path)
364
+
365
+ # If the checkpoint is saved with a 'model' key
366
+ if 'model' in checkpoint:
367
+ state_dict = checkpoint['model']
368
+ else:
369
+ state_dict = checkpoint
370
+
371
+ # Load the state dictionary into the model
372
+ model.load_state_dict(state_dict, strict=False)
373
+
374
+ model.eval()
375
+ model.cuda() # Move to GPU if available
376
+ return model
377
+
378
+
379
+ def main(csv_list: List[List[str]], checkpoint_path: str, output_dir: str, index: int):
380
+ """
381
+ Original main function for backward compatibility with command-line usage.
382
+ """
383
+ model = load_model(checkpoint_path)
384
+
385
+ os.makedirs(output_dir, exist_ok=True)
386
+
387
+ for row in csv_list:
388
+ cues_list = row[0].split('\t')
389
+ face_path, left_hand_path, right_hand_path, body_posture_path = cues_list[0], cues_list[1], cues_list[2], cues_list[3]
390
+
391
+ output_filename = f"{os.path.basename(face_path).rsplit('.', 1)[0].rsplit('_', 1)[0]}.npy"
392
+ output_path = os.path.join(output_dir, output_filename)
393
+
394
+ # check if the output file already exists
395
+ if os.path.exists(output_path):
396
+ print(f"Skipping {output_path} as it already exists")
397
+ continue
398
+
399
+ # Process the sample
400
+ features = process_sample(model, face_path, left_hand_path, right_hand_path, body_posture_path)
401
+
402
+ np.save(output_path, features)
403
+
404
+
405
+ if __name__ == "__main__":
406
+ parser = argparse.ArgumentParser()
407
+ parser.add_argument('--index', type=int, required=True,
408
+ help='index of the sub_list to work with')
409
+ parser.add_argument('--csv_path', type=str, required=True,
410
+ help='path to the CSV file')
411
+ parser.add_argument('--checkpoint_path', type=str, required=True,
412
+ help='path to the checkpoint file')
413
+ parser.add_argument('--output_dir', type=str, required=True,
414
+ help='directory to save output files')
415
+ parser.add_argument('--batch_size', type=int, required=True,
416
+ help='batch size for processing')
417
+
418
+ args = parser.parse_args()
419
+ index = args.index
420
+ csv_path = args.csv_path
421
+ checkpoint_path = args.checkpoint_path
422
+ output_dir = args.output_dir
423
+ batch_size = int(args.batch_size)
424
+
425
+ # make output dir
426
+ os.makedirs(output_dir, exist_ok=True)
427
+
428
+ # Load CSV data
429
+ fixed_list = []
430
+ with open(csv_path, 'r') as csvfile:
431
+ reader = csv.reader(csvfile)
432
+ for row in reader:
433
+ fixed_list.append(row)
434
+
435
+ # Process in batches
436
+ video_batches = [fixed_list[i:i + batch_size] for i in range(0, len(fixed_list), batch_size)]
437
+
438
+ csv_list = video_batches[index]
439
+ main(csv_list, checkpoint_path, output_dir, index)