sachin commited on
Commit
66fd0f6
·
verified ·
1 Parent(s): 82edaab

newer model

Browse files
Files changed (1) hide show
  1. model.py +85 -0
model.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class Model(PreTrainedModel):
3
+ config_class = VLMConfig
4
+
5
+ def __init__(self, config: VLMConfig, image_model, language_model, num_projections: int, tokenizer, prepend_text: str, image_tokens:int):
6
+ super().__init__(config)
7
+ self.image_model = image_model
8
+ self.language_model = language_model
9
+ self.projector = nn.Sequential(
10
+ *projection_layers(image_model.num_features, language_model.config.hidden_size, num_projections)
11
+ )
12
+
13
+ self.tokenizer = tokenizer
14
+ self.eos_token = tokenizer.eos_token
15
+ self.prepend_text = prepend_text
16
+
17
+ self.image_tokens = image_tokens
18
+
19
+ input_ids = tokenizer(prepend_text, return_tensors="pt").input_ids
20
+ eos_token_index = (input_ids[0] == tokenizer.eos_token_id).nonzero(as_tuple=True)[0].item()
21
+ text_embeddings = self.language_model.get_input_embeddings()(input_ids).detach()
22
+ self.prepend_embeddings = text_embeddings[:, :eos_token_index]
23
+ self.postpend_embeddings = text_embeddings[:, eos_token_index:]
24
+ self.attention_mask = torch.ones(1, text_embeddings.shape[1] + image_tokens)
25
+ self.labels = torch.full((1, self.attention_mask.shape[1]), LABEL_MASK)
26
+
27
+ def project_image_features(self, images: torch.Tensor):
28
+ image_features = self.image_model.forward_features(images)
29
+ image_features = einops.rearrange(image_features, "bs dim w h -> bs (w h) dim")
30
+ encoder_outputs = self.projector(image_features)
31
+ return encoder_outputs
32
+
33
+ def forward(self, images: torch.Tensor, tokenized_captions: dict[str, torch.Tensor]):
34
+ image_outputs = self.project_image_features(images)
35
+ caption_embeddings = self.language_model.get_input_embeddings()(tokenized_captions.input_ids).detach()
36
+ device = images.device
37
+ embeddings = torch.cat(
38
+ [
39
+ self.prepend_embeddings.to(device).expand(len(images), -1, -1),
40
+ image_outputs,
41
+ self.postpend_embeddings.to(device).expand(len(images), -1, -1),
42
+ caption_embeddings,
43
+ ],
44
+ dim=1,
45
+ )
46
+ attention_mask = torch.cat(
47
+ [
48
+ self.attention_mask.to(device).expand(len(images), -1),
49
+ tokenized_captions.attention_mask
50
+ ],
51
+ dim=1
52
+ )
53
+ labels = torch.cat(
54
+ [
55
+ self.labels.to(device).expand(len(images), -1),
56
+ tokenized_captions.input_ids.clone()
57
+ ],
58
+ dim=1,
59
+ )
60
+ labels[attention_mask == 0] = LABEL_MASK
61
+
62
+ return self.language_model(
63
+ inputs_embeds=embeddings,
64
+ attention_mask=attention_mask,
65
+ labels=labels,
66
+ )
67
+
68
+ def generate(self, images: torch.Tensor, generator_kwargs: dict[str, Union[int, float]]):
69
+ image_outputs = self.project_image_features(images)
70
+ device = images.device
71
+ embeddings = torch.cat(
72
+ [
73
+ self.prepend_embeddings.to(device).expand(len(images), -1, -1),
74
+ image_outputs,
75
+ self.postpend_embeddings.to(device).expand(len(images), -1, -1),
76
+ ],
77
+ dim=1,
78
+ )
79
+ attention_mask = self.attention_mask.to(device).expand(len(images), -1)
80
+ return self.language_model.generate(
81
+ inputs_embeds=embeddings,
82
+ attention_mask=attention_mask,
83
+ eos_token_id=self.tokenizer.eos_token_id,
84
+ **generator_kwargs
85
+ )