import gradio as gr import spaces import torch from PIL import Image # Set random seeds for reproducibility torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) from models.vision_language_model import VisionLanguageModel from data.processors import get_tokenizer, get_image_processor @spaces.GPU def generate_outputs(image, query): # Determine device if torch.cuda.is_available(): device = torch.device("cuda") elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") # Load model hf_model = "lusxvr/nanoVLM-222M" try: model = VisionLanguageModel.from_pretrained(hf_model).to(device) model.eval() except Exception as e: return f"Error loading model: {str(e)}", None, None, None, None # Load tokenizer and image processor try: tokenizer = get_tokenizer(model.cfg.lm_tokenizer) image_processor = get_image_processor(model.cfg.vit_img_size) except Exception as e: return f"Error loading tokenizer or image processor: {str(e)}", None, None, None, None # Prepare text input template = f"Question: {query} Answer:" encoded = tokenizer.batch_encode_plus([template], return_tensors="pt") tokens = encoded["input_ids"].to(device) # Process image try: img = image.convert("RGB") img_t = image_processor(img).unsqueeze(0).to(device) except Exception as e: return f"Error processing image: {str(e)}", None, None, None, None # Generate four outputs outputs = [] max_new_tokens = 50 # Fixed value from provided script try: for _ in range(4): gen = model.generate(tokens, img_t, max_new_tokens=max_new_tokens) out = tokenizer.batch_decode(gen, skip_special_tokens=True)[0] outputs.append(out) except Exception as e: return f"Error during generation: {str(e)}", None, None, None, None return None, outputs[0], outputs[1], outputs[2], outputs[3] def main(): # Define minimal CSS for subtle aesthetic enhancements css = """ .gradio-container { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; padding: 20px; } h1 { color: #333; text-align: center; margin-bottom: 20px; } .description { margin-bottom: 20px; line-height: 1.6; } .gr-button { padding: 10px 20px; } """ # Define Gradio interface with gr.Blocks(css=css, title="nanoVLM Image-to-Text Generator") as app: gr.Markdown( "# nanoVLM Image-to-Text Generator" ) gr.Markdown( """