Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
from PIL import Image | |
import decord | |
from decord import VideoReader | |
from decord import cpu, gpu | |
import numpy as np | |
import os | |
import pickle | |
import gzip | |
from pathlib import Path | |
import argparse | |
import json | |
import csv | |
import glob | |
import time | |
from typing import List, Union, Optional, Tuple | |
class DINOEmbedder: | |
""" | |
A class for extracting DINOv2 embeddings from video frames or images. | |
""" | |
def __init__(self, dino_model_path: str, batch_size: int = 128, device: Optional[str] = None): | |
""" | |
Initialize the DINOEmbedder. | |
Args: | |
dino_model_path: Path to the fine-tuned DINOv2 model | |
batch_size: Batch size for processing frames | |
device: Device to use ('cuda' or 'cpu'). Auto-detected if None | |
""" | |
self.dino_model_path = dino_model_path | |
self.batch_size = batch_size | |
self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Initialize model | |
self.model = self._load_dino_model() | |
self.model.eval() | |
# Initialize transform | |
self.transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
print(f"DINOEmbedder initialized on device: {self.device}") | |
def _load_dino_model(self) -> nn.Module: | |
"""Load the fine-tuned DINOv2 model.""" | |
# Load the original DINOv2 model with the correct architecture | |
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg', pretrained=False) | |
# Load fine-tuned weights | |
pretrained = torch.load(self.dino_model_path, map_location=self.device) | |
# Make correct state dict for loading | |
new_state_dict = {} | |
for key, value in pretrained['teacher'].items(): | |
if 'dino_head' in key: | |
continue # Skip dino_head layers | |
else: | |
new_key = key.replace('backbone.', '') | |
new_state_dict[new_key] = value | |
# Change shape of pos_embed | |
pos_embed = nn.Parameter(torch.zeros(1, 257, 384)) | |
model.pos_embed = pos_embed | |
# Load state dict | |
model.load_state_dict(new_state_dict, strict=True) | |
# Move model to device | |
model.to(self.device) | |
return model | |
def _preprocess_frame(self, frame: np.ndarray) -> torch.Tensor: | |
"""Preprocess a single frame.""" | |
if isinstance(frame, np.ndarray): | |
image = Image.fromarray(frame) | |
else: | |
image = frame | |
tensor = self.transform(image) | |
# Ensure only RGB channels are considered | |
return tensor[:3] | |
def _preprocess_frames_batch(self, frames: List[np.ndarray]) -> torch.Tensor: | |
"""Preprocess a batch of frames.""" | |
batch_tensors = torch.stack([self._preprocess_frame(frame) for frame in frames]) | |
return batch_tensors.to(self.device) | |
def extract_embeddings_from_frames(self, frames: List[np.ndarray]) -> np.ndarray: | |
""" | |
Extract DINOv2 embeddings from a list of frames. | |
Args: | |
frames: List of frames as numpy arrays | |
Returns: | |
Numpy array of embeddings with shape (num_frames, embedding_dim) | |
""" | |
all_embeddings = [] | |
# Process frames in batches | |
for idx in range(0, len(frames), self.batch_size): | |
batch_frames = frames[idx:idx + self.batch_size] | |
# Preprocess batch | |
batch_tensors = self._preprocess_frames_batch(batch_frames) | |
# Extract embeddings | |
with torch.no_grad(): | |
batch_embeddings = self.model(batch_tensors).cpu().numpy() | |
all_embeddings.append(batch_embeddings) | |
# Concatenate all embeddings | |
embeddings = np.concatenate(all_embeddings, axis=0) | |
return embeddings | |
def extract_embeddings_from_video(self, video_input: Union[str, VideoReader], | |
target_size: Tuple[int, int] = (224, 224)) -> np.ndarray: | |
""" | |
Extract DINOv2 embeddings from a video. | |
Args: | |
video_input: Either a path to video file (str) or a VideoReader object | |
target_size: Target size for video frames (width, height) | |
Returns: | |
Numpy array of embeddings with shape (num_frames, embedding_dim) | |
""" | |
# Handle different input types | |
if isinstance(video_input, str): | |
video_path = Path(video_input) | |
if not video_path.exists(): | |
raise FileNotFoundError(f"Video file not found: {video_input}") | |
try: | |
vr = VideoReader(str(video_path), width=target_size[0], height=target_size[1]) | |
except Exception as e: | |
raise RuntimeError(f"Error loading video {video_input}: {e}") | |
# elif hasattr(video_input, 'get_batch'): | |
else: | |
vr = video_input | |
# else: | |
# raise TypeError("video_input must be either a file path (str) or a VideoReader object") | |
total_frames = len(vr) | |
all_embeddings = [] | |
# Process video in batches | |
for idx in range(0, total_frames, self.batch_size): | |
batch_indices = range(idx, min(idx + self.batch_size, total_frames)) | |
# batch_frames = vr.get_batch(batch_indices).asnumpy() | |
batch_frames = vr[batch_indices] | |
# Preprocess batch | |
batch_tensors = self._preprocess_frames_batch(batch_frames) | |
# Extract embeddings | |
with torch.no_grad(): | |
batch_embeddings = self.model(batch_tensors).cpu().numpy() | |
all_embeddings.append(batch_embeddings) | |
# Concatenate all embeddings | |
embeddings = np.concatenate(all_embeddings, axis=0) | |
return embeddings | |
def extract_embeddings_from_video_and_save(self, video_path: str, output_folder: str) -> str: | |
""" | |
Extract embeddings from video and save to file. | |
Args: | |
video_path: Path to the video file | |
output_folder: Folder to save the embeddings | |
Returns: | |
Path to the saved embeddings file | |
""" | |
# Create output folder if it doesn't exist | |
Path(output_folder).mkdir(parents=True, exist_ok=True) | |
# Extract embeddings | |
embeddings = self.extract_embeddings_from_video(video_path) | |
# Save embeddings | |
video_name = Path(video_path).stem | |
np_path = Path(output_folder) / f"{video_name}.npy" | |
np.save(np_path, embeddings) | |
return str(np_path) | |
def extract_embedding_from_single_image(self, image: Union[np.ndarray, Image.Image]) -> np.ndarray: | |
""" | |
Extract DINOv2 embedding from a single image. | |
Args: | |
image: Image as numpy array or PIL Image | |
Returns: | |
Numpy array of embedding with shape (1, embedding_dim) | |
""" | |
# Preprocess image | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
tensor = self.transform(image).unsqueeze(0).to(self.device) | |
# Extract embedding | |
with torch.no_grad(): | |
embedding = self.model(tensor).cpu().numpy() | |
return embedding | |
# Convenience functions for backward compatibility | |
def extract_embeddings_from_frames(frames: List[np.ndarray], dino_model_path: str, | |
batch_size: int = 128) -> np.ndarray: | |
""" | |
Convenience function to extract embeddings from frames. | |
Args: | |
frames: List of frames as numpy arrays | |
dino_model_path: Path to the fine-tuned DINOv2 model | |
batch_size: Batch size for processing | |
Returns: | |
Numpy array of embeddings | |
""" | |
embedder = DINOEmbedder(dino_model_path, batch_size) | |
return embedder.extract_embeddings_from_frames(frames) | |
def extract_embeddings_from_video(video_path: str, dino_model_path: str, | |
batch_size: int = 128) -> np.ndarray: | |
""" | |
Convenience function to extract embeddings from video. | |
Args: | |
video_path: Path to the video file | |
dino_model_path: Path to the fine-tuned DINOv2 model | |
batch_size: Batch size for processing | |
Returns: | |
Numpy array of embeddings | |
""" | |
embedder = DINOEmbedder(dino_model_path, batch_size) | |
return embedder.extract_embeddings_from_video(video_path) | |
def video_to_embeddings(video_path: str, output_folder: str, dino_path: str, batch_size: int = 128): | |
""" | |
Original function for backward compatibility with command-line usage. | |
""" | |
try: | |
embedder = DINOEmbedder(dino_path, batch_size) | |
embedder.extract_embeddings_from_video_and_save(video_path, output_folder) | |
except Exception as e: | |
print(f'Error processing {video_path}: {e}') | |
# Utility functions for batch processing | |
def get_mp4_files(directory: str) -> List[str]: | |
"""Get all MP4 files in a directory.""" | |
if not os.path.exists(directory): | |
raise FileNotFoundError(f'Directory not found: {directory}') | |
mp4_files = glob.glob(os.path.join(directory, '*.mp4')) | |
return [os.path.abspath(file) for file in mp4_files] | |
def load_file(filename: str): | |
"""Load a pickled and gzipped file.""" | |
with gzip.open(filename, "rb") as f: | |
return pickle.load(f) | |
def is_string_in_file(file_path: str, target_string: str) -> bool: | |
"""Check if a string exists in a file.""" | |
try: | |
with Path(file_path).open("r") as f: | |
for line in f: | |
if target_string in line: | |
return True | |
return False | |
except Exception as e: | |
print(f"Error: {e}") | |
return False | |
def main(): | |
"""Main function for command-line usage.""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--index', type=int, required=True, | |
help='index of the sub_list to work with') | |
parser.add_argument('--time_limit', type=int, required=True, | |
help='time limit in seconds') | |
parser.add_argument('--batch_size', type=int, required=True, | |
help='number of videos to process in this batch') | |
parser.add_argument('--files_list', type=str, required=True, | |
help='path to the files list file') | |
parser.add_argument('--output_folder', type=str, required=True, | |
help='path to the output folder') | |
parser.add_argument('--dino_path', type=str, required=True, | |
help='path to the dino model') | |
args = parser.parse_args() | |
start_time = time.time() | |
# Load files list | |
fixed_list = load_file(args.files_list) | |
# Create output folder if it doesn't exist | |
if not os.path.exists(args.output_folder): | |
os.makedirs(args.output_folder) | |
# Initialize embedder | |
embedder = DINOEmbedder(args.dino_path, batch_size=512) | |
# Process videos in batches | |
video_batches = [fixed_list[i:i + args.batch_size] for i in range(0, len(fixed_list), args.batch_size)] | |
print(f"Total number of video batches: {len(video_batches)}") | |
for video_path in video_batches[args.index]: | |
current_time = time.time() | |
if current_time - start_time > args.time_limit: | |
print("Time limit reached. Stopping execution.") | |
break | |
video_name = Path(video_path).stem | |
np_path = Path(args.output_folder) / f"{video_name}.npy" | |
if np_path.exists(): | |
print(f"Skipping {video_path} - output already exists") | |
continue | |
else: | |
try: | |
print(f"Processing {video_path}") | |
embedder.extract_embeddings_from_video_and_save(video_path, args.output_folder) | |
print(f"Successfully processed {video_path}") | |
except Exception as e: | |
print(f"Error processing {video_path}: {e}") | |
if __name__ == "__main__": | |
main() |