import streamlit as st from PIL import Image import inference from transformers import AutoProcessor, AutoModelForCausalLM from PIL import Image import io import requests import copy import os from unittest.mock import patch from transformers.dynamic_module_utils import get_imports import torch #remove flash_attn for load model in cpu def fixed_get_imports(filename: str | os.PathLike) -> list[str]: if not str(filename).endswith("modeling_florence2.py"): return get_imports(filename) imports = get_imports(filename) imports.remove("flash_attn") return imports # Initialize session state for model loading and to block re-running if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False # Function to load the model (e.g., Florence-2 model) def load_model(): # Simulate model loading process model_id = "microsoft/Florence-2-large" #processor loading st.session_state.processor = AutoProcessor.from_pretrained(model_id, torch_dtype=torch.qint8, trust_remote_code=True) try: os.mkdir("temp") except: pass # Load the model normally with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): # workaround for unnecessary flash_attn requirement model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="sdpa", trust_remote_code=True) # Apply dynamic quantization Qmodel = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) del model st.session_state.model = Qmodel st.session_state.model_loaded = True st.write("model loaded complete") # Load the model only once if not st.session_state.model_loaded: with st.spinner('Loading model...'): load_model() # Initialize session state to block re-running if 'has_run' not in st.session_state: st.session_state.has_run = False # Main UI container st.markdown('

VQA

', unsafe_allow_html=True) # Image upload area uploaded_image = st.sidebar.file_uploader("Upload your image here", type=["jpg", "jpeg", "png"]) # Display the uploaded image and process it if available if uploaded_image is not None: image = Image.open(uploaded_image) if image.mode != 'RGB': image = image.convert('RGB') image = image.resize((256,256)) # Save the image to a BytesIO object with a specific format image_bytes = io.BytesIO() image_format = image.format if image.format else 'PNG' # Default to 'PNG' if format is None image.save(image_bytes, format=image_format) image_bytes.seek(0) # Display the image using Streamlit st.image(image, caption="Uploaded Image", use_column_width=True) image_binary = image_bytes.getvalue() # Task prompt input task_prompt = st.sidebar.text_input("Task Prompt", value="") # Additional text input (optional) text_input = st.sidebar.text_area("Input Questions",value="", height=20) # Generate Caption button if st.sidebar.button("Generate Caption", key="Generate"): #st.write(task_prompt,"\n\n",text_input) # inference.demo() output=inference.run_example(image,st.session_state.model,st.session_state.processor,task_prompt,text_input) st.write(output)