Spaces:
Build error
Build error
import spaces | |
import os | |
import json | |
import requests | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from transformers import MllamaForConditionalGeneration, AutoProcessor | |
from pdf2image import convert_from_path | |
from PyPDF2 import PdfReader | |
# Load the multimodal model | |
model_id = "miike-ai/r1-11b-vision" | |
model = MllamaForConditionalGeneration.from_pretrained( | |
model_id, torch_dtype=torch.bfloat16, device_map="auto" | |
) | |
processor = AutoProcessor.from_pretrained(model_id) | |
# File download function (for remote images or PDFs) | |
def download_file(url, save_dir="downloads"): | |
os.makedirs(save_dir, exist_ok=True) | |
local_filename = os.path.join(save_dir, url.split("/")[-1]) | |
response = requests.get(url, stream=True) | |
if response.status_code == 200: | |
with open(local_filename, "wb") as f: | |
for chunk in response.iter_content(1024): | |
f.write(chunk) | |
return local_filename | |
return None | |
# Extracts text and images from a PDF | |
def extract_pdf_content(pdf_path): | |
extracted_text = [] | |
images = convert_from_path(pdf_path)[:1] # Keep the first page image | |
pdf_reader = PdfReader(pdf_path) | |
for page in pdf_reader.pages: | |
text = page.extract_text() | |
if text: | |
extracted_text.append(text) | |
return " ".join(extracted_text), images | |
# Core multimodal processing function | |
def multimodal_chat(text_prompt, file_input=None): | |
conversation = [] | |
images = [] | |
extracted_text = "" | |
# Handle file input (if any) | |
if file_input: | |
file_path = file_input.name if hasattr(file_input, "name") else file_input # Handle both file objects & paths | |
if isinstance(file_path, str) and file_path.startswith("http"): | |
file_path = download_file(file_path) | |
if file_path.lower().endswith(".pdf"): | |
extracted_text, images = extract_pdf_content(file_path) | |
elif file_path.lower().endswith((".png", ".jpg", ".jpeg", ".webp")): | |
images.append(Image.open(file_path)) | |
# Prepare user input | |
user_message = {"role": "user", "content": [{"type": "text", "text": text_prompt}]} | |
if extracted_text: | |
user_message["content"].append({"type": "text", "text": extracted_text}) | |
if images: | |
user_message["content"].insert(0, {"type": "image"}) | |
conversation.append(user_message) | |
# Apply chat template and process input | |
input_text = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
if images: | |
inputs = processor(images=images, text=[input_text], add_special_tokens=True, return_tensors="pt").to(model.device) | |
else: | |
inputs = processor(text=[input_text], add_special_tokens=True, return_tensors="pt").to(model.device) | |
# Generate response | |
with torch.no_grad(): | |
output = model.generate(**inputs, max_new_tokens=8192) | |
response_text = processor.decode(output[0], skip_special_tokens=True) | |
# Format JSON response | |
response_json = { | |
"user_input": text_prompt, | |
"file_path": file_path if file_input else None, | |
"response": response_text | |
} | |
return json.dumps(response_json, indent=4) | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🤖 Multimodal AI Chatbot") | |
gr.Markdown("Type a message and optionally upload an **image or PDF** to chat with the AI.") | |
text_input = gr.Textbox(label="Enter your question") | |
file_input = gr.File(label="Upload an image/PDF (or enter URL)", type="filepath", interactive=True) | |
chat_button = gr.Button("Submit") | |
output_json = gr.Textbox(label="Response (JSON Output)", interactive=False) | |
chat_button.click(multimodal_chat, inputs=[text_input, file_input], outputs=output_json) | |
# Run the Gradio app | |
demo.launch() | |