Spaces:
Running
on
A100
Running
on
A100
| # Copyright (c) 2025 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
| # LICENSE is in incl_licenses directory. | |
| from functools import partial | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import torch | |
| from .basic import BasicVideoEncoder | |
| __all__ = ["TSPVideoEncoder"] | |
| def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor: | |
| return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1) | |
| class TSPVideoEncoder(BasicVideoEncoder): | |
| def __init__( | |
| self, | |
| parent: torch.nn.Module, | |
| pool_sizes: List[Tuple[int, int, int]], | |
| start_tokens: Optional[str] = None, | |
| end_tokens: Optional[str] = "\n", | |
| sep_tokens: Optional[str] = None, | |
| ) -> None: | |
| super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens) | |
| self.pool_sizes = pool_sizes | |
| self.sep_tokens = sep_tokens | |
| def _process_features( | |
| self, | |
| inputs: torch.Tensor, | |
| start_token_embeds: Optional[torch.Tensor], | |
| end_token_embeds: Optional[torch.Tensor], | |
| sep_token_embeds: Optional[torch.Tensor], | |
| ) -> torch.Tensor: | |
| nt, ns = inputs.shape[:2] | |
| nl = int(ns**0.5) | |
| outputs = [] | |
| for pool_size in self.pool_sizes: | |
| features = inputs.view(nt, nl, nl, -1) | |
| for dim, p in enumerate(pool_size): | |
| features = pool(features, p, dim=dim) | |
| features = features.flatten(1, 2) | |
| features = super()._process_features( | |
| features, | |
| start_token_embeds=start_token_embeds, | |
| end_token_embeds=end_token_embeds, | |
| ) | |
| if sep_token_embeds is not None: | |
| features = torch.cat([features, sep_token_embeds], dim=0) | |
| outputs.append(features) | |
| return torch.cat(outputs, dim=0) | |
| def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: | |
| num_frames = [video.shape[0] for video in videos] | |
| images = torch.cat(videos, dim=0) | |
| features = self.parent.encode_images(images) | |
| features = torch.split(features, num_frames) | |
| process_features = partial( | |
| self._process_features, | |
| start_token_embeds=self.embed_tokens(self.start_tokens), | |
| end_token_embeds=self.embed_tokens(self.end_tokens), | |
| sep_token_embeds=self.embed_tokens(self.sep_tokens), | |
| ) | |
| return [process_features(f) for f in features] | |