sergiopaniego's picture
Update app.py
a71f62f verified
import gradio as gr
from datasets import load_dataset
import numpy as np
import torch
import random
import time
import spaces
from transformers import AutoProcessor, AutoModelForImageTextToText, AutoModelForVision2Seq
dataset_vqa = load_dataset(
path="KevinNotSmile/nuscenes-qa-mini",
name="day",
split="train",
data_files="day-train/*.arrow",
)
MODEL_VERSIONS = {
"SmolVLM-256M-Instruct": "HuggingFaceTB/SmolVLM-256M-Instruct",
"SmolVLM-500M-Instruct": "HuggingFaceTB/SmolVLM-500M-Instruct",
"SmolVLM-2.2B-Instruct": "HuggingFaceTB/SmolVLM-Instruct",
"SmolVLM2-256M-Instruct": "HuggingFaceTB/SmolVLM2-256M-Instruct",
"SmolVLM2-500M-Instruct": "HuggingFaceTB/SmolVLM2-500M-Instruct",
"SmolVLM2-2.2B-Instruct": "HuggingFaceTB/SmolVLM2-2.2B-Instruct",
}
def load_model_and_processor(version):
model_name = MODEL_VERSIONS[version]
if version.startswith("SmolVLM-"):
model = AutoModelForVision2Seq.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
else:
model = AutoModelForImageTextToText.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
processor = AutoProcessor.from_pretrained(model_name)
return model, processor
@spaces.GPU
def predict(model_version):
sample = random.choice(dataset_vqa)
model, processor = load_model_and_processor(model_version)
messages = [
{
"role": "system",
"content": "You are analyzing real-time camera feed from a self-driving car's multi-camera setup. "
+ "The position of the cameras with respect to the car is: "
+ "CAM_FRONT_LEFT, CAM_FRONT, CAM_FRONT_RIGHT, CAM_BACK_LEFT, CAM_BACK, CAM_BACK_RIGHT. "
+ "Your task is to perform precise visual analysis and answer questions about the scene."
},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "image"},
{"type": "image"},
{"type": "image"},
{"type": "image"},
{"type": "image"},
{"type": "text", "text": f"Answer the following question. {sample['question']}."},
],
},
{
"role": "assistant",
"content": "Answer: "
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True, do_rescale=False)
images = [
np.array(sample["CAM_FRONT_LEFT"]),
np.array(sample["CAM_FRONT"]),
np.array(sample["CAM_FRONT_RIGHT"]),
np.array(sample["CAM_BACK_LEFT"]),
np.array(sample["CAM_BACK"]),
np.array(sample["CAM_BACK_RIGHT"]),
]
inputs = processor(text=prompt, images=images, return_tensors="pt").to(device=model.device).to(torch.float16)
start = time.time()
generated_ids = model.generate(**inputs, max_new_tokens=1000)
end = time.time()
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
predicted_answer = generated_text.split("Assistant: ")[-1].strip()
expected_answer = sample["answer"].strip()
question = sample["question"].strip()
is_correct = predicted_answer.lower() == expected_answer.lower()
inference_time = round(end - start, 2)
return (
images[0], images[1], images[2], images[3], images[4], images[5],
question, expected_answer, predicted_answer,
"βœ… Correct" if is_correct else "❌ Incorrect",
f"{inference_time:.2f} seconds"
)
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald")
with gr.Blocks(theme=theme, title="πŸ” SmolVLM2 VQA Demo (NuScenes multimodal QA dataset)") as demo:
gr.Markdown("# SmolVLM2 VQA Demo (NuScenes multimodal QA dataset)")
gr.Markdown("This is a demo for the **SmolVLM-SmolVLM2 model family** on the **NuScenes multimodal QA dataset**.")
gr.Markdown("You can select different model versions and predict answers to questions based on the camera feed.")
gr.Markdown("[Check out the SmolVLM2 collection](https://huggingface.co/collections/HuggingFaceTB/smolvlm2-smallest-video-lm-ever-67ab6b5e84bf8aaa60cb17c7)")
gr.Markdown("[Check out the SmolVLM collection](https://huggingface.co/collections/HuggingFaceTB/smolvlm-6740bd584b2dcbf51ecb1f39)")
gr.Markdown("[Check out the NuScenes multimodal QA dataset](https://huggingface.co/datasets/KevinNotSmile/nuscenes-qa-mini)")
model_selector = gr.Dropdown(
choices=list(MODEL_VERSIONS.keys()),
value="2-2.2B",
label="Select Model Version"
)
predict_button = gr.Button("Predict on Random Sample")
with gr.Row():
cam_images_front = [
gr.Image(label=cam) for cam in [
"CAM_FRONT_LEFT", "CAM_FRONT", "CAM_FRONT_RIGHT"
]
]
with gr.Row():
cam_images_back = [
gr.Image(label=cam) for cam in [
"CAM_BACK_LEFT", "CAM_BACK", "CAM_BACK_RIGHT"
]
]
cam_images = cam_images_front + cam_images_back
question_text = gr.Textbox(label="Question")
expected_text = gr.Textbox(label="Expected Answer")
predicted_text = gr.Textbox(label="Predicted Answer")
correctness = gr.Textbox(label="Correct?")
timing = gr.Textbox(label="Inference Time")
predict_button.click(
fn=predict,
inputs=[model_selector],
outputs=cam_images + [question_text, expected_text, predicted_text, correctness, timing]
)
demo.launch()