Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,067 Bytes
feb8b85 ea6ffcd feb8b85 ea6ffcd feb8b85 ea6ffcd feb8b85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
#!/usr/bin/env python
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import PIL.Image
import spaces
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
model_id = "sbintuitions/sarashina2-vision-14b"
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="cuda", trust_remote_code=True)
@spaces.GPU
def run(
message: dict,
history: list[dict],
max_new_tokens: int = 256,
temperature: float = 0.7,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
if not history and len(message["files"]) == 0:
gr.Warning("Please upload an image.")
yield ""
return
if history and len(message["files"]) > 0:
gr.Warning("Only one image is allowed.")
yield ""
return
if not history:
image = PIL.Image.open(message["files"][0])
messages = []
for past_message in history:
content = past_message["content"]
if isinstance(content, tuple):
image = PIL.Image.open(content[0])
else:
messages.append({"role": past_message["role"], "content": past_message["content"]})
messages.append({"role": "user", "content": message["text"]})
text_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
text=[text_prompt],
images=[image],
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0.0,
repetition_penalty=repetition_penalty,
stopping_criteria=processor.get_stopping_criteria(["\n###"]),
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
output = ""
for delta in streamer:
output += delta
yield output
examples = [
[
{
"text": "この写真に写っているもので、最も有名と考えられる建築物は何でどこに写っていますか?",
"files": ["assets/sample.jpg"],
}
],
]
demo = gr.ChatInterface(
fn=run,
type="messages",
multimodal=True,
textbox=gr.MultimodalTextbox(file_types=["image"], file_count="single"),
additional_inputs=[
gr.Slider(label="Max new tokens", minimum=10, maximum=1024, step=1, value=512),
gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.7),
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.2),
],
examples=examples,
title="sbintuitions/sarashina2-vision-14b",
cache_examples=False,
run_examples_on_click=False,
css_paths="style.css",
)
if __name__ == "__main__":
demo.launch()
|