newer model
Browse files
model.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class Model(PreTrainedModel):
|
3 |
+
config_class = VLMConfig
|
4 |
+
|
5 |
+
def __init__(self, config: VLMConfig, image_model, language_model, num_projections: int, tokenizer, prepend_text: str, image_tokens:int):
|
6 |
+
super().__init__(config)
|
7 |
+
self.image_model = image_model
|
8 |
+
self.language_model = language_model
|
9 |
+
self.projector = nn.Sequential(
|
10 |
+
*projection_layers(image_model.num_features, language_model.config.hidden_size, num_projections)
|
11 |
+
)
|
12 |
+
|
13 |
+
self.tokenizer = tokenizer
|
14 |
+
self.eos_token = tokenizer.eos_token
|
15 |
+
self.prepend_text = prepend_text
|
16 |
+
|
17 |
+
self.image_tokens = image_tokens
|
18 |
+
|
19 |
+
input_ids = tokenizer(prepend_text, return_tensors="pt").input_ids
|
20 |
+
eos_token_index = (input_ids[0] == tokenizer.eos_token_id).nonzero(as_tuple=True)[0].item()
|
21 |
+
text_embeddings = self.language_model.get_input_embeddings()(input_ids).detach()
|
22 |
+
self.prepend_embeddings = text_embeddings[:, :eos_token_index]
|
23 |
+
self.postpend_embeddings = text_embeddings[:, eos_token_index:]
|
24 |
+
self.attention_mask = torch.ones(1, text_embeddings.shape[1] + image_tokens)
|
25 |
+
self.labels = torch.full((1, self.attention_mask.shape[1]), LABEL_MASK)
|
26 |
+
|
27 |
+
def project_image_features(self, images: torch.Tensor):
|
28 |
+
image_features = self.image_model.forward_features(images)
|
29 |
+
image_features = einops.rearrange(image_features, "bs dim w h -> bs (w h) dim")
|
30 |
+
encoder_outputs = self.projector(image_features)
|
31 |
+
return encoder_outputs
|
32 |
+
|
33 |
+
def forward(self, images: torch.Tensor, tokenized_captions: dict[str, torch.Tensor]):
|
34 |
+
image_outputs = self.project_image_features(images)
|
35 |
+
caption_embeddings = self.language_model.get_input_embeddings()(tokenized_captions.input_ids).detach()
|
36 |
+
device = images.device
|
37 |
+
embeddings = torch.cat(
|
38 |
+
[
|
39 |
+
self.prepend_embeddings.to(device).expand(len(images), -1, -1),
|
40 |
+
image_outputs,
|
41 |
+
self.postpend_embeddings.to(device).expand(len(images), -1, -1),
|
42 |
+
caption_embeddings,
|
43 |
+
],
|
44 |
+
dim=1,
|
45 |
+
)
|
46 |
+
attention_mask = torch.cat(
|
47 |
+
[
|
48 |
+
self.attention_mask.to(device).expand(len(images), -1),
|
49 |
+
tokenized_captions.attention_mask
|
50 |
+
],
|
51 |
+
dim=1
|
52 |
+
)
|
53 |
+
labels = torch.cat(
|
54 |
+
[
|
55 |
+
self.labels.to(device).expand(len(images), -1),
|
56 |
+
tokenized_captions.input_ids.clone()
|
57 |
+
],
|
58 |
+
dim=1,
|
59 |
+
)
|
60 |
+
labels[attention_mask == 0] = LABEL_MASK
|
61 |
+
|
62 |
+
return self.language_model(
|
63 |
+
inputs_embeds=embeddings,
|
64 |
+
attention_mask=attention_mask,
|
65 |
+
labels=labels,
|
66 |
+
)
|
67 |
+
|
68 |
+
def generate(self, images: torch.Tensor, generator_kwargs: dict[str, Union[int, float]]):
|
69 |
+
image_outputs = self.project_image_features(images)
|
70 |
+
device = images.device
|
71 |
+
embeddings = torch.cat(
|
72 |
+
[
|
73 |
+
self.prepend_embeddings.to(device).expand(len(images), -1, -1),
|
74 |
+
image_outputs,
|
75 |
+
self.postpend_embeddings.to(device).expand(len(images), -1, -1),
|
76 |
+
],
|
77 |
+
dim=1,
|
78 |
+
)
|
79 |
+
attention_mask = self.attention_mask.to(device).expand(len(images), -1)
|
80 |
+
return self.language_model.generate(
|
81 |
+
inputs_embeds=embeddings,
|
82 |
+
attention_mask=attention_mask,
|
83 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
84 |
+
**generator_kwargs
|
85 |
+
)
|