Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	File size: 1,163 Bytes
			
			| bb65ef0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | import torch
from typing import Tuple, Union
import torch
from einops import rearrange
from torch import nn
def make_triple(value: Union[int, Tuple[int, int, int]]) -> Tuple[int, int, int]:
    value = (value,) * 3 if isinstance(value, int) else value
    assert len(value) == 3
    return value
class AudioPack(nn.Module):
    def __init__(
            self,
            in_channels: int,
            patch_size: Union[int, Tuple[int, int, int]],
            dim: int,
            layernorm=False,
    ):
        super().__init__()
        t, h, w = make_triple(patch_size)
        self.patch_size = t, h, w
        self.proj = nn.Linear(in_channels * t * h * w, dim)
        if layernorm:
            self.norm_out = nn.LayerNorm(dim)
        else:
            self.norm_out = None
    def forward(
            self,
            vid: torch.Tensor,
    ) -> torch.Tensor:
        t, h, w = self.patch_size
        vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w)
        vid = self.proj(vid)
        if self.norm_out is not None:
            vid = self.norm_out(vid)
        return vid | 
