Spaces:
Running
on
T4
Running
on
T4
cheesyFishes
commited on
Commit
Β·
5e5d2bc
1
Parent(s):
2557dbe
add initial source files
Browse files- README.md +4 -4
- app.py +243 -0
- requirements.txt +3 -0
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
-
title: Multimodal
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.11.0
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: mit
|
11 |
short_description: Multimodal retrieval using llamaindex/vdr-2b-multi-v1
|
12 |
---
|
|
|
1 |
---
|
2 |
+
title: Multimodal VDR Demo
|
3 |
+
emoji: π¦π
|
4 |
+
colorFrom: pink
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.11.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: mit
|
11 |
short_description: Multimodal retrieval using llamaindex/vdr-2b-multi-v1
|
12 |
---
|
app.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from llama_parse import LlamaParse
|
6 |
+
from llama_index.core import StorageContext, load_index_from_storage
|
7 |
+
from llama_index.core.indices import MultiModalVectorStoreIndex
|
8 |
+
from llama_index.core.schema import Document, ImageDocument
|
9 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
10 |
+
|
11 |
+
|
12 |
+
example_indexes = {
|
13 |
+
"IONIQ 2024": "./iconiq_report_index",
|
14 |
+
"Uber 10k 2021": "./uber_index",
|
15 |
+
}
|
16 |
+
|
17 |
+
device = "cpu"
|
18 |
+
if torch.cuda.is_available():
|
19 |
+
device = "cuda"
|
20 |
+
elif torch.backends.mps.is_available():
|
21 |
+
device = "mps"
|
22 |
+
|
23 |
+
image_embed_model = HuggingFaceEmbedding(
|
24 |
+
model_name="llamaindex/vdr-2b-multi-v1",
|
25 |
+
device=device,
|
26 |
+
trust_remote_code=True,
|
27 |
+
token=os.getenv("HUGGINGFACE_TOKEN"),
|
28 |
+
model_kwargs={"torch_dtype": torch.float16},
|
29 |
+
embed_batch_size=4,
|
30 |
+
)
|
31 |
+
|
32 |
+
text_embed_model = HuggingFaceEmbedding(
|
33 |
+
model_name="BAAI/bge-small-en",
|
34 |
+
device=device,
|
35 |
+
trust_remote_code=True,
|
36 |
+
token=os.getenv("HUGGINGFACE_TOKEN"),
|
37 |
+
embed_batch_size=4,
|
38 |
+
)
|
39 |
+
|
40 |
+
def load_index(index_path: str) -> MultiModalVectorStoreIndex:
|
41 |
+
storage_context = StorageContext.from_defaults(persist_dir=index_path)
|
42 |
+
return load_index_from_storage(
|
43 |
+
storage_context,
|
44 |
+
embed_model=text_embed_model,
|
45 |
+
image_embed_model=image_embed_model,
|
46 |
+
)
|
47 |
+
|
48 |
+
def create_index(file, llama_parse_key, progress=gr.Progress()):
|
49 |
+
if not file or not llama_parse_key:
|
50 |
+
return None, "Please provide both a file and LlamaParse API key"
|
51 |
+
|
52 |
+
try:
|
53 |
+
progress(0, desc="Initializing LlamaParse...")
|
54 |
+
parser = LlamaParse(
|
55 |
+
api_key=llama_parse_key,
|
56 |
+
take_screenshot=True,
|
57 |
+
)
|
58 |
+
|
59 |
+
# Process document
|
60 |
+
progress(0.2, desc="Processing document with LlamaParse...")
|
61 |
+
md_json_obj = parser.get_json_result(file.name)[0]
|
62 |
+
|
63 |
+
progress(0.4, desc="Downloading and processing images...")
|
64 |
+
image_dicts = parser.get_images(
|
65 |
+
[md_json_obj],
|
66 |
+
download_path=os.path.join(os.path.dirname(file.name), f"{file.name}_images")
|
67 |
+
)
|
68 |
+
|
69 |
+
# Create text document
|
70 |
+
progress(0.6, desc="Creating text documents...")
|
71 |
+
text = ""
|
72 |
+
for page in md_json_obj["pages"]:
|
73 |
+
text += page["md"] + "\n\n"
|
74 |
+
text_docs = [Document(text=text.strip())]
|
75 |
+
|
76 |
+
# Create image documents
|
77 |
+
progress(0.8, desc="Creating image documents...")
|
78 |
+
image_docs = []
|
79 |
+
for image_dict in image_dicts:
|
80 |
+
image_docs.append(ImageDocument(text=image_dict["name"], image_path=image_dict["path"]))
|
81 |
+
|
82 |
+
# Create index
|
83 |
+
progress(0.9, desc="Creating final index...")
|
84 |
+
index = MultiModalVectorStoreIndex.from_documents(
|
85 |
+
text_docs + image_docs,
|
86 |
+
embed_model=text_embed_model,
|
87 |
+
image_embed_model=image_embed_model,
|
88 |
+
)
|
89 |
+
|
90 |
+
progress(1.0, desc="Complete!")
|
91 |
+
return index, "Index created successfully!"
|
92 |
+
|
93 |
+
except Exception as e:
|
94 |
+
return None, f"Error creating index: {str(e)}"
|
95 |
+
|
96 |
+
def run_search(index, query, text_top_k, image_top_k):
|
97 |
+
if not index:
|
98 |
+
return "Please create or select an index first.", [], []
|
99 |
+
|
100 |
+
retriever = index.as_retriever(
|
101 |
+
similarity_top_k=text_top_k,
|
102 |
+
image_similarity_top_k=image_top_k,
|
103 |
+
)
|
104 |
+
|
105 |
+
image_nodes = retriever.text_to_image_retrieve(query)
|
106 |
+
text_nodes = retriever.text_retrieve(query)
|
107 |
+
|
108 |
+
# Extract text and scores from nodes
|
109 |
+
text_results = [{"text": node.text, "score": f"{node.score:.3f}"} for node in text_nodes]
|
110 |
+
|
111 |
+
# Load images and scores
|
112 |
+
image_results = []
|
113 |
+
for node in image_nodes:
|
114 |
+
if hasattr(node.node, 'image_path') and os.path.exists(node.node.image_path):
|
115 |
+
try:
|
116 |
+
image_results.append((
|
117 |
+
node.node.image_path,
|
118 |
+
f"Similarity: {node.score:.3f}",
|
119 |
+
))
|
120 |
+
except Exception as e:
|
121 |
+
print(f"Error loading image {node.node.image_path}: {e}")
|
122 |
+
|
123 |
+
return "Search completed!", text_results, image_results
|
124 |
+
|
125 |
+
# Create the Gradio interface
|
126 |
+
with gr.Blocks() as demo:
|
127 |
+
gr.Markdown("# Multi-Modal Retrieval with LlamaIndex and llamaindex/vdr-2b-multi-v1")
|
128 |
+
gr.Markdown("""
|
129 |
+
This demo shows how to use the new `llamaindex/vdr-2b-multi-v1` model for multi-modal document search.
|
130 |
+
|
131 |
+
Using this model, we can index images and perform text-to-image retrieval.
|
132 |
+
|
133 |
+
This demo compares to pure text retrieval using the `BAAI/bge-small-en` model. Is this a fair comparison? Not really,
|
134 |
+
but it's the easiest to run in a limited huggingface space, and shows the strengths of screenshot-based retrieval.
|
135 |
+
"""
|
136 |
+
)
|
137 |
+
|
138 |
+
with gr.Row():
|
139 |
+
with gr.Column():
|
140 |
+
# Index selection/creation
|
141 |
+
with gr.Tab("Use Existing Index"):
|
142 |
+
existing_index_dropdown = gr.Dropdown(
|
143 |
+
choices=list(example_indexes.keys()),
|
144 |
+
label="Select Pre-made Index",
|
145 |
+
value=list(example_indexes.keys())[0]
|
146 |
+
)
|
147 |
+
|
148 |
+
with gr.Tab("Create New Index"):
|
149 |
+
gr.Markdown(
|
150 |
+
"""
|
151 |
+
To create a new index, enter your LlamaParse API key and upload a PDF.
|
152 |
+
|
153 |
+
You can get a free API key by signing up [here](https://cloud.llamaindex.ai).
|
154 |
+
|
155 |
+
Processing will take a few minutes when creating a new index, depending on the size of the document.
|
156 |
+
"""
|
157 |
+
)
|
158 |
+
file_upload = gr.File(label="Upload PDF")
|
159 |
+
llama_parse_key = gr.Textbox(
|
160 |
+
label="LlamaParse API Key",
|
161 |
+
type="password"
|
162 |
+
)
|
163 |
+
create_btn = gr.Button("Create Index")
|
164 |
+
create_status = gr.Textbox(label="Status", interactive=False)
|
165 |
+
|
166 |
+
# Search controls
|
167 |
+
query_input = gr.Textbox(label="Search Query", value="What is the Executive Summary?")
|
168 |
+
with gr.Row():
|
169 |
+
text_top_k = gr.Slider(
|
170 |
+
minimum=1,
|
171 |
+
maximum=10,
|
172 |
+
value=2,
|
173 |
+
step=1,
|
174 |
+
label="Text Top-K"
|
175 |
+
)
|
176 |
+
image_top_k = gr.Slider(
|
177 |
+
minimum=1,
|
178 |
+
maximum=10,
|
179 |
+
value=2,
|
180 |
+
step=1,
|
181 |
+
label="Image Top-K"
|
182 |
+
)
|
183 |
+
search_btn = gr.Button("Search")
|
184 |
+
|
185 |
+
with gr.Column():
|
186 |
+
# Results display
|
187 |
+
status_output = gr.Textbox(label="Search Status")
|
188 |
+
image_output = gr.Gallery(
|
189 |
+
label="Retrieved Images",
|
190 |
+
show_label=True, # This will show the similarity score captions
|
191 |
+
elem_id="gallery"
|
192 |
+
)
|
193 |
+
text_output = gr.JSON(
|
194 |
+
label="Retrieved Text with Similarity Scores",
|
195 |
+
elem_id="text_results"
|
196 |
+
)
|
197 |
+
|
198 |
+
# State
|
199 |
+
index_state = gr.State()
|
200 |
+
|
201 |
+
# Load default index on startup
|
202 |
+
default_index = load_index(example_indexes["IONIQ 2024"])
|
203 |
+
index_state.value = default_index
|
204 |
+
|
205 |
+
# Event handlers
|
206 |
+
def load_existing_index(index_name):
|
207 |
+
if index_name:
|
208 |
+
try:
|
209 |
+
index = load_index(example_indexes[index_name])
|
210 |
+
return index, f"Loaded index: {index_name}"
|
211 |
+
except Exception as e:
|
212 |
+
return None, f"Error loading index: {str(e)}"
|
213 |
+
return None, "No index selected"
|
214 |
+
|
215 |
+
existing_index_dropdown.change(
|
216 |
+
fn=load_existing_index,
|
217 |
+
inputs=[existing_index_dropdown],
|
218 |
+
outputs=[index_state, create_status],
|
219 |
+
api_name=False
|
220 |
+
)
|
221 |
+
|
222 |
+
create_btn.click(
|
223 |
+
fn=create_index,
|
224 |
+
inputs=[file_upload, llama_parse_key],
|
225 |
+
outputs=[index_state, create_status],
|
226 |
+
api_name=False,
|
227 |
+
show_progress=True # Enable progress bar
|
228 |
+
)
|
229 |
+
|
230 |
+
search_btn.click(
|
231 |
+
fn=run_search,
|
232 |
+
inputs=[index_state, query_input, text_top_k, image_top_k],
|
233 |
+
outputs=[status_output, text_output, image_output],
|
234 |
+
api_name=False
|
235 |
+
)
|
236 |
+
|
237 |
+
gr.Markdown("""
|
238 |
+
This demo was built with [LlamaIndex](https://docs.llamaindex.ai) and [LlamaParse](https://cloud.llamaindex.ai). To see more multi-modal demos, check out the [llama parse examples](https://github.com/run-llama/llama_parse/tree/main/examples/multimodal).
|
239 |
+
"""
|
240 |
+
)
|
241 |
+
|
242 |
+
if __name__ == "__main__":
|
243 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
llama-index-core==0.12.10
|
2 |
+
llama-index-embeddings-huggingface==0.5.0
|
3 |
+
llama-parse==0.5.19
|