|
import os |
|
import streamlit as st |
|
from PIL import Image |
|
import torch |
|
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_KEY") |
|
|
|
|
|
if not HF_TOKEN: |
|
st.error("β Hugging Face API key not found! Set it as 'HF_KEY' in Spaces secrets.") |
|
st.stop() |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model_id = "google/paligemma2-3b-mix-224" |
|
|
|
try: |
|
model = PaliGemmaForConditionalGeneration.from_pretrained( |
|
model_id, token=HF_TOKEN, torch_dtype=torch.bfloat16, device_map="auto" |
|
).eval() |
|
|
|
processor = PaliGemmaProcessor.from_pretrained( |
|
model_id, token=HF_TOKEN |
|
) |
|
|
|
return processor, model |
|
|
|
except Exception as e: |
|
st.error(f"β Error loading model: {str(e)}") |
|
st.stop() |
|
|
|
processor, model = load_model() |
|
|
|
|
|
st.title("πΌοΈ Image Understanding with PaliGemma") |
|
|
|
uploaded_file = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"]) |
|
|
|
if uploaded_file: |
|
image = Image.open(uploaded_file).convert("RGB") |
|
st.image(image, caption="Uploaded Image", use_container_width=True) |
|
|
|
|
|
task = st.selectbox( |
|
"Select a task:", |
|
["Generate a caption", "Answer a question", "Detect objects", "Generate segmentation"] |
|
) |
|
|
|
|
|
prompt = st.text_area("Enter a prompt (e.g., 'Describe the image' or 'What objects are present?')") |
|
|
|
if st.button("Run"): |
|
if prompt: |
|
inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device) |
|
input_len = inputs["input_ids"].shape[-1] |
|
|
|
with torch.inference_mode(): |
|
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False) |
|
generation = generation[0][input_len:] |
|
answer = processor.decode(generation, skip_special_tokens=True) |
|
|
|
st.success(f"β
Result: {answer}") |
|
|