#!/usr/bin/env python from collections.abc import Iterator from threading import Thread import gradio as gr import PIL.Image import spaces from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer model_id = "sbintuitions/sarashina2-vision-14b" processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="cuda", trust_remote_code=True) @spaces.GPU def run( message: dict, history: list[dict], max_new_tokens: int = 256, temperature: float = 0.7, repetition_penalty: float = 1.2, ) -> Iterator[str]: if not history and len(message["files"]) == 0: gr.Warning("Please upload an image.") yield "" return if history and len(message["files"]) > 0: gr.Warning("Only one image is allowed.") yield "" return if not history: image = PIL.Image.open(message["files"][0]) messages = [] for past_message in history: content = past_message["content"] if isinstance(content, tuple): image = PIL.Image.open(content[0]) else: messages.append({"role": past_message["role"], "content": past_message["content"]}) messages.append({"role": "user", "content": message["text"]}) text_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor( text=[text_prompt], images=[image], padding=True, return_tensors="pt", ) inputs = inputs.to(model.device) streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=temperature > 0.0, repetition_penalty=repetition_penalty, stopping_criteria=processor.get_stopping_criteria(["\n###"]), ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() output = "" for delta in streamer: output += delta yield output examples = [ [ { "text": "この写真に写っているもので、最も有名と考えられる建築物は何でどこに写っていますか?", "files": ["assets/sample.jpg"], } ], ] demo = gr.ChatInterface( fn=run, type="messages", multimodal=True, textbox=gr.MultimodalTextbox(file_types=["image"], file_count="single"), additional_inputs=[ gr.Slider(label="Max new tokens", minimum=10, maximum=1024, step=1, value=512), gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.7), gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.2), ], examples=examples, title="sbintuitions/sarashina2-vision-14b", cache_examples=False, run_examples_on_click=False, css_paths="style.css", ) if __name__ == "__main__": demo.launch()