Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| class ReduxImageEncoder(torch.nn.Module): | |
| def __init__( | |
| self, | |
| redux_dim: int = 1152, | |
| txt_in_features: int = 4096, | |
| device=None, | |
| dtype=None, | |
| ) -> None: | |
| super().__init__() | |
| self.redux_dim = redux_dim | |
| self.device = device | |
| self.dtype = dtype | |
| self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype) | |
| self.redux_down = nn.Linear( | |
| txt_in_features * 3, txt_in_features, dtype=dtype) | |
| def forward(self, sigclip_embeds) -> torch.Tensor: | |
| x = self.redux_up(sigclip_embeds) | |
| x = torch.nn.functional.silu(x) | |
| projected_x = self.redux_down(x) | |
| return projected_x | |