File size: 3,790 Bytes
07d3e2d
db0dfa5
 
 
 
 
 
 
 
 
 
 
f04b346
db0dfa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17c5d09
db0dfa5
 
 
 
 
 
 
 
 
 
d7279c0
db0dfa5
 
 
 
 
 
 
17c5d09
 
 
 
db0dfa5
17c5d09
 
 
 
db0dfa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3457e5
db0dfa5
 
 
 
 
 
17c5d09
db0dfa5
 
 
 
 
 
 
 
 
 
 
17c5d09
db0dfa5
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import spaces
import os
import json
import requests
import torch
import gradio as gr
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor
from pdf2image import convert_from_path
from PyPDF2 import PdfReader

# Load the multimodal model
model_id = "miike-ai/r1-11b-vision"

model = MllamaForConditionalGeneration.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, device_map="auto"
)
processor = AutoProcessor.from_pretrained(model_id)

# File download function (for remote images or PDFs)
def download_file(url, save_dir="downloads"):
    os.makedirs(save_dir, exist_ok=True)
    local_filename = os.path.join(save_dir, url.split("/")[-1])
    response = requests.get(url, stream=True)
    if response.status_code == 200:
        with open(local_filename, "wb") as f:
            for chunk in response.iter_content(1024):
                f.write(chunk)
        return local_filename
    return None

# Extracts text and images from a PDF
def extract_pdf_content(pdf_path):
    extracted_text = []
    images = convert_from_path(pdf_path)[:1]  # Keep the first page image

    pdf_reader = PdfReader(pdf_path)
    for page in pdf_reader.pages:
        text = page.extract_text()
        if text:
            extracted_text.append(text)

    return " ".join(extracted_text), images

# Core multimodal processing function
@spaces.GPU
def multimodal_chat(text_prompt, file_input=None):
    conversation = []
    images = []
    extracted_text = ""
    
    # Handle file input (if any)
    if file_input:
        file_path = file_input.name if hasattr(file_input, "name") else file_input  # Handle both file objects & paths
        
        if isinstance(file_path, str) and file_path.startswith("http"):
            file_path = download_file(file_path)

        if file_path.lower().endswith(".pdf"):
            extracted_text, images = extract_pdf_content(file_path)
        elif file_path.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
            images.append(Image.open(file_path))

    # Prepare user input
    user_message = {"role": "user", "content": [{"type": "text", "text": text_prompt}]}
    
    if extracted_text:
        user_message["content"].append({"type": "text", "text": extracted_text})
    if images:
        user_message["content"].insert(0, {"type": "image"})

    conversation.append(user_message)

    # Apply chat template and process input
    input_text = processor.apply_chat_template(conversation, add_generation_prompt=True)

    if images:
        inputs = processor(images=images, text=[input_text], add_special_tokens=True, return_tensors="pt").to(model.device)
    else:
        inputs = processor(text=[input_text], add_special_tokens=True, return_tensors="pt").to(model.device)

    # Generate response
    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=8192)

    response_text = processor.decode(output[0], skip_special_tokens=True)

    # Format JSON response
    response_json = {
        "user_input": text_prompt,
        "file_path": file_path if file_input else None,
        "response": response_text
    }

    return json.dumps(response_json, indent=4)

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# 🤖 Multimodal AI Chatbot")
    gr.Markdown("Type a message and optionally upload an **image or PDF** to chat with the AI.")

    text_input = gr.Textbox(label="Enter your question")
    file_input = gr.File(label="Upload an image/PDF (or enter URL)", type="filepath", interactive=True)
    chat_button = gr.Button("Submit")
    output_json = gr.Textbox(label="Response (JSON Output)", interactive=False)

    chat_button.click(multimodal_chat, inputs=[text_input, file_input], outputs=output_json)

# Run the Gradio app
demo.launch()