File size: 3,067 Bytes
feb8b85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6ffcd
 
feb8b85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6ffcd
 
 
feb8b85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6ffcd
 
 
 
 
feb8b85
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python

from collections.abc import Iterator
from threading import Thread

import gradio as gr
import PIL.Image
import spaces
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer

model_id = "sbintuitions/sarashina2-vision-14b"
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="cuda", trust_remote_code=True)


@spaces.GPU
def run(
    message: dict,
    history: list[dict],
    max_new_tokens: int = 256,
    temperature: float = 0.7,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    if not history and len(message["files"]) == 0:
        gr.Warning("Please upload an image.")
        yield ""
        return
    if history and len(message["files"]) > 0:
        gr.Warning("Only one image is allowed.")
        yield ""
        return

    if not history:
        image = PIL.Image.open(message["files"][0])
    messages = []
    for past_message in history:
        content = past_message["content"]
        if isinstance(content, tuple):
            image = PIL.Image.open(content[0])
        else:
            messages.append({"role": past_message["role"], "content": past_message["content"]})
    messages.append({"role": "user", "content": message["text"]})

    text_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(
        text=[text_prompt],
        images=[image],
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(model.device)

    streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=temperature > 0.0,
        repetition_penalty=repetition_penalty,
        stopping_criteria=processor.get_stopping_criteria(["\n###"]),
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    output = ""
    for delta in streamer:
        output += delta
        yield output


examples = [
    [
        {
            "text": "この写真に写っているもので、最も有名と考えられる建築物は何でどこに写っていますか?",
            "files": ["assets/sample.jpg"],
        }
    ],
]

demo = gr.ChatInterface(
    fn=run,
    type="messages",
    multimodal=True,
    textbox=gr.MultimodalTextbox(file_types=["image"], file_count="single"),
    additional_inputs=[
        gr.Slider(label="Max new tokens", minimum=10, maximum=1024, step=1, value=512),
        gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.7),
        gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.2),
    ],
    examples=examples,
    title="sbintuitions/sarashina2-vision-14b",
    cache_examples=False,
    run_examples_on_click=False,
    css_paths="style.css",
)
if __name__ == "__main__":
    demo.launch()