import os import streamlit as st from PIL import Image import torch from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration # Get Hugging Face API key from environment variables HF_TOKEN = os.getenv("HF_KEY") # Ensure API key is available if not HF_TOKEN: st.error("❌ Hugging Face API key not found! Set it as 'HF_KEY' in Spaces secrets.") st.stop() # Load the model and processor with authentication @st.cache_resource def load_model(): model_id = "google/paligemma2-3b-mix-224" try: model = PaliGemmaForConditionalGeneration.from_pretrained( model_id, token=HF_TOKEN, torch_dtype=torch.bfloat16, device_map="auto" ).eval() processor = PaliGemmaProcessor.from_pretrained( model_id, token=HF_TOKEN ) return processor, model except Exception as e: st.error(f"❌ Error loading model: {str(e)}") st.stop() processor, model = load_model() # Streamlit UI st.title("🖼️ Image Understanding with PaliGemma") uploaded_file = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"]) if uploaded_file: image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded Image", use_container_width=True) # User input for task selection task = st.selectbox( "Select a task:", ["Generate a caption", "Answer a question", "Detect objects", "Generate segmentation"] ) # User prompt prompt = st.text_area("Enter a prompt (e.g., 'Describe the image' or 'What objects are present?')") if st.button("Run"): if prompt: inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device) input_len = inputs["input_ids"].shape[-1] # Get input length with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=100, do_sample=False) generation = generation[0][input_len:] # Remove input tokens from output answer = processor.decode(generation, skip_special_tokens=True) st.success(f"✅ Result: {answer}")