File size: 3,508 Bytes
05d640e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from transformers import PreTrainedModel, PretrainedConfig

from .config import MoondreamConfig
from .moondream import MoondreamModel

# Files sometimes don't get loaded without these...
from .image_crops import *
from .vision import *
from .text import *
from .region import *
from .utils import *


def extract_question(text):
    prefix = "<image>\n\nQuestion: "
    suffix = "\n\nAnswer:"
    
    if text.startswith(prefix) and text.endswith(suffix):
        return text[len(prefix) : -len(suffix)]
    else:
        return None


class HfConfig(PretrainedConfig):
    _auto_class = "AutoConfig"
    model_type = "moondream1"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.config = {}


class HfMoondream(PreTrainedModel):
    _auto_class = "AutoModelForCausalLM"
    config_class = HfConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = MoondreamModel(MoondreamConfig.from_dict(config.config))

    @property
    def encode_image(self):
        return self.model.encode_image

    @property
    def query(self):
        return self.model.query

    @property
    def caption(self):
        return self.model.caption

    @property
    def detect(self):
        return self.model.detect

    @property
    def point(self):
        return self.model.point

    @property
    def detect_gaze(self):
        return self.model.detect_gaze

    def answer_question(
        self,
        image_embeds,
        question,
        tokenizer=None,
        chat_history="",
        result_queue=None,
        max_new_tokens=256,
        **kwargs
    ):
        answer = self.query(image_embeds, question)["answer"].strip()

        if result_queue is not None:
            result_queue.put(answer)
        return answer

    def batch_answer(self, images, prompts, tokenizer=None, **kwargs):
        answers = []
        for image, prompt in zip(images, prompts):
            answers.append(self.query(image, prompt)["answer"].strip())
        return answers

    def _unsupported_exception(self):
        raise NotImplementedError(
            "This method is not supported in the latest version of moondream. "
            "Consider upgrading to the updated API spec, or alternately pin "
            "to 'revision=2024-08-26'."
        )

    def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs):
        """
        Function definition remains unchanged for backwards compatibility.
        Be aware that tokenizer, max_new_takens, and kwargs are ignored.
        """
        prompt_extracted = extract_question(prompt)
        if prompt_extracted is not None:
            answer = self.model.query(image=image_embeds, question=prompt_extracted, stream=False)[
                "answer"
            ]
        else:
            image_embeds = self.encode_image(image_embeds)
            prompt_tokens = torch.tensor(
                [self.model.tokenizer.encode(prompt).ids],
                device=self.device,
            )
            def generator():
                for token in self.model._generate_text(
                    prompt_tokens, image_embeds.kv_cache, image_embeds.pos, max_new_tokens
                ):
                    yield token
            answer = "".join(list(generator()))
            
        return [answer]

    def get_input_embeddings(self):
        return super().get_input_embeddings()

    def input_embeds(self, *args, **kwargs):
        self._unsupported_exception()