r1-11b-vision / app.py
miike-ai's picture
Update app.py
a3457e5 verified
import spaces
import os
import json
import requests
import torch
import gradio as gr
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor
from pdf2image import convert_from_path
from PyPDF2 import PdfReader
# Load the multimodal model
model_id = "miike-ai/r1-11b-vision"
model = MllamaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto"
)
processor = AutoProcessor.from_pretrained(model_id)
# File download function (for remote images or PDFs)
def download_file(url, save_dir="downloads"):
os.makedirs(save_dir, exist_ok=True)
local_filename = os.path.join(save_dir, url.split("/")[-1])
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(local_filename, "wb") as f:
for chunk in response.iter_content(1024):
f.write(chunk)
return local_filename
return None
# Extracts text and images from a PDF
def extract_pdf_content(pdf_path):
extracted_text = []
images = convert_from_path(pdf_path)[:1] # Keep the first page image
pdf_reader = PdfReader(pdf_path)
for page in pdf_reader.pages:
text = page.extract_text()
if text:
extracted_text.append(text)
return " ".join(extracted_text), images
# Core multimodal processing function
@spaces.GPU
def multimodal_chat(text_prompt, file_input=None):
conversation = []
images = []
extracted_text = ""
# Handle file input (if any)
if file_input:
file_path = file_input.name if hasattr(file_input, "name") else file_input # Handle both file objects & paths
if isinstance(file_path, str) and file_path.startswith("http"):
file_path = download_file(file_path)
if file_path.lower().endswith(".pdf"):
extracted_text, images = extract_pdf_content(file_path)
elif file_path.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
images.append(Image.open(file_path))
# Prepare user input
user_message = {"role": "user", "content": [{"type": "text", "text": text_prompt}]}
if extracted_text:
user_message["content"].append({"type": "text", "text": extracted_text})
if images:
user_message["content"].insert(0, {"type": "image"})
conversation.append(user_message)
# Apply chat template and process input
input_text = processor.apply_chat_template(conversation, add_generation_prompt=True)
if images:
inputs = processor(images=images, text=[input_text], add_special_tokens=True, return_tensors="pt").to(model.device)
else:
inputs = processor(text=[input_text], add_special_tokens=True, return_tensors="pt").to(model.device)
# Generate response
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=8192)
response_text = processor.decode(output[0], skip_special_tokens=True)
# Format JSON response
response_json = {
"user_input": text_prompt,
"file_path": file_path if file_input else None,
"response": response_text
}
return json.dumps(response_json, indent=4)
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# 🤖 Multimodal AI Chatbot")
gr.Markdown("Type a message and optionally upload an **image or PDF** to chat with the AI.")
text_input = gr.Textbox(label="Enter your question")
file_input = gr.File(label="Upload an image/PDF (or enter URL)", type="filepath", interactive=True)
chat_button = gr.Button("Submit")
output_json = gr.Textbox(label="Response (JSON Output)", interactive=False)
chat_button.click(multimodal_chat, inputs=[text_input, file_input], outputs=output_json)
# Run the Gradio app
demo.launch()