Kian Kyars commited on
Commit
887133f
·
1 Parent(s): b0866e4

Remove comments and clean up chat-with-pdf app

Browse files
Files changed (1) hide show
  1. app.py +293 -83
app.py CHANGED
@@ -1,94 +1,304 @@
 
 
 
 
1
  import modal
2
- import gradio as gr
3
- import os
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
-
6
- image = modal.Image.debian_slim().pip_install(
7
- "torch",
8
- "transformers",
9
- "accelerate",
10
- "gradio"
11
- )
12
 
13
- app = modal.App("agentic-demo", image=image)
14
-
15
- ALL_MODELS = [
16
- "Qwen/Qwen-72B",
17
- "deepseek-ai/deepseek-llm-67b-base",
18
- "openchat/openchat-3.5-1210",
19
- "microsoft/phi-2",
20
- "google/gemma-7b",
21
- "01-ai/Yi-34B",
22
- "upstage/SOLAR-10.7B-v1.0",
23
- "microsoft/Orca-2-13b",
24
- "lmsys/vicuna-13b-v1.5"
25
- ]
26
-
27
- def debate_agent(topic, agent_a_model, agent_b_model, judge_model):
28
- if len({agent_a_model, agent_b_model, judge_model}) < 3:
29
- return {"error": "Please select three different models."}
30
- # Agent A
31
- tokenizer_a = AutoTokenizer.from_pretrained(agent_a_model)
32
- model_a = AutoModelForCausalLM.from_pretrained(
33
- agent_a_model,
34
- load_in_4bit=True,
35
- device_map="auto"
36
- )
37
- prompt_a = f"Debate as Agent A: {topic}"
38
- inputs_a = tokenizer_a(prompt_a, return_tensors="pt").to(model_a.device)
39
- outputs_a = model_a.generate(**inputs_a, max_new_tokens=10000)
40
- arg_a = tokenizer_a.decode(outputs_a[0], skip_special_tokens=True)
41
- # Agent B
42
- tokenizer_b = AutoTokenizer.from_pretrained(agent_b_model)
43
- model_b = AutoModelForCausalLM.from_pretrained(
44
- agent_b_model,
45
- load_in_4bit=True,
46
- device_map="auto"
47
- )
48
- prompt_b = f"Debate as Agent B: {topic}"
49
- inputs_b = tokenizer_b(prompt_b, return_tensors="pt").to(model_b.device)
50
- outputs_b = model_b.generate(**inputs_b, max_new_tokens=10000)
51
- arg_b = tokenizer_b.decode(outputs_b[0], skip_special_tokens=True)
52
- # Judge
53
- judge_prompt = (
54
- f"You are the judge of a debate.\n"
55
- f"Topic: {topic}\n"
56
- f"Agent A says: {arg_a}\n"
57
- f"Agent B says: {arg_b}\n"
58
- f"Summarize both arguments and pick a winner (A or B) with a short justification."
59
  )
60
- tokenizer_j = AutoTokenizer.from_pretrained(judge_model)
61
- model_j = AutoModelForCausalLM.from_pretrained(
62
- judge_model,
63
- load_in_4bit=True,
64
- device_map="auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  )
66
- inputs_j = tokenizer_j(judge_prompt, return_tensors="pt").to(model_j.device)
67
- outputs_j = model_j.generate(**inputs_j, max_new_tokens=10000)
68
- judge_summary = tokenizer_j.decode(outputs_j[0], skip_special_tokens=True)
69
- return {
70
- "Agent A": arg_a,
71
- "Agent B": arg_b,
72
- "Judge": judge_summary
73
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- @app.function(gpu="B200")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  @modal.asgi_app()
77
- def fastapi_app():
 
 
78
  import gradio as gr
79
  from fastapi import FastAPI
80
  from gradio.routes import mount_gradio_app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- demo = gr.Interface(
83
- fn=debate_agent,
84
- inputs=[
85
- gr.Textbox(label="Debate Topic"),
86
- gr.Dropdown(ALL_MODELS, label="Agent A Model", value=ALL_MODELS[0]),
87
- gr.Dropdown(ALL_MODELS, label="Agent B Model", value=ALL_MODELS[1]),
88
- gr.Dropdown(ALL_MODELS, label="Judge Model", value=ALL_MODELS[2])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  ],
90
- outputs=gr.JSON(label="Debate Results"),
91
- title="Agentic Demo: LLM Debate & Judge"
92
- )
93
- demo.queue(max_size=5)
94
- return mount_gradio_app(app=FastAPI(), blocks=demo, path="/")
 
 
 
 
 
1
+ from pathlib import Path
2
+ from urllib.request import urlopen
3
+ from uuid import uuid4
4
+
5
  import modal
 
 
 
 
 
 
 
 
 
 
6
 
7
+ MINUTES = 60
8
+
9
+ app = modal.App("chat-with-pdf")
10
+
11
+ CACHE_DIR = "/hf-cache"
12
+
13
+ model_image = (
14
+ modal.Image.debian_slim(python_version="3.12")
15
+ .apt_install("git")
16
+ .pip_install(
17
+ [
18
+ "git+https://github.com/illuin-tech/colpali.git@782edcd50108d1842d154730ad3ce72476a2d17d",
19
+ "hf_transfer==0.1.8",
20
+ "qwen-vl-utils==0.0.8",
21
+ "torchvision==0.19.1",
22
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  )
24
+ .env({"HF_HUB_ENABLE_HF_TRANSFER": "1", "HF_HUB_CACHE": CACHE_DIR})
25
+ )
26
+
27
+ with model_image.imports():
28
+ import torch
29
+ from colpali_engine.models import ColQwen2, ColQwen2Processor
30
+ from qwen_vl_utils import process_vision_info
31
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
32
+
33
+ MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct"
34
+ MODEL_REVISION = "aca78372505e6cb469c4fa6a35c60265b00ff5a4"
35
+
36
+ sessions = modal.Dict.from_name("colqwen-chat-sessions", create_if_missing=True)
37
+
38
+ class Session:
39
+ def __init__(self):
40
+ self.images = None
41
+ self.messages = []
42
+ self.pdf_embeddings = None
43
+
44
+ pdf_volume = modal.Volume.from_name("colqwen-chat-pdfs", create_if_missing=True)
45
+ PDF_ROOT = Path("/vol/pdfs/")
46
+
47
+ cache_volume = modal.Volume.from_name("hf-hub-cache", create_if_missing=True)
48
+
49
+ @app.function(
50
+ image=model_image, volumes={CACHE_DIR: cache_volume}, timeout=20 * MINUTES
51
+ )
52
+ def download_model():
53
+ from huggingface_hub import snapshot_download
54
+
55
+ result = snapshot_download(
56
+ MODEL_NAME,
57
+ revision=MODEL_REVISION,
58
+ ignore_patterns=["*.pt", "*.bin"],
59
  )
60
+ print(f"Downloaded model weights to {result}")
61
+
62
+ @app.cls(
63
+ image=model_image,
64
+ gpu="A100-80GB",
65
+ scaledown_window=10 * MINUTES,
66
+ volumes={"/vol/pdfs/": pdf_volume, CACHE_DIR: cache_volume},
67
+ )
68
+ class Model:
69
+ @modal.enter()
70
+ def load_models(self):
71
+ self.colqwen2_model = ColQwen2.from_pretrained(
72
+ "vidore/colqwen2-v0.1",
73
+ torch_dtype=torch.bfloat16,
74
+ device_map="cuda:0",
75
+ )
76
+ self.colqwen2_processor = ColQwen2Processor.from_pretrained(
77
+ "vidore/colqwen2-v0.1"
78
+ )
79
+ self.qwen2_vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
80
+ MODEL_NAME,
81
+ revision=MODEL_REVISION,
82
+ torch_dtype=torch.bfloat16,
83
+ )
84
+ self.qwen2_vl_model.to("cuda:0")
85
+ self.qwen2_vl_processor = AutoProcessor.from_pretrained(
86
+ "Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True
87
+ )
88
+
89
+ @modal.method()
90
+ def index_pdf(self, session_id, target: bytes | list):
91
+ session = sessions.get(session_id)
92
+ if session is None:
93
+ session = Session()
94
+
95
+ if isinstance(target, bytes):
96
+ images = convert_pdf_to_images.remote(target)
97
+ else:
98
+ images = target
99
+
100
+ session_dir = PDF_ROOT / f"{session_id}"
101
+ session_dir.mkdir(exist_ok=True, parents=True)
102
+ for ii, image in enumerate(images):
103
+ filename = session_dir / f"{str(ii).zfill(3)}.jpg"
104
+ image.save(filename)
105
+
106
+ BATCH_SZ = 4
107
+ pdf_embeddings = []
108
+ batches = [images[i : i + BATCH_SZ] for i in range(0, len(images), BATCH_SZ)]
109
+ for batch in batches:
110
+ batch_images = self.colqwen2_processor.process_images(batch).to(
111
+ self.colqwen2_model.device
112
+ )
113
+ pdf_embeddings += list(self.colqwen2_model(**batch_images).to("cpu"))
114
+
115
+ session.pdf_embeddings = pdf_embeddings
116
+ sessions[session_id] = session
117
+
118
+ @modal.method()
119
+ def respond_to_message(self, session_id, message):
120
+ session = sessions.get(session_id)
121
+ if session is None:
122
+ session = Session()
123
+
124
+ pdf_volume.reload()
125
+
126
+ images = (PDF_ROOT / str(session_id)).glob("*.jpg")
127
+ images = list(sorted(images, key=lambda p: int(p.stem)))
128
+
129
+ if not images:
130
+ return "Please upload a PDF first"
131
+ elif session.pdf_embeddings is None:
132
+ return "Indexing PDF..."
133
+
134
+ relevant_image = self.get_relevant_image(message, session, images)
135
+ output_text = self.generate_response(message, session, relevant_image)
136
+
137
+ append_to_messages(message, session, user_type="user")
138
+ append_to_messages(output_text, session, user_type="assistant")
139
+ sessions[session_id] = session
140
+
141
+ return output_text
142
 
143
+ def get_relevant_image(self, message, session, images):
144
+ import PIL
145
+
146
+ batch_queries = self.colqwen2_processor.process_queries([message]).to(
147
+ self.colqwen2_model.device
148
+ )
149
+ query_embeddings = self.colqwen2_model(**batch_queries)
150
+
151
+ scores = self.colqwen2_processor.score_multi_vector(
152
+ query_embeddings, session.pdf_embeddings
153
+ )[0]
154
+
155
+ max_index = max(range(len(scores)), key=lambda index: scores[index])
156
+ return PIL.Image.open(images[max_index])
157
+
158
+ def generate_response(self, message, session, image):
159
+ chatbot_message = get_chatbot_message_with_image(message, image)
160
+ query = self.qwen2_vl_processor.apply_chat_template(
161
+ [*session.messages, chatbot_message],
162
+ tokenize=False,
163
+ add_generation_prompt=True,
164
+ )
165
+ image_inputs, _ = process_vision_info([chatbot_message])
166
+ inputs = self.qwen2_vl_processor(
167
+ text=[query],
168
+ images=image_inputs,
169
+ padding=True,
170
+ return_tensors="pt",
171
+ )
172
+ inputs = inputs.to("cuda:0")
173
+
174
+ generated_ids = self.qwen2_vl_model.generate(**inputs, max_new_tokens=512)
175
+ generated_ids_trimmed = [
176
+ out_ids[len(in_ids) :]
177
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
178
+ ]
179
+ output_text = self.qwen2_vl_processor.batch_decode(
180
+ generated_ids_trimmed,
181
+ skip_special_tokens=True,
182
+ clean_up_tokenization_spaces=False,
183
+ )[0]
184
+ return output_text
185
+
186
+ pdf_image = (
187
+ modal.Image.debian_slim(python_version="3.12")
188
+ .apt_install("poppler-utils")
189
+ .pip_install("pdf2image==1.17.0", "pillow==10.4.0")
190
+ )
191
+
192
+ @app.function(image=pdf_image)
193
+ def convert_pdf_to_images(pdf_bytes):
194
+ from pdf2image import convert_from_bytes
195
+
196
+ images = convert_from_bytes(pdf_bytes, fmt="jpeg")
197
+ return images
198
+
199
+ @app.local_entrypoint()
200
+ def main(question: str = None, pdf_path: str = None, session_id: str = None):
201
+ model = Model()
202
+ if session_id is None:
203
+ session_id = str(uuid4())
204
+ print("Starting a new session with id", session_id)
205
+
206
+ if pdf_path is None:
207
+ pdf_path = "https://arxiv.org/pdf/1706.03762"
208
+
209
+ if pdf_path.startswith("http"):
210
+ pdf_bytes = urlopen(pdf_path).read()
211
+ else:
212
+ pdf_path = Path(pdf_path)
213
+ pdf_bytes = pdf_path.read_bytes()
214
+
215
+ print("Indexing PDF from", pdf_path)
216
+ model.index_pdf.remote(session_id, pdf_bytes)
217
+ else:
218
+ if pdf_path is not None:
219
+ raise ValueError("Start a new session to chat with a new PDF")
220
+ print("Resuming session with id", session_id)
221
+
222
+ if question is None:
223
+ question = "What is this document about?"
224
+
225
+ print("QUESTION:", question)
226
+ print(model.respond_to_message.remote(session_id, question))
227
+
228
+ web_image = pdf_image.pip_install(
229
+ "fastapi[standard]==0.115.4",
230
+ "pydantic==2.9.2",
231
+ "starlette==0.41.2",
232
+ "gradio==4.44.1",
233
+ "pillow==10.4.0",
234
+ "gradio-pdf==0.0.15",
235
+ "pdf2image==1.17.0",
236
+ )
237
+
238
+ @app.function(
239
+ image=web_image,
240
+ max_containers=1,
241
+ )
242
+ @modal.concurrent(max_inputs=1000)
243
  @modal.asgi_app()
244
+ def ui():
245
+ import uuid
246
+
247
  import gradio as gr
248
  from fastapi import FastAPI
249
  from gradio.routes import mount_gradio_app
250
+ from gradio_pdf import PDF
251
+ from pdf2image import convert_from_path
252
+
253
+ web_app = FastAPI()
254
+ model = Model()
255
+
256
+ def upload_pdf(path, session_id):
257
+ if session_id == "" or session_id is None:
258
+ session_id = str(uuid.uuid4())
259
+
260
+ images = convert_from_path(path)
261
+ model.index_pdf.remote(session_id, images)
262
+
263
+ return session_id
264
 
265
+ def respond_to_message(message, _, session_id):
266
+ return model.respond_to_message.remote(session_id, message)
267
+
268
+ with gr.Blocks(theme="soft") as demo:
269
+ session_id = gr.State("")
270
+
271
+ gr.Markdown("# Chat with PDF")
272
+ with gr.Row():
273
+ with gr.Column(scale=1):
274
+ gr.ChatInterface(
275
+ fn=respond_to_message,
276
+ additional_inputs=[session_id],
277
+ retry_btn=None,
278
+ undo_btn=None,
279
+ clear_btn=None,
280
+ )
281
+ with gr.Column(scale=1):
282
+ pdf = PDF(
283
+ label="Upload a PDF",
284
+ )
285
+ pdf.upload(upload_pdf, [pdf, session_id], session_id)
286
+
287
+ return mount_gradio_app(app=web_app, blocks=demo, path="/")
288
+
289
+ def get_chatbot_message_with_image(message, image):
290
+ return {
291
+ "role": "user",
292
+ "content": [
293
+ {"type": "image", "image": image},
294
+ {"type": "text", "text": message},
295
  ],
296
+ }
297
+
298
+ def append_to_messages(message, session, user_type="user"):
299
+ session.messages.append(
300
+ {
301
+ "role": user_type,
302
+ "content": {"type": "text", "text": message},
303
+ },
304
+ )