vikhyatk commited on
Commit
1ad313c
·
verified ·
1 Parent(s): e32b732

Upload HfMoondream

Browse files
Files changed (4) hide show
  1. config.json +1 -1
  2. generation_config.json +1 -1
  3. model.safetensors +1 -1
  4. moondream.py +101 -38
config.json CHANGED
@@ -9,5 +9,5 @@
9
  "config": {},
10
  "model_type": "moondream1",
11
  "torch_dtype": "float16",
12
- "transformers_version": "4.48.0"
13
  }
 
9
  "config": {},
10
  "model_type": "moondream1",
11
  "torch_dtype": "float16",
12
+ "transformers_version": "4.44.0"
13
  }
generation_config.json CHANGED
@@ -1,4 +1,4 @@
1
  {
2
  "_from_model_config": true,
3
- "transformers_version": "4.48.0"
4
  }
 
1
  {
2
  "_from_model_config": true,
3
+ "transformers_version": "4.44.0"
4
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fadcffea8c17fe8a20ea68af3a013cf3184a63787ee4453cc9eb75206c7c1f9b
3
  size 3854538376
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96dce588e4a319fde7af3c70fbf27e726f4850e22522d0fdc4b165d5e6003ad5
3
  size 3854538376
moondream.py CHANGED
@@ -15,13 +15,26 @@ from .region import decode_coordinate, encode_coordinate, decode_size, encode_si
15
  from .utils import remove_outlier_points
16
 
17
 
18
- SamplingSettings = TypedDict(
19
- "SamplingSettings",
20
- {"max_tokens": int},
 
 
 
 
 
 
 
 
 
 
21
  total=False,
22
  )
23
 
24
  DEFAULT_MAX_TOKENS = 768
 
 
 
25
 
26
 
27
  @dataclass(frozen=True)
@@ -144,7 +157,7 @@ class MoondreamModel(nn.Module):
144
  def _decode_one_tok(
145
  self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor
146
  ):
147
- hidden = text_decoder(x[None], self.text, attn_mask, pos_ids, self.config.text)
148
  logits = lm_head(hidden, self.text)
149
  return logits, hidden
150
 
@@ -209,7 +222,19 @@ class MoondreamModel(nn.Module):
209
  ],
210
  )
211
 
212
- def _prefill_prompt(self, prompt_tokens: torch.Tensor, pos: int):
 
 
 
 
 
 
 
 
 
 
 
 
213
  with torch.inference_mode():
214
  prompt_emb = text_encoder(prompt_tokens, self.text)
215
  torch._dynamo.mark_dynamic(prompt_emb, 1)
@@ -217,7 +242,14 @@ class MoondreamModel(nn.Module):
217
  pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long)
218
  hidden = self._prefill(prompt_emb, mask, pos_ids)
219
  logits = lm_head(hidden, self.text)
220
- next_token = torch.argmax(logits, dim=-1)
 
 
 
 
 
 
 
221
  pos = pos + prompt_emb.size(1)
222
  return logits, hidden, next_token, pos
223
 
@@ -225,9 +257,23 @@ class MoondreamModel(nn.Module):
225
  self,
226
  prompt_tokens: torch.Tensor,
227
  pos: int,
228
- max_tokens: int,
229
  ):
230
- _, _, next_token, pos = self._prefill_prompt(prompt_tokens, pos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  def generator(next_token, pos):
233
  mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
@@ -275,7 +321,14 @@ class MoondreamModel(nn.Module):
275
  mask[:, :, pos], pos_ids[0] = 1, pos
276
  logits, _ = self._decode_one_tok(next_emb, mask, pos_ids)
277
  pos += 1
278
- next_token = torch.argmax(logits, dim=-1)
 
 
 
 
 
 
 
279
  generated_tokens += 1
280
 
281
  # Flush any remaining text in the cache
@@ -292,7 +345,7 @@ class MoondreamModel(nn.Module):
292
  image: Union[Image.Image, EncodedImage],
293
  question: str,
294
  stream: bool = False,
295
- settings: Optional[SamplingSettings] = None,
296
  ):
297
  if self.config.tokenizer.templates["query"] is None:
298
  raise NotImplementedError("Model does not support querying.")
@@ -309,12 +362,8 @@ class MoondreamModel(nn.Module):
309
  device=self.device,
310
  )
311
 
312
- max_tokens = DEFAULT_MAX_TOKENS
313
- if settings:
314
- max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
315
-
316
  def generator():
317
- for token in self._generate_text(prompt_tokens, image.pos, max_tokens):
318
  yield token
319
 
320
  if stream:
@@ -332,7 +381,7 @@ class MoondreamModel(nn.Module):
332
  image: Union[Image.Image, EncodedImage],
333
  length: Literal["normal", "short", "long"] = "normal",
334
  stream: bool = False,
335
- settings: Optional[SamplingSettings] = None,
336
  ):
337
  if self.config.tokenizer.templates["caption"] is None:
338
  raise NotImplementedError("Model does not support captioning.")
@@ -346,12 +395,8 @@ class MoondreamModel(nn.Module):
346
  [self.config.tokenizer.templates["caption"][length]], device=self.device
347
  )
348
 
349
- max_tokens = DEFAULT_MAX_TOKENS
350
- if settings:
351
- max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
352
-
353
  def generator():
354
- for token in self._generate_text(prompt_tokens, image.pos, max_tokens):
355
  yield token
356
 
357
  if stream:
@@ -365,7 +410,7 @@ class MoondreamModel(nn.Module):
365
  next_token: torch.Tensor,
366
  pos: int,
367
  include_size: bool = True,
368
- max_points: int = 50,
369
  ):
370
  out = []
371
  mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
@@ -375,13 +420,13 @@ class MoondreamModel(nn.Module):
375
  with torch.inference_mode():
376
  while (
377
  next_token.item() != self.config.tokenizer.eos_id
378
- and len(out) < max_points
379
  ):
380
  x_logits = decode_coordinate(hidden, self.region)
381
  x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
382
  next_emb = encode_coordinate(
383
  x_center.to(dtype=x_logits.dtype), self.region
384
- )
385
 
386
  # Decode y-coordinate
387
  mask[:, :, pos], pos_ids[0] = 1, pos
@@ -391,7 +436,7 @@ class MoondreamModel(nn.Module):
391
  y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
392
  next_emb = encode_coordinate(
393
  y_center.to(dtype=y_logits.dtype), self.region
394
- )
395
 
396
  # Decode size
397
  if include_size:
@@ -409,12 +454,16 @@ class MoondreamModel(nn.Module):
409
  w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
410
  h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
411
 
412
- next_emb = encode_size(
413
- torch.tensor(
414
- [w, h], device=self.device, dtype=size_logits.dtype
415
- ),
416
- self.region,
417
- )[None]
 
 
 
 
418
 
419
  # Add object
420
  out.append(
@@ -440,7 +489,7 @@ class MoondreamModel(nn.Module):
440
  self,
441
  image: Union[Image.Image, EncodedImage],
442
  object: str,
443
- settings: Optional[SamplingSettings] = None,
444
  ):
445
  if self.config.tokenizer.templates["detect"] is None:
446
  raise NotImplementedError("Model does not support object detection.")
@@ -457,11 +506,18 @@ class MoondreamModel(nn.Module):
457
  device=self.device,
458
  )
459
 
460
- _, hidden, next_token, pos = self._prefill_prompt(prompt_tokens, image.pos)
 
 
461
  hidden = hidden[:, -1:, :]
462
 
 
 
 
 
 
463
  objects = self._generate_points(
464
- hidden, next_token, pos, include_size=True, max_points=50
465
  )
466
 
467
  return {"objects": objects}
@@ -470,7 +526,7 @@ class MoondreamModel(nn.Module):
470
  self,
471
  image: Union[Image.Image, EncodedImage],
472
  object: str,
473
- settings: Optional[SamplingSettings] = None,
474
  ):
475
  if self.config.tokenizer.templates["point"] is None:
476
  raise NotImplementedError("Model does not support pointing.")
@@ -487,11 +543,18 @@ class MoondreamModel(nn.Module):
487
  device=self.device,
488
  )
489
 
490
- _, hidden, next_token, pos = self._prefill_prompt(prompt_tokens, image.pos)
 
 
491
  hidden = hidden[:, -1:, :]
492
 
 
 
 
 
 
493
  objects = self._generate_points(
494
- hidden, next_token, pos, include_size=False, max_points=50
495
  )
496
 
497
  return {"points": objects}
@@ -545,7 +608,7 @@ class MoondreamModel(nn.Module):
545
  return None
546
 
547
  gaze = self._generate_points(
548
- hidden, next_token, pos, include_size=False, max_points=1
549
  )
550
  return gaze[0]
551
 
 
15
  from .utils import remove_outlier_points
16
 
17
 
18
+ TextSamplingSettings = TypedDict(
19
+ "TextSamplingSettings",
20
+ {
21
+ "max_tokens": int,
22
+ "temperature": float,
23
+ "top_p": float,
24
+ },
25
+ total=False,
26
+ )
27
+
28
+ ObjectSamplingSettings = TypedDict(
29
+ "ObjectSamplingSettings",
30
+ {"max_objects": int},
31
  total=False,
32
  )
33
 
34
  DEFAULT_MAX_TOKENS = 768
35
+ DEFAULT_TEMPERATURE = 0.5
36
+ DEFAULT_TOP_P = 0.3
37
+ DEFAULT_MAX_OBJECTS = 50
38
 
39
 
40
  @dataclass(frozen=True)
 
157
  def _decode_one_tok(
158
  self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor
159
  ):
160
+ hidden = text_decoder(x, self.text, attn_mask, pos_ids, self.config.text)
161
  logits = lm_head(hidden, self.text)
162
  return logits, hidden
163
 
 
222
  ],
223
  )
224
 
225
+ def _apply_top_p(self, probs: torch.Tensor, top_p: float):
226
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
227
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
228
+ mask = probs_sum - probs_sort > top_p
229
+ probs_sort[mask] = 0.0
230
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
231
+ next_probs = torch.zeros_like(probs)
232
+ next_probs.scatter_(dim=-1, index=probs_idx, src=probs_sort)
233
+ return next_probs
234
+
235
+ def _prefill_prompt(
236
+ self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
237
+ ):
238
  with torch.inference_mode():
239
  prompt_emb = text_encoder(prompt_tokens, self.text)
240
  torch._dynamo.mark_dynamic(prompt_emb, 1)
 
242
  pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long)
243
  hidden = self._prefill(prompt_emb, mask, pos_ids)
244
  logits = lm_head(hidden, self.text)
245
+
246
+ if temperature == 0:
247
+ next_token = torch.argmax(logits, dim=-1).unsqueeze(1)
248
+ else:
249
+ probs = torch.softmax(logits / temperature, dim=-1)
250
+ probs = self._apply_top_p(probs, top_p)
251
+ next_token = torch.multinomial(probs, num_samples=1)
252
+
253
  pos = pos + prompt_emb.size(1)
254
  return logits, hidden, next_token, pos
255
 
 
257
  self,
258
  prompt_tokens: torch.Tensor,
259
  pos: int,
260
+ settings: Optional[TextSamplingSettings] = None,
261
  ):
262
+ max_tokens = (
263
+ settings.get("max_tokens", DEFAULT_MAX_TOKENS)
264
+ if settings
265
+ else DEFAULT_MAX_TOKENS
266
+ )
267
+ temperature = (
268
+ settings.get("temperature", DEFAULT_TEMPERATURE)
269
+ if settings
270
+ else DEFAULT_TEMPERATURE
271
+ )
272
+ top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
273
+
274
+ _, _, next_token, pos = self._prefill_prompt(
275
+ prompt_tokens, pos, temperature, top_p
276
+ )
277
 
278
  def generator(next_token, pos):
279
  mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
 
321
  mask[:, :, pos], pos_ids[0] = 1, pos
322
  logits, _ = self._decode_one_tok(next_emb, mask, pos_ids)
323
  pos += 1
324
+
325
+ if temperature == 0:
326
+ next_token = torch.argmax(logits, dim=-1).unsqueeze(1) # (1, 1)
327
+ else:
328
+ probs = torch.softmax(logits / temperature, dim=-1) # (1, V)
329
+ probs = self._apply_top_p(probs, top_p)
330
+ next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
331
+
332
  generated_tokens += 1
333
 
334
  # Flush any remaining text in the cache
 
345
  image: Union[Image.Image, EncodedImage],
346
  question: str,
347
  stream: bool = False,
348
+ settings: Optional[TextSamplingSettings] = None,
349
  ):
350
  if self.config.tokenizer.templates["query"] is None:
351
  raise NotImplementedError("Model does not support querying.")
 
362
  device=self.device,
363
  )
364
 
 
 
 
 
365
  def generator():
366
+ for token in self._generate_text(prompt_tokens, image.pos, settings):
367
  yield token
368
 
369
  if stream:
 
381
  image: Union[Image.Image, EncodedImage],
382
  length: Literal["normal", "short", "long"] = "normal",
383
  stream: bool = False,
384
+ settings: Optional[TextSamplingSettings] = None,
385
  ):
386
  if self.config.tokenizer.templates["caption"] is None:
387
  raise NotImplementedError("Model does not support captioning.")
 
395
  [self.config.tokenizer.templates["caption"][length]], device=self.device
396
  )
397
 
 
 
 
 
398
  def generator():
399
+ for token in self._generate_text(prompt_tokens, image.pos, settings):
400
  yield token
401
 
402
  if stream:
 
410
  next_token: torch.Tensor,
411
  pos: int,
412
  include_size: bool = True,
413
+ max_objects: int = DEFAULT_MAX_OBJECTS,
414
  ):
415
  out = []
416
  mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
 
420
  with torch.inference_mode():
421
  while (
422
  next_token.item() != self.config.tokenizer.eos_id
423
+ and len(out) < max_objects
424
  ):
425
  x_logits = decode_coordinate(hidden, self.region)
426
  x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
427
  next_emb = encode_coordinate(
428
  x_center.to(dtype=x_logits.dtype), self.region
429
+ ).unsqueeze(0)
430
 
431
  # Decode y-coordinate
432
  mask[:, :, pos], pos_ids[0] = 1, pos
 
436
  y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
437
  next_emb = encode_coordinate(
438
  y_center.to(dtype=y_logits.dtype), self.region
439
+ ).unsqueeze(0)
440
 
441
  # Decode size
442
  if include_size:
 
454
  w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
455
  h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
456
 
457
+ next_emb = (
458
+ encode_size(
459
+ torch.tensor(
460
+ [w, h], device=self.device, dtype=size_logits.dtype
461
+ ),
462
+ self.region,
463
+ )
464
+ .unsqueeze(0)
465
+ .unsqueeze(0)
466
+ )
467
 
468
  # Add object
469
  out.append(
 
489
  self,
490
  image: Union[Image.Image, EncodedImage],
491
  object: str,
492
+ settings: Optional[ObjectSamplingSettings] = None,
493
  ):
494
  if self.config.tokenizer.templates["detect"] is None:
495
  raise NotImplementedError("Model does not support object detection.")
 
506
  device=self.device,
507
  )
508
 
509
+ _, hidden, next_token, pos = self._prefill_prompt(
510
+ prompt_tokens, image.pos, temperature=0, top_p=0
511
+ )
512
  hidden = hidden[:, -1:, :]
513
 
514
+ max_objects = (
515
+ settings.get("max_objects", DEFAULT_MAX_OBJECTS)
516
+ if settings
517
+ else DEFAULT_MAX_OBJECTS
518
+ )
519
  objects = self._generate_points(
520
+ hidden, next_token, pos, include_size=True, max_objects=max_objects
521
  )
522
 
523
  return {"objects": objects}
 
526
  self,
527
  image: Union[Image.Image, EncodedImage],
528
  object: str,
529
+ settings: Optional[ObjectSamplingSettings] = None,
530
  ):
531
  if self.config.tokenizer.templates["point"] is None:
532
  raise NotImplementedError("Model does not support pointing.")
 
543
  device=self.device,
544
  )
545
 
546
+ _, hidden, next_token, pos = self._prefill_prompt(
547
+ prompt_tokens, image.pos, temperature=0, top_p=0
548
+ )
549
  hidden = hidden[:, -1:, :]
550
 
551
+ max_objects = (
552
+ settings.get("max_objects", DEFAULT_MAX_OBJECTS)
553
+ if settings
554
+ else DEFAULT_MAX_OBJECTS
555
+ )
556
  objects = self._generate_points(
557
+ hidden, next_token, pos, include_size=False, max_objects=max_objects
558
  )
559
 
560
  return {"points": objects}
 
608
  return None
609
 
610
  gaze = self._generate_points(
611
+ hidden, next_token, pos, include_size=False, max_objects=1
612
  )
613
  return gaze[0]
614