raksama19 commited on
Commit
84e44dc
Β·
verified Β·
1 Parent(s): 3c6cd0f

Update chat.py

Browse files
Files changed (1) hide show
  1. chat.py +771 -178
chat.py CHANGED
@@ -1,198 +1,791 @@
1
- """
2
- Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
- SPDX-License-Identifier: MIT
4
  """
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import os
7
- import warnings
8
- from collections import OrderedDict
 
 
 
 
9
 
10
- from omegaconf import ListConfig
 
 
 
 
 
 
11
 
12
- warnings.filterwarnings("ignore", category=UserWarning)
13
- warnings.filterwarnings("ignore", category=FutureWarning)
14
- os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python")
15
 
16
- import torch
17
- from PIL import Image
18
- from transformers import PreTrainedTokenizerFast
19
-
20
- from utils.model import DonutConfig, DonutModel, SwinEncoder
21
- from utils.processor import DolphinProcessor
22
-
23
-
24
- def try_rename_lagacy_weights(ckpt, output_path=""):
25
- if "state_dict" in ckpt.keys():
26
- ckpt = ckpt["state_dict"]
27
- if "module" in ckpt.keys():
28
- ckpt = ckpt["module"]
29
- new_ckpt = OrderedDict()
30
- for k, v in ckpt.items():
31
- if k.startswith("model."):
32
- k = k[len("model.") :]
33
- if k.startswith("encoder"):
34
- new_ckpt["vpm" + k[len("encoder") :]] = v
35
- elif k.startswith("decoder"):
36
- new_ckpt["llm" + k[len("encoder") :]] = v
 
 
 
 
 
 
 
 
 
 
 
 
37
  else:
38
- new_ckpt[k] = v
39
- if output_path:
40
- torch.save(new_ckpt, output_path)
41
- return new_ckpt
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- def convert_listconfig_to_list(config):
45
- new_config = {}
46
- for k, v in config.items():
47
- if isinstance(v, ListConfig):
48
- new_config[k] = list(v)
 
 
 
49
  else:
50
- new_config[k] = v
51
- return new_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
- class DOLPHIN:
55
- def __init__(self, config, ckpt_path="") -> None:
56
- self.model_args = config.model
57
- self.swin_args = config.model.pop("swin_args")
58
- self.swin_args = convert_listconfig_to_list(self.swin_args)
59
-
60
- vision_tower = SwinEncoder(
61
- input_size=self.swin_args["img_size"],
62
- patch_size=self.swin_args["patch_size"],
63
- embed_dim=self.swin_args["embed_dim"],
64
- window_size=self.swin_args["window_size"],
65
- encoder_layer=self.swin_args["encoder_layer"],
66
- num_heads=self.swin_args["num_heads"],
67
- align_long_axis=self.swin_args["align_long_axis"],
68
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.model_args.tokenizer_path)
71
- self.tokenizer.pad_token = "<pad>"
72
- self.tokenizer.bos_token = "<s>"
73
- self.tokenizer.eos_token = "</s>"
74
- self.tokenizer.unk_token = "<unk>"
75
-
76
- if self.model_args.get("extra_answer_tokens", False):
77
- # print("Allowing multitask training: adding <Answer/> to the tokenizer.")
78
- prompt_end_token = " <Answer/>"
79
- self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set([prompt_end_token]))})
80
- self.tokenizer._prompt_end_token = prompt_end_token
81
- self.tokenizer._prompt_end_token_id = self.tokenizer.convert_tokens_to_ids(prompt_end_token)
82
-
83
- donut_config = DonutConfig(
84
- decoder_layer=self.model_args.decoder_layer,
85
- max_length=self.model_args.max_length,
86
- max_position_embeddings=self.model_args.max_position_embeddings,
87
- hidden_dimension=self.model_args.hidden_dimension,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  )
 
 
 
 
89
 
90
- self.model = DonutModel(config=donut_config, vision_tower=vision_tower, tokenizer=self.tokenizer)
91
- if self.model_args.model_name_or_path:
92
- ckpt = torch.load(self.model_args.model_name_or_path)
93
- ckpt = try_rename_lagacy_weights(ckpt)
94
- self.model.load_state_dict(ckpt, strict=True)
95
 
96
- device = "cuda" if torch.cuda.is_available() else "cpu"
97
- self.model.to(device)
98
- self.model.eval()
99
- transform_args = {
100
- "input_size": self.swin_args["img_size"],
101
- "max_length": self.model_args.max_length,
102
- }
103
- self.processor = DolphinProcessor({}, self.tokenizer, transform_args=transform_args)
104
-
105
- def chat(
106
- self,
107
- question,
108
- image,
109
- return_raw=False,
110
- return_score=False,
111
- return_img_size=False,
112
- only_return_img_size=False,
113
- max_batch_size=16,
114
- ):
115
-
116
- def _preprocess_image(image):
117
- if isinstance(image, str):
118
- image = Image.open(image).convert("RGB")
119
- if return_img_size or only_return_img_size:
120
- image_tensor, ori_size = self.processor.process_image_for_inference(image, return_img_size=True)
121
- else:
122
- image_tensor = self.processor.process_image_for_inference(image, return_img_size=False)
123
- ori_size = None
124
- return image_tensor, ori_size
125
-
126
- def _preprocess_prompt(question):
127
- if self.model_args.get("extra_answer_tokens", False):
128
- if self.tokenizer._prompt_end_token not in question:
129
- question = question + self.tokenizer._prompt_end_token
130
- prompt_ids = self.processor.process_prompt_for_inference(question)
131
- return prompt_ids
132
-
133
- def _preprocess_prompt_batch(question):
134
- if self.model_args.get("extra_answer_tokens", False):
135
- for i in range(len(question)):
136
- if self.tokenizer._prompt_end_token not in question[i]:
137
- question[i] = question[i] + self.tokenizer._prompt_end_token
138
- if not question[i].startswith("<s>"):
139
- question[i] = "<s>" + question[i]
140
- return question
141
-
142
- def _postprocess(output, question):
143
- output = output.replace("<s>", "").replace(question, "").replace("</s>", "").replace("<pad>", "")
144
- if self.model_args.get("extra_answer_tokens", False):
145
- output = output.split(self.tokenizer._prompt_end_token)[-1]
146
- return output
147
-
148
- if isinstance(question, list):
149
- image_tensor_list = []
150
- for i in image:
151
- image_tensor, ori_size = _preprocess_image(i)
152
- image_tensor_list.append(image_tensor)
153
- image_tensor = torch.cat(image_tensor_list, dim=0)
154
-
155
- question = _preprocess_prompt_batch(question)
156
- self.processor.tokenizer.padding_side = "left"
157
- prompt_ids = self.processor.tokenizer(
158
- question, add_special_tokens=False, return_tensors="pt", padding=True
159
- ).input_ids
 
 
160
  else:
161
- image_tensor, ori_size = _preprocess_image(image)
162
- prompt_ids = _preprocess_prompt(question)
163
-
164
- if only_return_img_size:
165
- return ori_size
166
-
167
- model_output_batch = []
168
- for i in range(0, image_tensor.shape[0], max_batch_size):
169
- image_tensor_batch = image_tensor[i : i + max_batch_size]
170
- prompt_ids_batch = prompt_ids[i : i + max_batch_size]
171
- model_output = self.model.inference(image_tensors=image_tensor_batch, prompt_ids=prompt_ids_batch)
172
- model_output_batch.append(model_output)
173
- model_output = {}
174
- for k, v in model_output_batch[0].items():
175
- if isinstance(v, torch.Tensor):
176
- model_output[k] = sum(
177
- [v_batch[k].cpu().numpy().tolist() for v_batch in model_output_batch],
178
- [],
179
- )
180
- else:
181
- model_output[k] = sum([v_batch[k] for v_batch in model_output_batch], [])
182
 
183
- if return_raw:
184
- if return_img_size:
185
- return model_output, ori_size
186
- return model_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  else:
188
- if isinstance(question, list):
189
- output = [_postprocess(model_output["repetitions"][i], question[i]) for i in range(len(question))]
190
- score = model_output["scores"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  else:
192
- output = _postprocess(model_output["repetitions"][0], question)
193
- score = model_output["scores"][0]
194
- if return_score:
195
- return output, score
196
- if return_img_size:
197
- return output, ori_size
198
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DOLPHIN PDF Document AI - Final Version
3
+ Optimized for HuggingFace Spaces NVIDIA T4 Small deployment
4
  """
5
 
6
+ import gradio as gr
7
+ import json
8
+ import markdown
9
+ import cv2
10
+ import numpy as np
11
+ from PIL import Image
12
+ from transformers import AutoProcessor, VisionEncoderDecoderModel, Gemma3nForConditionalGeneration, pipeline
13
+ import torch
14
+ try:
15
+ from sentence_transformers import SentenceTransformer
16
+ import numpy as np
17
+ from sklearn.metrics.pairwise import cosine_similarity
18
+ import google.generativeai as genai
19
+ RAG_DEPENDENCIES_AVAILABLE = True
20
+ except ImportError as e:
21
+ print(f"RAG dependencies not available: {e}")
22
+ print("Please install: pip install sentence-transformers scikit-learn google-generativeai")
23
+ RAG_DEPENDENCIES_AVAILABLE = False
24
+ SentenceTransformer = None
25
  import os
26
+ import tempfile
27
+ import uuid
28
+ import base64
29
+ import io
30
+ from utils.utils import *
31
+ from utils.markdown_utils import MarkdownConverter
32
 
33
+ # Math extension is optional for enhanced math rendering
34
+ MATH_EXTENSION_AVAILABLE = False
35
+ try:
36
+ from mdx_math import MathExtension
37
+ MATH_EXTENSION_AVAILABLE = True
38
+ except ImportError:
39
+ pass
40
 
 
 
 
41
 
42
+ class DOLPHIN:
43
+ def __init__(self, model_id_or_path):
44
+ """Initialize the Hugging Face model optimized for T4 Small"""
45
+ self.processor = AutoProcessor.from_pretrained(model_id_or_path)
46
+ self.model = VisionEncoderDecoderModel.from_pretrained(
47
+ model_id_or_path,
48
+ torch_dtype=torch.float16,
49
+ device_map="auto" if torch.cuda.is_available() else None
50
+ )
51
+ self.model.eval()
52
+
53
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ if not torch.cuda.is_available():
55
+ self.model = self.model.float()
56
+
57
+ self.tokenizer = self.processor.tokenizer
58
+
59
+ def chat(self, prompt, image):
60
+ """Process an image or batch of images with the given prompt(s)"""
61
+ is_batch = isinstance(image, list)
62
+
63
+ if not is_batch:
64
+ images = [image]
65
+ prompts = [prompt]
66
+ else:
67
+ images = image
68
+ prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
69
+
70
+ batch_inputs = self.processor(images, return_tensors="pt", padding=True)
71
+ batch_pixel_values = batch_inputs.pixel_values
72
+
73
+ if torch.cuda.is_available():
74
+ batch_pixel_values = batch_pixel_values.half().to(self.device)
75
  else:
76
+ batch_pixel_values = batch_pixel_values.to(self.device)
77
+
78
+ prompts = [f"<s>{p} <Answer/>" for p in prompts]
79
+ batch_prompt_inputs = self.tokenizer(
80
+ prompts,
81
+ add_special_tokens=False,
82
+ return_tensors="pt"
83
+ )
84
 
85
+ batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device)
86
+ batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device)
87
+
88
+ with torch.no_grad():
89
+ outputs = self.model.generate(
90
+ pixel_values=batch_pixel_values,
91
+ decoder_input_ids=batch_prompt_ids,
92
+ decoder_attention_mask=batch_attention_mask,
93
+ min_length=1,
94
+ max_length=1024, # Reduced for T4 Small
95
+ pad_token_id=self.tokenizer.pad_token_id,
96
+ eos_token_id=self.tokenizer.eos_token_id,
97
+ use_cache=True,
98
+ bad_words_ids=[[self.tokenizer.unk_token_id]],
99
+ return_dict_in_generate=True,
100
+ do_sample=False,
101
+ num_beams=1,
102
+ repetition_penalty=1.1,
103
+ temperature=1.0
104
+ )
105
+
106
+ sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
107
+
108
+ results = []
109
+ for i, sequence in enumerate(sequences):
110
+ cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
111
+ results.append(cleaned)
112
+
113
+ if not is_batch:
114
+ return results[0]
115
+ return results
116
 
117
+
118
+ def convert_pdf_to_images_gradio(pdf_file):
119
+ """Convert uploaded PDF file to list of PIL Images"""
120
+ try:
121
+ import pymupdf
122
+
123
+ if isinstance(pdf_file, str):
124
+ pdf_document = pymupdf.open(pdf_file)
125
  else:
126
+ pdf_bytes = pdf_file.read()
127
+ pdf_document = pymupdf.open(stream=pdf_bytes, filetype="pdf")
128
+
129
+ images = []
130
+ for page_num in range(len(pdf_document)):
131
+ page = pdf_document[page_num]
132
+ mat = pymupdf.Matrix(2.0, 2.0)
133
+ pix = page.get_pixmap(matrix=mat)
134
+ img_data = pix.tobytes("png")
135
+ pil_image = Image.open(io.BytesIO(img_data)).convert("RGB")
136
+ images.append(pil_image)
137
+
138
+ pdf_document.close()
139
+ return images
140
+
141
+ except Exception as e:
142
+ raise Exception(f"Error converting PDF: {str(e)}")
143
 
144
 
145
+ def process_pdf_document(pdf_file, model, progress=gr.Progress()):
146
+ """Process uploaded PDF file page by page"""
147
+ if pdf_file is None:
148
+ return "No PDF file uploaded", ""
149
+
150
+ try:
151
+ progress(0.1, desc="Converting PDF to images...")
152
+ images = convert_pdf_to_images_gradio(pdf_file)
153
+
154
+ if not images:
155
+ return "Failed to convert PDF to images", ""
156
+
157
+ all_results = []
158
+
159
+ for page_idx, pil_image in enumerate(images):
160
+ progress((page_idx + 1) / len(images) * 0.8 + 0.1,
161
+ desc=f"Processing page {page_idx + 1}/{len(images)}...")
162
+
163
+ layout_output = model.chat("Parse the reading order of this document.", pil_image)
164
+
165
+ padded_image, dims = prepare_image(pil_image)
166
+ recognition_results = process_elements_optimized(
167
+ layout_output,
168
+ padded_image,
169
+ dims,
170
+ model,
171
+ max_batch_size=2 # Smaller batch for T4 Small
172
+ )
173
+
174
+ try:
175
+ markdown_converter = MarkdownConverter()
176
+ markdown_content = markdown_converter.convert(recognition_results)
177
+ except:
178
+ markdown_content = generate_fallback_markdown(recognition_results)
179
+
180
+ page_result = {
181
+ "page_number": page_idx + 1,
182
+ "markdown": markdown_content
183
+ }
184
+ all_results.append(page_result)
185
+
186
+ progress(1.0, desc="Processing complete!")
187
+
188
+ combined_markdown = "\n\n---\n\n".join([
189
+ f"# Page {result['page_number']}\n\n{result['markdown']}"
190
+ for result in all_results
191
+ ])
192
+
193
+ return combined_markdown, "processing_complete"
194
+
195
+ except Exception as e:
196
+ error_msg = f"Error processing PDF: {str(e)}"
197
+ return error_msg, "error"
198
 
199
+
200
+ def process_elements_optimized(layout_results, padded_image, dims, model, max_batch_size=2):
201
+ """Optimized element processing for T4 Small"""
202
+ layout_results = parse_layout_string(layout_results)
203
+
204
+ text_elements = []
205
+ table_elements = []
206
+ figure_results = []
207
+ previous_box = None
208
+ reading_order = 0
209
+
210
+ for bbox, label in layout_results:
211
+ try:
212
+ x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
213
+ bbox, padded_image, dims, previous_box
214
+ )
215
+
216
+ cropped = padded_image[y1:y2, x1:x2]
217
+ if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
218
+ if label == "fig":
219
+ pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
220
+ pil_crop = crop_margin(pil_crop)
221
+
222
+ buffered = io.BytesIO()
223
+ pil_crop.save(buffered, format="PNG")
224
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
225
+ data_uri = f"data:image/png;base64,{img_base64}"
226
+
227
+ figure_results.append({
228
+ "label": label,
229
+ "text": f"![Figure {reading_order}]({data_uri})",
230
+ "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
231
+ "reading_order": reading_order,
232
+ })
233
+ else:
234
+ pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
235
+ element_info = {
236
+ "crop": pil_crop,
237
+ "label": label,
238
+ "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
239
+ "reading_order": reading_order,
240
+ }
241
+
242
+ if label == "tab":
243
+ table_elements.append(element_info)
244
+ else:
245
+ text_elements.append(element_info)
246
+
247
+ reading_order += 1
248
+
249
+ except Exception as e:
250
+ print(f"Error processing element {label}: {str(e)}")
251
+ continue
252
+
253
+ recognition_results = figure_results.copy()
254
+
255
+ if text_elements:
256
+ text_results = process_element_batch_optimized(
257
+ text_elements, model, "Read text in the image.", max_batch_size
258
+ )
259
+ recognition_results.extend(text_results)
260
+
261
+ if table_elements:
262
+ table_results = process_element_batch_optimized(
263
+ table_elements, model, "Parse the table in the image.", max_batch_size
264
  )
265
+ recognition_results.extend(table_results)
266
+
267
+ recognition_results.sort(key=lambda x: x.get("reading_order", 0))
268
+ return recognition_results
269
 
 
 
 
 
 
270
 
271
+ def process_element_batch_optimized(elements, model, prompt, max_batch_size=2):
272
+ """Process elements in small batches for T4 Small"""
273
+ results = []
274
+ batch_size = min(len(elements), max_batch_size)
275
+
276
+ for i in range(0, len(elements), batch_size):
277
+ batch_elements = elements[i:i+batch_size]
278
+ crops_list = [elem["crop"] for elem in batch_elements]
279
+ prompts_list = [prompt] * len(crops_list)
280
+
281
+ batch_results = model.chat(prompts_list, crops_list)
282
+
283
+ for j, result in enumerate(batch_results):
284
+ elem = batch_elements[j]
285
+ results.append({
286
+ "label": elem["label"],
287
+ "bbox": elem["bbox"],
288
+ "text": result.strip(),
289
+ "reading_order": elem["reading_order"],
290
+ })
291
+
292
+ del crops_list, batch_elements
293
+ if torch.cuda.is_available():
294
+ torch.cuda.empty_cache()
295
+
296
+ return results
297
+
298
+
299
+ def generate_fallback_markdown(recognition_results):
300
+ """Generate basic markdown if converter fails"""
301
+ markdown_content = ""
302
+ for element in recognition_results:
303
+ if element["label"] == "tab":
304
+ markdown_content += f"\n\n{element['text']}\n\n"
305
+ elif element["label"] in ["para", "title", "sec", "sub_sec"]:
306
+ markdown_content += f"{element['text']}\n\n"
307
+ elif element["label"] == "fig":
308
+ markdown_content += f"{element['text']}\n\n"
309
+ return markdown_content
310
+
311
+
312
+ # Initialize model
313
+ model_path = "./hf_model"
314
+ if not os.path.exists(model_path):
315
+ model_path = "ByteDance/DOLPHIN"
316
+
317
+ # Model paths and configuration
318
+ model_path = "./hf_model" if os.path.exists("./hf_model") else "ByteDance/DOLPHIN"
319
+ hf_token = os.getenv('HF_TOKEN')
320
+
321
+ # Don't load models initially - load them on demand
322
+ model_status = "βœ… Models ready (Dynamic loading)"
323
+
324
+ # Initialize embedding model and Gemini API
325
+ if RAG_DEPENDENCIES_AVAILABLE:
326
+ try:
327
+ print("Loading embedding model for RAG...")
328
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
329
+ print("βœ… Embedding model loaded successfully (CPU)")
330
+
331
+ # Initialize Gemini API
332
+ gemini_api_key = os.getenv('GEMINI_API_KEY')
333
+ if gemini_api_key:
334
+ genai.configure(api_key=gemini_api_key)
335
+ gemini_model = genai.GenerativeModel('gemma-3n-e4b-it')
336
+ print("βœ… Gemini API configured successfully")
337
  else:
338
+ print("❌ GEMINI_API_KEY not found in environment")
339
+ gemini_model = None
340
+ except Exception as e:
341
+ print(f"❌ Error loading models: {e}")
342
+ import traceback
343
+ traceback.print_exc()
344
+ embedding_model = None
345
+ gemini_model = None
346
+ else:
347
+ print("❌ RAG dependencies not available")
348
+ embedding_model = None
349
+ gemini_model = None
 
 
 
 
 
 
 
 
 
350
 
351
+ # Model management functions
352
+ def load_dolphin_model():
353
+ """Load DOLPHIN model for PDF processing"""
354
+ global dolphin_model, current_model
355
+
356
+ if current_model == "dolphin":
357
+ return dolphin_model
358
+
359
+ # No need to unload chatbot model (using API now)
360
+
361
+ try:
362
+ print("Loading DOLPHIN model...")
363
+ dolphin_model = DOLPHIN(model_path)
364
+ current_model = "dolphin"
365
+ print(f"βœ… DOLPHIN model loaded (Device: {dolphin_model.device})")
366
+ return dolphin_model
367
+ except Exception as e:
368
+ print(f"❌ Error loading DOLPHIN model: {e}")
369
+ return None
370
+
371
+ def unload_dolphin_model():
372
+ """Unload DOLPHIN model to free memory"""
373
+ global dolphin_model, current_model
374
+
375
+ if dolphin_model is not None:
376
+ print("Unloading DOLPHIN model...")
377
+ del dolphin_model
378
+ dolphin_model = None
379
+ if current_model == "dolphin":
380
+ current_model = None
381
+ if torch.cuda.is_available():
382
+ torch.cuda.empty_cache()
383
+ print("βœ… DOLPHIN model unloaded")
384
+
385
+ def initialize_gemini_model():
386
+ """Initialize Gemini API model"""
387
+ global gemini_model
388
+
389
+ if gemini_model is not None:
390
+ return gemini_model
391
+
392
+ try:
393
+ gemini_api_key = os.getenv('GEMINI_API_KEY')
394
+ if not gemini_api_key:
395
+ print("❌ GEMINI_API_KEY not found in environment")
396
+ return None
397
+
398
+ print("Initializing Gemini API...")
399
+ genai.configure(api_key=gemini_api_key)
400
+ gemini_model = genai.GenerativeModel('gemma-3n-e4b-it')
401
+ print("βœ… Gemini API model ready")
402
+ return gemini_model
403
+ except Exception as e:
404
+ print(f"❌ Error initializing Gemini model: {e}")
405
+ import traceback
406
+ traceback.print_exc()
407
+ return None
408
+
409
+
410
+ # Global state for managing tabs
411
+ processed_markdown = ""
412
+ show_results_tab = False
413
+ document_chunks = []
414
+ document_embeddings = None
415
+
416
+ # Global model state
417
+ dolphin_model = None
418
+ gemini_model = None
419
+ current_model = None # Track which model is currently loaded
420
+
421
+
422
+ def chunk_document(text, chunk_size=300, overlap=50):
423
+ """Split document into overlapping chunks for RAG - optimized for API quota"""
424
+ words = text.split()
425
+ chunks = []
426
+
427
+ for i in range(0, len(words), chunk_size - overlap):
428
+ chunk = ' '.join(words[i:i + chunk_size])
429
+ if chunk.strip():
430
+ chunks.append(chunk)
431
+
432
+ return chunks
433
+
434
+ def create_embeddings(chunks):
435
+ """Create embeddings for document chunks"""
436
+ if embedding_model is None:
437
+ return None
438
+
439
+ try:
440
+ # Process in smaller batches on CPU
441
+ batch_size = 32
442
+ embeddings = []
443
+
444
+ for i in range(0, len(chunks), batch_size):
445
+ batch = chunks[i:i + batch_size]
446
+ batch_embeddings = embedding_model.encode(batch, show_progress_bar=False)
447
+ embeddings.extend(batch_embeddings)
448
+
449
+ return np.array(embeddings)
450
+ except Exception as e:
451
+ print(f"Error creating embeddings: {e}")
452
+ return None
453
+
454
+ def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
455
+ """Retrieve most relevant chunks for a question"""
456
+ if embedding_model is None or embeddings is None:
457
+ return chunks[:3] # Fallback to first 3 chunks
458
+
459
+ try:
460
+ question_embedding = embedding_model.encode([question], show_progress_bar=False)
461
+ similarities = cosine_similarity(question_embedding, embeddings)[0]
462
+
463
+ # Get top-k most similar chunks
464
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
465
+ relevant_chunks = [chunks[i] for i in top_indices]
466
+
467
+ return relevant_chunks
468
+ except Exception as e:
469
+ print(f"Error retrieving chunks: {e}")
470
+ return chunks[:3] # Fallback
471
+
472
+ def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
473
+ """Main processing function for uploaded PDF"""
474
+ global processed_markdown, show_results_tab, document_chunks, document_embeddings
475
+
476
+ if pdf_file is None:
477
+ return "❌ No PDF uploaded", gr.Tabs(visible=False)
478
+
479
+ try:
480
+ # Load DOLPHIN model for PDF processing
481
+ progress(0.1, desc="Loading DOLPHIN model...")
482
+ dolphin = load_dolphin_model()
483
+
484
+ if dolphin is None:
485
+ return "❌ Failed to load DOLPHIN model", gr.Tabs(visible=False)
486
+
487
+ # Process PDF
488
+ progress(0.2, desc="Processing PDF...")
489
+ combined_markdown, status = process_pdf_document(pdf_file, dolphin, progress)
490
+
491
+ if status == "processing_complete":
492
+ processed_markdown = combined_markdown
493
+
494
+ # Create chunks and embeddings for RAG
495
+ progress(0.9, desc="Creating document chunks for RAG...")
496
+ document_chunks = chunk_document(processed_markdown)
497
+ document_embeddings = create_embeddings(document_chunks)
498
+ print(f"Created {len(document_chunks)} chunks")
499
+
500
+ # Keep DOLPHIN model loaded for GPU usage
501
+ progress(0.95, desc="Preparing chatbot...")
502
+
503
+ show_results_tab = True
504
+ progress(1.0, desc="PDF processed successfully!")
505
+ return "βœ… PDF processed successfully! Chatbot is ready in the Chat tab.", gr.Tabs(visible=True)
506
  else:
507
+ show_results_tab = False
508
+ return combined_markdown, gr.Tabs(visible=False)
509
+
510
+ except Exception as e:
511
+ show_results_tab = False
512
+ error_msg = f"❌ Error processing PDF: {str(e)}"
513
+ return error_msg, gr.Tabs(visible=False)
514
+
515
+
516
+ def get_processed_markdown():
517
+ """Return the processed markdown content"""
518
+ global processed_markdown
519
+ return processed_markdown if processed_markdown else "No document processed yet."
520
+
521
+
522
+ def clear_all():
523
+ """Clear all data and hide results tab"""
524
+ global processed_markdown, show_results_tab, document_chunks, document_embeddings
525
+ processed_markdown = ""
526
+ show_results_tab = False
527
+ document_chunks = []
528
+ document_embeddings = None
529
+
530
+ # Unload DOLPHIN model
531
+ unload_dolphin_model()
532
+
533
+ return None, "", gr.Tabs(visible=False)
534
+
535
+
536
+ # Create Gradio interface
537
+ with gr.Blocks(
538
+ title="DOLPHIN PDF AI",
539
+ theme=gr.themes.Soft(),
540
+ css="""
541
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
542
+
543
+ * {
544
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important;
545
+ }
546
+
547
+ .main-container {
548
+ max-width: 1000px;
549
+ margin: 0 auto;
550
+ }
551
+ .upload-container {
552
+ text-align: center;
553
+ padding: 40px 20px;
554
+ border: 2px dashed #e0e0e0;
555
+ border-radius: 15px;
556
+ margin: 20px 0;
557
+ }
558
+ .upload-button {
559
+ font-size: 18px !important;
560
+ padding: 15px 30px !important;
561
+ margin: 20px 0 !important;
562
+ font-weight: 600 !important;
563
+ }
564
+ .status-message {
565
+ text-align: center;
566
+ padding: 15px;
567
+ margin: 10px 0;
568
+ border-radius: 8px;
569
+ font-weight: 500;
570
+ }
571
+ .chatbot-container {
572
+ max-height: 600px;
573
+ }
574
+ h1, h2, h3 {
575
+ font-weight: 700 !important;
576
+ }
577
+ #progress-container {
578
+ margin: 10px 0;
579
+ min-height: 20px;
580
+ }
581
+ """
582
+ ) as demo:
583
+
584
+ with gr.Tabs() as main_tabs:
585
+ # Home Tab
586
+ with gr.TabItem("🏠 Home", id="home"):
587
+ embedding_status = "βœ… RAG ready" if embedding_model else "❌ RAG not loaded"
588
+ gemini_status = "βœ… Gemini API ready" if gemini_model else "❌ Gemini API not configured"
589
+ current_status = f"Currently loaded: {current_model or 'None'}"
590
+ gr.Markdown(
591
+ "# Scholar Express\n"
592
+ "### Upload a research paper to get a web-friendly version and an AI chatbot powered by Gemini API. DOLPHIN model runs on GPU for optimal performance.\n"
593
+ f"**System:** {model_status}\n"
594
+ f"**RAG System:** {embedding_status}\n"
595
+ f"**Gemini API:** {gemini_status}\n"
596
+ f"**Status:** {current_status}"
597
+ )
598
+
599
+ with gr.Column(elem_classes="upload-container"):
600
+ gr.Markdown("## πŸ“„ Upload Your PDF Document")
601
+
602
+ pdf_input = gr.File(
603
+ file_types=[".pdf"],
604
+ label="",
605
+ height=150,
606
+ elem_id="pdf_upload"
607
+ )
608
+
609
+ process_btn = gr.Button(
610
+ "πŸš€ Process PDF",
611
+ variant="primary",
612
+ size="lg",
613
+ elem_classes="upload-button"
614
+ )
615
+
616
+ clear_btn = gr.Button(
617
+ "πŸ—‘οΈ Clear",
618
+ variant="secondary"
619
+ )
620
+
621
+ # Dedicated progress space
622
+ progress_space = gr.HTML(
623
+ value="",
624
+ visible=False,
625
+ elem_id="progress-container"
626
+ )
627
+
628
+ # Status output (hidden during processing)
629
+ status_output = gr.Markdown(
630
+ "",
631
+ elem_classes="status-message"
632
+ )
633
+
634
+ # Results Tab (initially hidden)
635
+ with gr.TabItem("πŸ“– Document", id="results", visible=False) as results_tab:
636
+ gr.Markdown("## Processed Document")
637
+
638
+ markdown_display = gr.Markdown(
639
+ value="",
640
+ latex_delimiters=[
641
+ {"left": "$$", "right": "$$", "display": True},
642
+ {"left": "$", "right": "$", "display": False}
643
+ ],
644
+ height=700
645
+ )
646
+
647
+ # Chatbot Tab (initially hidden)
648
+ with gr.TabItem("πŸ’¬ Chat", id="chat", visible=False) as chat_tab:
649
+ gr.Markdown("## Ask Questions About Your Document")
650
+
651
+ chatbot = gr.Chatbot(
652
+ value=[],
653
+ height=500,
654
+ elem_classes="chatbot-container",
655
+ placeholder="Your conversation will appear here once you process a document..."
656
+ )
657
+
658
+ with gr.Row():
659
+ msg_input = gr.Textbox(
660
+ placeholder="Ask a question about the processed document...",
661
+ scale=4,
662
+ container=False
663
+ )
664
+ send_btn = gr.Button("Send", variant="primary", scale=1)
665
+
666
+ gr.Markdown(
667
+ "*Ask questions about your processed document. The AI uses RAG (Retrieval-Augmented Generation) with Gemini API to find relevant sections and provide accurate answers.*",
668
+ elem_id="chat-notice"
669
+ )
670
+
671
+ # Event handlers
672
+ process_btn.click(
673
+ fn=process_uploaded_pdf,
674
+ inputs=[pdf_input],
675
+ outputs=[status_output, results_tab],
676
+ show_progress=True
677
+ ).then(
678
+ fn=get_processed_markdown,
679
+ outputs=[markdown_display]
680
+ ).then(
681
+ fn=lambda: gr.TabItem(visible=True),
682
+ outputs=[chat_tab]
683
+ )
684
+
685
+ clear_btn.click(
686
+ fn=clear_all,
687
+ outputs=[pdf_input, status_output, results_tab]
688
+ ).then(
689
+ fn=lambda: gr.HTML(visible=False),
690
+ outputs=[progress_space]
691
+ ).then(
692
+ fn=lambda: gr.TabItem(visible=False),
693
+ outputs=[chat_tab]
694
+ )
695
+
696
+ # Chatbot functionality with Gemini API
697
+ def chatbot_response(message, history):
698
+ if not message.strip():
699
+ return history
700
+
701
+ if not processed_markdown:
702
+ return history + [[message, "❌ Please process a PDF document first before asking questions."]]
703
+
704
+ try:
705
+ # Initialize Gemini model
706
+ model = initialize_gemini_model()
707
+
708
+ if model is None:
709
+ return history + [[message, "❌ Failed to initialize Gemini model. Please check your GEMINI_API_KEY."]]
710
+
711
+ # Use RAG to get relevant chunks from markdown (balanced for performance vs quota)
712
+ if document_chunks and len(document_chunks) > 0:
713
+ relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings, top_k=3)
714
+ context = "\n\n".join(relevant_chunks)
715
+ # Smart truncation: aim for ~1500 chars (good context while staying under quota)
716
+ if len(context) > 1500:
717
+ # Try to cut at sentence boundaries
718
+ sentences = context[:1500].split('.')
719
+ context = '.'.join(sentences[:-1]) + '...' if len(sentences) > 1 else context[:1500] + '...'
720
  else:
721
+ # Fallback to truncated document if RAG fails
722
+ context = processed_markdown[:1200] + "..." if len(processed_markdown) > 1200 else processed_markdown
723
+
724
+ # Create prompt for Gemini
725
+ prompt = f"""You are a helpful assistant that answers questions about documents. Use the provided context to answer questions accurately and concisely.
726
+
727
+ Context from the document:
728
+ {context}
729
+
730
+ Question: {message}
731
+
732
+ Please provide a clear and helpful answer based on the context provided."""
733
+
734
+ # Generate response using Gemini API with retry logic
735
+ import time
736
+ max_retries = 2
737
+
738
+ for attempt in range(max_retries):
739
+ try:
740
+ response = model.generate_content(prompt)
741
+ response_text = response.text if hasattr(response, 'text') else str(response)
742
+ return history + [[message, response_text]]
743
+ except Exception as api_error:
744
+ if "429" in str(api_error) and attempt < max_retries - 1:
745
+ # Rate limit hit, wait and retry
746
+ time.sleep(3)
747
+ continue
748
+ else:
749
+ # Other error or final attempt failed
750
+ if "429" in str(api_error):
751
+ return history + [[message, "❌ API quota exceeded. Please wait a moment and try again, or check your Gemini API billing."]]
752
+ else:
753
+ raise api_error
754
+
755
+ except Exception as e:
756
+ error_msg = f"❌ Error generating response: {str(e)}"
757
+ print(f"Full error: {e}")
758
+ import traceback
759
+ traceback.print_exc()
760
+ return history + [[message, error_msg]]
761
+
762
+ send_btn.click(
763
+ fn=chatbot_response,
764
+ inputs=[msg_input, chatbot],
765
+ outputs=[chatbot]
766
+ ).then(
767
+ lambda: "",
768
+ outputs=[msg_input]
769
+ )
770
+
771
+ # Also allow Enter key to send message
772
+ msg_input.submit(
773
+ fn=chatbot_response,
774
+ inputs=[msg_input, chatbot],
775
+ outputs=[chatbot]
776
+ ).then(
777
+ lambda: "",
778
+ outputs=[msg_input]
779
+ )
780
+
781
+
782
+ if __name__ == "__main__":
783
+ demo.launch(
784
+ server_name="0.0.0.0",
785
+ server_port=7860,
786
+ share=False,
787
+ show_error=True,
788
+ max_threads=1, # Single thread for T4 Small
789
+ inbrowser=False,
790
+ quiet=True
791
+ )