vikhyatk commited on
Commit
d8914b8
·
verified ·
1 Parent(s): ea7d26c

Upload Moondream

Browse files
Files changed (5) hide show
  1. config.json +15 -0
  2. configuration_moondream.py +1 -1
  3. modeling_phi.py +2 -33
  4. moondream.py +6 -4
  5. vision_encoder.py +67 -27
config.json CHANGED
@@ -1,4 +1,5 @@
1
  {
 
2
  "architectures": [
3
  "Moondream"
4
  ],
@@ -10,6 +11,20 @@
10
  "phi_config": {
11
  "model_type": "phi"
12
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  "torch_dtype": "float16",
14
  "transformers_version": "4.38.2"
15
  }
 
1
  {
2
+ "_name_or_path": "vikhyatk/moondream2",
3
  "architectures": [
4
  "Moondream"
5
  ],
 
11
  "phi_config": {
12
  "model_type": "phi"
13
  },
14
+ "text_config": {
15
+ "architectures": [
16
+ "Moondream"
17
+ ],
18
+ "auto_map": {
19
+ "AutoConfig": "configuration_moondream.MoondreamConfig",
20
+ "AutoModelForCausalLM": "moondream.Moondream"
21
+ },
22
+ "model_type": "phi",
23
+ "phi_config": {
24
+ "model_type": "phi"
25
+ },
26
+ "torch_dtype": "float16"
27
+ },
28
  "torch_dtype": "float16",
29
  "transformers_version": "4.38.2"
30
  }
configuration_moondream.py CHANGED
@@ -94,5 +94,5 @@ class MoondreamConfig(PretrainedConfig):
94
  model_type = "moondream1"
95
 
96
  def __init__(self, **kwargs):
97
- self.phi_config = PhiConfig(**kwargs)
98
  super().__init__(**kwargs)
 
94
  model_type = "moondream1"
95
 
96
  def __init__(self, **kwargs):
97
+ self.text_config = PhiConfig(**kwargs)
98
  super().__init__(**kwargs)
modeling_phi.py CHANGED
@@ -400,40 +400,10 @@ class PhiAttention(nn.Module):
400
  key_states = repeat_kv(key_states, self.num_key_value_groups)
401
  value_states = repeat_kv(value_states, self.num_key_value_groups)
402
 
403
- # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
404
- attn_weights = torch.matmul(
405
- query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
406
- ) / math.sqrt(self.head_dim)
407
-
408
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
409
- raise ValueError(
410
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
411
- f" {attn_weights.size()}"
412
- )
413
-
414
- if attention_mask is not None:
415
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
416
- raise ValueError(
417
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
418
- )
419
- attn_weights = attn_weights + attention_mask
420
-
421
- # upcast attention to fp32
422
- attn_weights = nn.functional.softmax(
423
- attn_weights, dim=-1, dtype=torch.float32
424
- ).to(value_states.dtype)
425
- attn_weights = nn.functional.dropout(
426
- attn_weights, p=self.attention_dropout, training=self.training
427
  )
428
 
429
- attn_output = torch.matmul(attn_weights, value_states)
430
-
431
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
432
- raise ValueError(
433
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
434
- f" {attn_output.size()}"
435
- )
436
-
437
  attn_output = attn_output.transpose(1, 2).contiguous()
438
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
439
 
@@ -1115,7 +1085,6 @@ class PhiForCausalLM(PhiPreTrainedModel):
1115
 
1116
  hidden_states = outputs[0]
1117
  logits = self.lm_head(hidden_states)
1118
- logits = logits.float()
1119
 
1120
  loss = None
1121
  if labels is not None:
 
400
  key_states = repeat_kv(key_states, self.num_key_value_groups)
401
  value_states = repeat_kv(value_states, self.num_key_value_groups)
402
 
403
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
404
+ query_states, key_states, value_states, attn_mask=attention_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  )
406
 
 
 
 
 
 
 
 
 
407
  attn_output = attn_output.transpose(1, 2).contiguous()
408
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
409
 
 
1085
 
1086
  hidden_states = outputs[0]
1087
  logits = self.lm_head(hidden_states)
 
1088
 
1089
  loss = None
1090
  if labels is not None:
moondream.py CHANGED
@@ -12,14 +12,16 @@ class Moondream(PreTrainedModel):
12
 
13
  def __init__(self, config):
14
  super().__init__(config)
15
- self.vision_encoder = VisionEncoder()
 
 
16
 
17
- if type(config.phi_config) == dict:
18
  phi_config = PhiConfig(
19
- **config.phi_config, attn_implementation=config._attn_implementation
20
  )
21
  else:
22
- phi_config = config.phi_config
23
  self.text_model = PhiForCausalLM(phi_config)
24
 
25
  @property
 
12
 
13
  def __init__(self, config):
14
  super().__init__(config)
15
+ self.vision_encoder = VisionEncoder(
16
+ use_flash_attn=config._attn_implementation == "flash_attention_2"
17
+ )
18
 
19
+ if type(config.text_config) == dict:
20
  phi_config = PhiConfig(
21
+ **config.text_config, attn_implementation=config._attn_implementation
22
  )
23
  else:
24
+ phi_config = config.text_config
25
  self.text_model = PhiForCausalLM(phi_config)
26
 
27
  @property
vision_encoder.py CHANGED
@@ -10,10 +10,20 @@ from torchvision.transforms.v2 import (
10
  ToDtype,
11
  Normalize,
12
  )
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  class Attention(nn.Module):
16
- def __init__(self, dim, num_heads=16):
 
17
  super().__init__()
18
  assert dim % num_heads == 0, "dim should be divisible by num_heads"
19
 
@@ -23,6 +33,11 @@ class Attention(nn.Module):
23
  self.qkv = nn.Linear(dim, dim * 3)
24
  self.proj = nn.Linear(dim, dim)
25
 
 
 
 
 
 
26
  torch.nn.init.kaiming_normal_(
27
  self.qkv.weight, mode="fan_in", nonlinearity="relu"
28
  )
@@ -31,25 +46,36 @@ class Attention(nn.Module):
31
  )
32
 
33
  def forward(self, x: torch.Tensor) -> torch.Tensor:
34
- B, N, C = x.shape
35
- qkv = (
36
- self.qkv(x)
37
- .reshape(B, N, 3, self.num_heads, self.head_dim)
38
- .permute(2, 0, 3, 1, 4)
39
- )
40
- q, k, v = qkv.unbind(0)
41
-
42
- x = F.scaled_dot_product_attention(q, k, v)
43
-
44
- x = x.transpose(1, 2).reshape(B, N, C)
45
- x = self.proj(x)
46
- return x
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  class VitBlock(nn.Module):
50
- def __init__(self, embed_dim):
 
51
  super().__init__()
52
- self.attn = Attention(embed_dim)
53
  self.mlp = MLP(embed_dim, 4304)
54
  self.norm1 = nn.LayerNorm(embed_dim)
55
  self.norm2 = nn.LayerNorm(embed_dim)
@@ -62,7 +88,7 @@ class VitBlock(nn.Module):
62
 
63
  class VisionTransformer(nn.Module):
64
 
65
- def __init__(self):
66
  super().__init__()
67
 
68
  embed_len = 729
@@ -70,7 +96,9 @@ class VisionTransformer(nn.Module):
70
 
71
  self.patch_embed = LinearPatchEmbedding()
72
  self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
73
- self.blocks = nn.Sequential(*[VitBlock(embed_dim) for _ in range(27)])
 
 
74
  self.norm = nn.LayerNorm(embed_dim)
75
 
76
  def forward(self, x):
@@ -83,9 +111,9 @@ class VisionTransformer(nn.Module):
83
 
84
  class EncoderWrapper(nn.Module):
85
 
86
- def __init__(self):
87
  super().__init__()
88
- self.model = nn.ModuleDict({"visual": VisionTransformer()})
89
 
90
  def forward(self, x):
91
  return self.model["visual"](x)
@@ -98,6 +126,13 @@ class LinearPatchEmbedding(nn.Module):
98
  self.linear = nn.Linear(588, 1152)
99
 
100
  def forward(self, x):
 
 
 
 
 
 
 
101
  return self.linear(x)
102
 
103
 
@@ -148,10 +183,11 @@ class VisionProjection(nn.Module):
148
 
149
 
150
  class VisionEncoder(nn.Module):
151
- def __init__(self) -> None:
 
152
  super().__init__()
153
 
154
- self.encoder = EncoderWrapper()
155
  self.projection = VisionProjection()
156
 
157
  self.preprocess = Compose(
@@ -172,16 +208,20 @@ class VisionEncoder(nn.Module):
172
  return self.projection.mlp.fc1.weight.dtype
173
 
174
  def __call__(self, images) -> torch.Tensor:
175
- if not isinstance(images, list):
176
  images = [images]
177
 
178
  with torch.no_grad():
179
- x = torch.stack(
180
- [self.preprocess(image.convert("RGB")) for image in images]
181
- ).to(self.device, dtype=self.dtype)
 
 
182
 
183
- x = rearrange(x, "b c (h p1) (w p2) -> b (h w) (c p1 p2)", p1=14, p2=14)
 
184
 
 
185
  x = self.encoder(x)
186
  x = self.projection(x)
187
 
 
10
  ToDtype,
11
  Normalize,
12
  )
13
+ from transformers.utils import is_flash_attn_2_available
14
+
15
+ try:
16
+ if is_flash_attn_2_available():
17
+ from flash_attn.modules.mha import FlashSelfAttention
18
+ else:
19
+ FlashSelfAttention = None
20
+ except ImportError:
21
+ FlashSelfAttention = None
22
 
23
 
24
  class Attention(nn.Module):
25
+
26
+ def __init__(self, dim, num_heads=16, use_flash_attn=False):
27
  super().__init__()
28
  assert dim % num_heads == 0, "dim should be divisible by num_heads"
29
 
 
33
  self.qkv = nn.Linear(dim, dim * 3)
34
  self.proj = nn.Linear(dim, dim)
35
 
36
+ if use_flash_attn and FlashSelfAttention is not None:
37
+ self.flash_attn = FlashSelfAttention()
38
+ else:
39
+ self.flash_attn = None
40
+
41
  torch.nn.init.kaiming_normal_(
42
  self.qkv.weight, mode="fan_in", nonlinearity="relu"
43
  )
 
46
  )
47
 
48
  def forward(self, x: torch.Tensor) -> torch.Tensor:
49
+ if self.flash_attn is not None:
50
+ qkv = self.qkv(x)
51
+ qkv = rearrange(
52
+ qkv, "... (three h d) -> ... three h d", three=3, h=self.num_heads
53
+ )
54
+ attn_output = self.flash_attn(qkv)
55
+ output = rearrange(attn_output, "... h d -> ... (h d)")
56
+ output = self.proj(output)
57
+ return output
58
+ else:
59
+ B, N, C = x.shape
60
+ qkv = (
61
+ self.qkv(x)
62
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
63
+ .permute(2, 0, 3, 1, 4)
64
+ )
65
+ q, k, v = qkv.unbind(0)
66
+
67
+ x = F.scaled_dot_product_attention(q, k, v)
68
+
69
+ x = x.transpose(1, 2).reshape(B, N, C)
70
+ x = self.proj(x)
71
+ return x
72
 
73
 
74
  class VitBlock(nn.Module):
75
+
76
+ def __init__(self, embed_dim, use_flash_attn=False):
77
  super().__init__()
78
+ self.attn = Attention(embed_dim, use_flash_attn=use_flash_attn)
79
  self.mlp = MLP(embed_dim, 4304)
80
  self.norm1 = nn.LayerNorm(embed_dim)
81
  self.norm2 = nn.LayerNorm(embed_dim)
 
88
 
89
  class VisionTransformer(nn.Module):
90
 
91
+ def __init__(self, use_flash_attn=False):
92
  super().__init__()
93
 
94
  embed_len = 729
 
96
 
97
  self.patch_embed = LinearPatchEmbedding()
98
  self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
99
+ self.blocks = nn.Sequential(
100
+ *[VitBlock(embed_dim, use_flash_attn=use_flash_attn) for _ in range(27)]
101
+ )
102
  self.norm = nn.LayerNorm(embed_dim)
103
 
104
  def forward(self, x):
 
111
 
112
  class EncoderWrapper(nn.Module):
113
 
114
+ def __init__(self, use_flash_attn=False):
115
  super().__init__()
116
+ self.model = nn.ModuleDict({"visual": VisionTransformer(use_flash_attn)})
117
 
118
  def forward(self, x):
119
  return self.model["visual"](x)
 
126
  self.linear = nn.Linear(588, 1152)
127
 
128
  def forward(self, x):
129
+ b, c, hp1, wp2 = x.shape
130
+ p1, p2 = 14, 14
131
+ h, w = hp1 // p1, wp2 // p2
132
+ x = x.reshape(b, c, h, p1, w, p2)
133
+ x = x.permute(0, 2, 4, 1, 3, 5)
134
+ x = x.reshape(b, h * w, c * p1 * p2)
135
+
136
  return self.linear(x)
137
 
138
 
 
183
 
184
 
185
  class VisionEncoder(nn.Module):
186
+
187
+ def __init__(self, use_flash_attn=False):
188
  super().__init__()
189
 
190
+ self.encoder = EncoderWrapper(use_flash_attn)
191
  self.projection = VisionProjection()
192
 
193
  self.preprocess = Compose(
 
208
  return self.projection.mlp.fc1.weight.dtype
209
 
210
  def __call__(self, images) -> torch.Tensor:
211
+ if not isinstance(images, list) and not isinstance(images, torch.Tensor):
212
  images = [images]
213
 
214
  with torch.no_grad():
215
+ # Skip preprocess if images are already tensors
216
+ if not isinstance(images, torch.Tensor) and not isinstance(
217
+ images[0], torch.Tensor
218
+ ):
219
+ images = [self.preprocess(image.convert("RGB")) for image in images]
220
 
221
+ if isinstance(images, list):
222
+ images = torch.stack(images)
223
 
224
+ x = images.to(self.device, dtype=self.dtype)
225
  x = self.encoder(x)
226
  x = self.projection(x)
227