|
import gradio as gr |
|
from PIL import Image |
|
import torch |
|
from transformers import AutoProcessor, LlavaNextForConditionalGeneration |
|
import spaces |
|
|
|
|
|
model_id = "llava-hf/llava-v1.6-mistral-7B-hf" |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
model = LlavaNextForConditionalGeneration.from_pretrained( |
|
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True |
|
) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
|
|
@spaces.GPU() |
|
def llava_inference(image: Image.Image, prompt: str): |
|
|
|
conversation = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image"}, |
|
{"type": "text", "text": prompt}, |
|
], |
|
}, |
|
] |
|
formatted_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) |
|
inputs = processor(image, formatted_prompt, return_tensors="pt").to(device) |
|
|
|
output_ids = model.generate(**inputs, max_new_tokens=100) |
|
output_text = processor.decode(output_ids[0], skip_special_tokens=True) |
|
return output_text |
|
|
|
|
|
demo = gr.Interface( |
|
fn=llava_inference, |
|
inputs=[ |
|
gr.Image(type="pil", label="Input Image"), |
|
gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt") |
|
], |
|
outputs=gr.Text(label="Output Response"), |
|
title="LLaVA-1.6 Gradio Demo", |
|
description="Upload an image and enter a prompt. The model will generate a response using LLaVA-1.6.", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|