Spaces:
Paused
Paused
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| from itertools import chain | |
| from typing import Callable, Dict, List, Tuple | |
| import einops | |
| import torch | |
| def flatten( | |
| hid: List[torch.FloatTensor], # List of (*** c) | |
| ) -> Tuple[ | |
| torch.FloatTensor, # (L c) | |
| torch.LongTensor, # (b n) | |
| ]: | |
| assert len(hid) > 0 | |
| shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) | |
| hid = torch.cat([x.flatten(0, -2) for x in hid]) | |
| return hid, shape | |
| def unflatten( | |
| hid: torch.FloatTensor, # (L c) or (L ... c) | |
| hid_shape: torch.LongTensor, # (b n) | |
| ) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) | |
| hid_len = hid_shape.prod(-1) | |
| hid = hid.split(hid_len.tolist()) | |
| hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] | |
| return hid | |
| def concat( | |
| vid: torch.FloatTensor, # (VL ... c) | |
| txt: torch.FloatTensor, # (TL ... c) | |
| vid_len: torch.LongTensor, # (b) | |
| txt_len: torch.LongTensor, # (b) | |
| ) -> torch.FloatTensor: # (L ... c) | |
| vid = torch.split(vid, vid_len.tolist()) | |
| txt = torch.split(txt, txt_len.tolist()) | |
| return torch.cat(list(chain(*zip(vid, txt)))) | |
| def concat_idx( | |
| vid_len: torch.LongTensor, # (b) | |
| txt_len: torch.LongTensor, # (b) | |
| ) -> Tuple[ | |
| Callable, | |
| Callable, | |
| ]: | |
| device = vid_len.device | |
| vid_idx = torch.arange(vid_len.sum(), device=device) | |
| txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) | |
| tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) | |
| src_idx = torch.argsort(tgt_idx) | |
| return ( | |
| lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), | |
| lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), | |
| ) | |
| def unconcat( | |
| all: torch.FloatTensor, # (L ... c) | |
| vid_len: torch.LongTensor, # (b) | |
| txt_len: torch.LongTensor, # (b) | |
| ) -> Tuple[ | |
| torch.FloatTensor, # (VL ... c) | |
| torch.FloatTensor, # (TL ... c) | |
| ]: | |
| interleave_len = list(chain(*zip(vid_len.tolist(), txt_len.tolist()))) | |
| all = all.split(interleave_len) | |
| vid = torch.cat(all[0::2]) | |
| txt = torch.cat(all[1::2]) | |
| return vid, txt | |
| def repeat_concat( | |
| vid: torch.FloatTensor, # (VL ... c) | |
| txt: torch.FloatTensor, # (TL ... c) | |
| vid_len: torch.LongTensor, # (n*b) | |
| txt_len: torch.LongTensor, # (b) | |
| txt_repeat: List, # (n) | |
| ) -> torch.FloatTensor: # (L ... c) | |
| vid = torch.split(vid, vid_len.tolist()) | |
| txt = torch.split(txt, txt_len.tolist()) | |
| txt = [[x] * n for x, n in zip(txt, txt_repeat)] | |
| txt = list(chain(*txt)) | |
| return torch.cat(list(chain(*zip(vid, txt)))) | |
| def repeat_concat_idx( | |
| vid_len: torch.LongTensor, # (n*b) | |
| txt_len: torch.LongTensor, # (b) | |
| txt_repeat: torch.LongTensor, # (n) | |
| ) -> Tuple[ | |
| Callable, | |
| Callable, | |
| ]: | |
| device = vid_len.device | |
| vid_idx = torch.arange(vid_len.sum(), device=device) | |
| txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) | |
| txt_repeat_list = txt_repeat.tolist() | |
| tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) | |
| src_idx = torch.argsort(tgt_idx) | |
| txt_idx_len = len(tgt_idx) - len(vid_idx) | |
| repeat_txt_len = (txt_len * txt_repeat).tolist() | |
| def unconcat_coalesce(all): | |
| """ | |
| Un-concat vid & txt, and coalesce the repeated txt. | |
| e.g. vid [0 1 2 3 4 5 6 7 8] -> 3 splits -> [0 1 2] [3 4 5] [6 7 8] | |
| txt [9 10] | |
| repeat_concat ==> [0 1 2 9 10 3 4 5 9 10 6 7 8 9 10] | |
| 1. argsort re-index ==> [0 1 2 3 4 5 6 7 8 9 9 9 10 10 10] | |
| split ==> vid_out [0 1 2 3 4 5 6 7 8] txt_out [9 9 9 10 10 10] | |
| 2. reshape & mean for each sample to coalesce the repeated txt. | |
| """ | |
| vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) | |
| txt_out_coalesced = [] | |
| for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): | |
| txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) | |
| txt_out_coalesced.append(txt) | |
| return vid_out, torch.cat(txt_out_coalesced) | |
| # Note: Backward of torch.index_select is non-deterministic when existing repeated index, | |
| # the difference may cumulative like torch.repeat_interleave, so we use vanilla index here. | |
| return ( | |
| lambda vid, txt: torch.cat([vid, txt])[tgt_idx], | |
| lambda all: unconcat_coalesce(all), | |
| ) | |
| def rearrange( | |
| hid: torch.FloatTensor, # (L c) | |
| hid_shape: torch.LongTensor, # (b n) | |
| pattern: str, | |
| **kwargs: Dict[str, int], | |
| ) -> Tuple[ | |
| torch.FloatTensor, | |
| torch.LongTensor, | |
| ]: | |
| return flatten([einops.rearrange(h, pattern, **kwargs) for h in unflatten(hid, hid_shape)]) | |
| def rearrange_idx( | |
| hid_shape: torch.LongTensor, # (b n) | |
| pattern: str, | |
| **kwargs: Dict[str, int], | |
| ) -> Tuple[Callable, Callable, torch.LongTensor]: | |
| hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) | |
| tgt_idx, tgt_shape = rearrange(hid_idx, hid_shape, pattern, **kwargs) | |
| tgt_idx = tgt_idx.squeeze(-1) | |
| src_idx = torch.argsort(tgt_idx) | |
| return ( | |
| lambda hid: torch.index_select(hid, 0, tgt_idx), | |
| lambda hid: torch.index_select(hid, 0, src_idx), | |
| tgt_shape, | |
| ) | |
| def repeat( | |
| hid: torch.FloatTensor, # (L c) | |
| hid_shape: torch.LongTensor, # (b n) | |
| pattern: str, | |
| **kwargs: Dict[str, torch.LongTensor], # (b) | |
| ) -> Tuple[ | |
| torch.FloatTensor, | |
| torch.LongTensor, | |
| ]: | |
| hid = unflatten(hid, hid_shape) | |
| kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] | |
| return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) | |
| def pack( | |
| samples: List[torch.Tensor], # List of (h w c). | |
| ) -> Tuple[ | |
| List[torch.Tensor], # groups [(b1 h1 w1 c1), (b2 h2 w2 c2)] | |
| List[List[int]], # reversal indices. | |
| ]: | |
| batches = {} | |
| indices = {} | |
| for i, sample in enumerate(samples): | |
| shape = sample.shape | |
| batches[shape] = batches.get(shape, []) | |
| indices[shape] = indices.get(shape, []) | |
| batches[shape].append(sample) | |
| indices[shape].append(i) | |
| batches = list(map(torch.stack, batches.values())) | |
| indices = list(indices.values()) | |
| return batches, indices | |
| def unpack( | |
| batches: List[torch.Tensor], | |
| indices: List[List[int]], | |
| ) -> List[torch.Tensor]: | |
| samples = [None] * (max(chain(*indices)) + 1) | |
| for batch, index in zip(batches, indices): | |
| for sample, i in zip(batch.unbind(), index): | |
| samples[i] = sample | |
| return samples | |
| def window( | |
| hid: torch.FloatTensor, # (L c) | |
| hid_shape: torch.LongTensor, # (b n) | |
| window_fn: Callable[[torch.Tensor], List[torch.Tensor]], | |
| ): | |
| hid = unflatten(hid, hid_shape) | |
| hid = list(map(window_fn, hid)) | |
| hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) | |
| hid, hid_shape = flatten(list(chain(*hid))) | |
| return hid, hid_shape, hid_windows | |
| def window_idx( | |
| hid_shape: torch.LongTensor, # (b n) | |
| window_fn: Callable[[torch.Tensor], List[torch.Tensor]], | |
| ): | |
| hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) | |
| tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) | |
| tgt_idx = tgt_idx.squeeze(-1) | |
| src_idx = torch.argsort(tgt_idx) | |
| return ( | |
| lambda hid: torch.index_select(hid, 0, tgt_idx), | |
| lambda hid: torch.index_select(hid, 0, src_idx), | |
| tgt_shape, | |
| tgt_windows, | |
| ) | |