File size: 1,163 Bytes
bb65ef0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from typing import Tuple, Union
import torch
from einops import rearrange
from torch import nn


def make_triple(value: Union[int, Tuple[int, int, int]]) -> Tuple[int, int, int]:
    value = (value,) * 3 if isinstance(value, int) else value
    assert len(value) == 3
    return value


class AudioPack(nn.Module):
    def __init__(

            self,

            in_channels: int,

            patch_size: Union[int, Tuple[int, int, int]],

            dim: int,

            layernorm=False,

    ):
        super().__init__()
        t, h, w = make_triple(patch_size)
        self.patch_size = t, h, w
        self.proj = nn.Linear(in_channels * t * h * w, dim)
        if layernorm:
            self.norm_out = nn.LayerNorm(dim)
        else:
            self.norm_out = None

    def forward(

            self,

            vid: torch.Tensor,

    ) -> torch.Tensor:
        t, h, w = self.patch_size
        vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w)
        vid = self.proj(vid)
        if self.norm_out is not None:
            vid = self.norm_out(vid)
        return vid