import os from vllm import LLM, SamplingParams import gradio as gr from PIL import Image from io import BytesIO import base64 import requests from huggingface_hub import login import os login(os.environ["HF_TOKEN"]) repo_id = "mistral-community/pixtral-12b-240910" #Replace to the model you would like to use sampling_params = SamplingParams(max_tokens=8192, temperature=0.7) max_tokens_per_img = 4096 max_img_per_msg = 5 llm = LLM(model="mistralai/Pixtral-12B-2409", tokenizer_mode="mistral", max_model_len=65536, max_num_batched_tokens=max_img_per_msg * max_tokens_per_img, limit_mm_per_prompt={"image": max_img_per_msg}) # Name or path of your model def encode_image(image: Image.Image, image_format="PNG") -> str: im_file = BytesIO() image.save(im_file, format=image_format) im_bytes = im_file.getvalue() im_64 = base64.b64encode(im_bytes).decode("utf-8") return im_64 # @spaces.GPU #[uncomment to use ZeroGPU] def infer(image_url, prompt, progress=gr.Progress(track_tqdm=True)): image = Image.open(BytesIO(requests.get(image_url).content)) image = image.resize((3844, 2408)) new_image_url = f"data:image/png;base64,{encode_image(image, image_format='PNG')}" messages = [ { "role": "user", "content": [{"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": new_image_url}}] }, ] outputs = llm.chat(messages, sampling_params=sampling_params) return outputs[0].outputs[0].text examples = [["https://picsum.photos/id/237/200/300", "What do you see in this image?"]] css = """ #col-container { margin: 0 auto; max-width: 640px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(f""" # Mistral Pixtral 12B """) with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=2, placeholder="Enter your prompt", container=False, ) with gr.Row(): image_url = gr.Text( label="Image URL", show_label=False, max_lines=1, placeholder="Enter your image URL", container=False, ) with gr.Row(): run_button = gr.Button("Run", scale=0) result = gr.Textbox( show_label=False ) gr.Examples( examples=examples, inputs=[image_url, prompt] ) gr.on( triggers=[run_button.click, image_url.submit, prompt.submit], fn=infer, inputs=[image_url, prompt], outputs=[result] ) demo.queue().launch()