Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python | |
from collections.abc import Iterator | |
from threading import Thread | |
import gradio as gr | |
import PIL.Image | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer | |
model_id = "sbintuitions/sarashina2-vision-14b" | |
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="cuda", trust_remote_code=True) | |
def run( | |
message: dict, | |
history: list[dict], | |
max_new_tokens: int = 256, | |
temperature: float = 0.7, | |
repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
if not history and len(message["files"]) == 0: | |
gr.Warning("Please upload an image.") | |
yield "" | |
return | |
if history and len(message["files"]) > 0: | |
gr.Warning("Only one image is allowed.") | |
yield "" | |
return | |
if not history: | |
image = PIL.Image.open(message["files"][0]) | |
messages = [] | |
for past_message in history: | |
content = past_message["content"] | |
if isinstance(content, tuple): | |
image = PIL.Image.open(content[0]) | |
else: | |
messages.append({"role": past_message["role"], "content": past_message["content"]}) | |
messages.append({"role": "user", "content": message["text"]}) | |
text_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
inputs = processor( | |
text=[text_prompt], | |
images=[image], | |
padding=True, | |
return_tensors="pt", | |
) | |
inputs = inputs.to(model.device) | |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
inputs, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
do_sample=temperature > 0.0, | |
repetition_penalty=repetition_penalty, | |
stopping_criteria=processor.get_stopping_criteria(["\n###"]), | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
output = "" | |
for delta in streamer: | |
output += delta | |
yield output | |
examples = [ | |
[ | |
{ | |
"text": "この写真に写っているもので、最も有名と考えられる建築物は何でどこに写っていますか?", | |
"files": ["assets/sample.jpg"], | |
} | |
], | |
] | |
demo = gr.ChatInterface( | |
fn=run, | |
type="messages", | |
multimodal=True, | |
textbox=gr.MultimodalTextbox(file_types=["image"], file_count="single"), | |
additional_inputs=[ | |
gr.Slider(label="Max new tokens", minimum=10, maximum=1024, step=1, value=512), | |
gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.7), | |
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.2), | |
], | |
examples=examples, | |
title="sbintuitions/sarashina2-vision-14b", | |
cache_examples=False, | |
run_examples_on_click=False, | |
css_paths="style.css", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |