raksama19 commited on
Commit
6c3a257
Β·
verified Β·
1 Parent(s): 6c7ef32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -226
app.py CHANGED
@@ -1,7 +1,7 @@
1
  """
2
- DOLPHIN PDF Document AI - Alt Text Enhanced Version
3
- Optimized for HuggingFace Spaces NVIDIA T4 Small deployment
4
- Features: AI-generated alt text for accessibility using Gemma 3n
5
  """
6
 
7
  import gradio as gr
@@ -10,18 +10,16 @@ import markdown
10
  import cv2
11
  import numpy as np
12
  from PIL import Image
13
- from transformers import AutoProcessor, VisionEncoderDecoderModel, Gemma3nForConditionalGeneration, pipeline
14
  import torch
15
  try:
16
  from sentence_transformers import SentenceTransformer
17
  import numpy as np
18
  from sklearn.metrics.pairwise import cosine_similarity
19
- import google.generativeai as genai
20
- from google.generativeai import types
21
  RAG_DEPENDENCIES_AVAILABLE = True
22
  except ImportError as e:
23
  print(f"RAG dependencies not available: {e}")
24
- print("Please install: pip install sentence-transformers scikit-learn google-generativeai")
25
  RAG_DEPENDENCIES_AVAILABLE = False
26
  SentenceTransformer = None
27
  import os
@@ -43,7 +41,7 @@ except ImportError:
43
 
44
  class DOLPHIN:
45
  def __init__(self, model_id_or_path):
46
- """Initialize the Hugging Face model optimized for T4 Small"""
47
  self.processor = AutoProcessor.from_pretrained(model_id_or_path)
48
  self.model = VisionEncoderDecoderModel.from_pretrained(
49
  model_id_or_path,
@@ -93,7 +91,7 @@ class DOLPHIN:
93
  decoder_input_ids=batch_prompt_ids,
94
  decoder_attention_mask=batch_attention_mask,
95
  min_length=1,
96
- max_length=1024, # Reduced for T4 Small
97
  pad_token_id=self.tokenizer.pad_token_id,
98
  eos_token_id=self.tokenizer.eos_token_id,
99
  use_cache=True,
@@ -117,6 +115,139 @@ class DOLPHIN:
117
  return results
118
 
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  def convert_pdf_to_images_gradio(pdf_file):
121
  """Convert uploaded PDF file to list of PIL Images"""
122
  try:
@@ -170,7 +301,7 @@ def process_pdf_document(pdf_file, model, progress=gr.Progress()):
170
  padded_image,
171
  dims,
172
  model,
173
- max_batch_size=2 # Smaller batch for T4 Small
174
  )
175
 
176
  try:
@@ -199,8 +330,8 @@ def process_pdf_document(pdf_file, model, progress=gr.Progress()):
199
  return error_msg, "error"
200
 
201
 
202
- def process_elements_optimized(layout_results, padded_image, dims, model, max_batch_size=2):
203
- """Optimized element processing for T4 Small"""
204
  layout_results = parse_layout_string(layout_results)
205
 
206
  text_elements = []
@@ -221,8 +352,8 @@ def process_elements_optimized(layout_results, padded_image, dims, model, max_ba
221
  pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
222
  pil_crop = crop_margin(pil_crop)
223
 
224
- # Generate alt text for accessibility
225
- alt_text = generate_alt_text_for_image(pil_crop)
226
 
227
  buffered = io.BytesIO()
228
  pil_crop.save(buffered, format="PNG")
@@ -274,8 +405,8 @@ def process_elements_optimized(layout_results, padded_image, dims, model, max_ba
274
  return recognition_results
275
 
276
 
277
- def process_element_batch_optimized(elements, model, prompt, max_batch_size=2):
278
- """Process elements in small batches for T4 Small"""
279
  results = []
280
  batch_size = min(len(elements), max_batch_size)
281
 
@@ -316,7 +447,7 @@ def generate_fallback_markdown(recognition_results):
316
  return markdown_content
317
 
318
 
319
- # Initialize model
320
  model_path = "./hf_model"
321
  if not os.path.exists(model_path):
322
  model_path = "ByteDance/DOLPHIN"
@@ -324,164 +455,30 @@ if not os.path.exists(model_path):
324
  # Model paths and configuration
325
  model_path = "./hf_model" if os.path.exists("./hf_model") else "ByteDance/DOLPHIN"
326
  hf_token = os.getenv('HF_TOKEN')
 
327
 
328
- # Don't load models initially - load them on demand
329
- model_status = "βœ… Models ready (Dynamic loading)"
 
 
330
 
331
- # Initialize embedding model and Gemini API
 
 
 
 
 
332
  if RAG_DEPENDENCIES_AVAILABLE:
333
  try:
334
  print("Loading embedding model for RAG...")
335
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
336
  print("βœ… Embedding model loaded successfully (CPU)")
337
-
338
- # Initialize Gemini API
339
- gemini_api_key = os.getenv('GEMINI_API_KEY')
340
- if gemini_api_key:
341
- genai.configure(api_key=gemini_api_key)
342
- gemini_client = True # Just mark as configured
343
- print("βœ… Gemini API configured successfully")
344
- else:
345
- print("❌ GEMINI_API_KEY not found in environment")
346
- gemini_client = None
347
  except Exception as e:
348
- print(f"❌ Error loading models: {e}")
349
- import traceback
350
- traceback.print_exc()
351
  embedding_model = None
352
- gemini_client = None
353
  else:
354
  print("❌ RAG dependencies not available")
355
  embedding_model = None
356
- gemini_client = None
357
-
358
- # Model management functions
359
- def load_dolphin_model():
360
- """Load DOLPHIN model for PDF processing"""
361
- global dolphin_model, current_model
362
-
363
- if current_model == "dolphin":
364
- return dolphin_model
365
-
366
- # No need to unload chatbot model (using API now)
367
-
368
- try:
369
- print("Loading DOLPHIN model...")
370
- dolphin_model = DOLPHIN(model_path)
371
- current_model = "dolphin"
372
- print(f"βœ… DOLPHIN model loaded (Device: {dolphin_model.device})")
373
- return dolphin_model
374
- except Exception as e:
375
- print(f"❌ Error loading DOLPHIN model: {e}")
376
- return None
377
-
378
- def unload_dolphin_model():
379
- """Unload DOLPHIN model to free memory"""
380
- global dolphin_model, current_model
381
-
382
- if dolphin_model is not None:
383
- print("Unloading DOLPHIN model...")
384
- del dolphin_model
385
- dolphin_model = None
386
- if current_model == "dolphin":
387
- current_model = None
388
- if torch.cuda.is_available():
389
- torch.cuda.empty_cache()
390
- print("βœ… DOLPHIN model unloaded")
391
-
392
- def initialize_gemini_client():
393
- """Initialize Gemini API client"""
394
- global gemini_client
395
-
396
- if gemini_client is not None:
397
- return gemini_client
398
-
399
- try:
400
- gemini_api_key = os.getenv('GEMINI_API_KEY')
401
- if not gemini_api_key:
402
- print("❌ GEMINI_API_KEY not found in environment")
403
- return None
404
-
405
- print("Initializing Gemini API client...")
406
- gemini_client = genai.configure(api_key=gemini_api_key)
407
- print("βœ… Gemini API client ready for gemma-3n-e4b-it")
408
- return gemini_client
409
- except Exception as e:
410
- print(f"❌ Error initializing Gemini client: {e}")
411
- import traceback
412
- traceback.print_exc()
413
- return None
414
-
415
-
416
- def generate_alt_text_for_image(pil_image):
417
- """Generate alt text for an image using Gemma 3n model via Google AI API"""
418
- try:
419
- # Initialize Gemini client
420
- client = initialize_gemini_client()
421
- if client is None:
422
- print("❌ Gemini client not initialized for alt text generation")
423
- return "Image description unavailable"
424
-
425
- # Debug: Check image format and properties
426
- print(f"πŸ” Image format: {pil_image.format}, mode: {pil_image.mode}, size: {pil_image.size}")
427
-
428
- # Ensure image is in RGB mode
429
- if pil_image.mode != 'RGB':
430
- print(f"Converting image from {pil_image.mode} to RGB")
431
- pil_image = pil_image.convert('RGB')
432
-
433
- # Convert PIL image to bytes
434
- buffered = io.BytesIO()
435
- pil_image.save(buffered, format="JPEG")
436
- image_bytes = buffered.getvalue()
437
-
438
- print(f"πŸ” Generating alt text for image with Gemma 3n...")
439
-
440
- # Create a detailed prompt for alt text generation
441
- prompt = """You are an accessibility expert creating alt text for images to help visually impaired users understand visual content. Analyze this image and provide a clear, concise description that captures the essential visual information.
442
-
443
- Focus on:
444
- - Main subject or content of the image
445
- - Important details, text, or data shown
446
- - Layout and structure if relevant (charts, diagrams, tables)
447
- - Context that would help someone understand the image's purpose
448
-
449
- Provide a descriptive alt text in 1-2 sentences that is informative but not overly verbose. Start directly with the description without saying "This image shows" or similar phrases."""
450
-
451
- # Use the Google AI API client with proper format
452
- response = genai.GenerativeModel('gemma-3n-e4b-it').generate_content([
453
- types.Part.from_bytes(
454
- data=image_bytes,
455
- mime_type='image/jpeg',
456
- ),
457
- prompt
458
- ])
459
-
460
- print(f"πŸ“‘ API response received: {type(response)}")
461
-
462
- if hasattr(response, 'text') and response.text:
463
- alt_text = response.text.strip()
464
- print(f"βœ… Alt text generated: {alt_text[:100]}...")
465
- else:
466
- print(f"❌ No text in response. Response: {response}")
467
- return "Image description unavailable"
468
-
469
- # Clean up the alt text
470
- alt_text = alt_text.replace('\n', ' ').replace('\r', ' ')
471
- # Remove common prefixes if they appear
472
- prefixes_to_remove = ["This image shows", "The image shows", "This shows", "The figure shows"]
473
- for prefix in prefixes_to_remove:
474
- if alt_text.startswith(prefix):
475
- alt_text = alt_text[len(prefix):].strip()
476
- break
477
-
478
- return alt_text if alt_text else "Image description unavailable"
479
-
480
- except Exception as e:
481
- print(f"❌ Error generating alt text: {e}")
482
- import traceback
483
- traceback.print_exc()
484
- return "Image description unavailable"
485
 
486
 
487
  # Global state for managing tabs
@@ -490,14 +487,9 @@ show_results_tab = False
490
  document_chunks = []
491
  document_embeddings = None
492
 
493
- # Global model state
494
- dolphin_model = None
495
- gemini_client = None
496
- current_model = None # Track which model is currently loaded
497
-
498
 
499
  def chunk_document(text, chunk_size=1024, overlap=100):
500
- """Split document into overlapping chunks for RAG - optimized for API quota"""
501
  words = text.split()
502
  chunks = []
503
 
@@ -554,16 +546,9 @@ def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
554
  return "❌ No PDF uploaded", gr.Tabs(visible=False)
555
 
556
  try:
557
- # Load DOLPHIN model for PDF processing
558
- progress(0.1, desc="Loading DOLPHIN model...")
559
- dolphin = load_dolphin_model()
560
-
561
- if dolphin is None:
562
- return "❌ Failed to load DOLPHIN model", gr.Tabs(visible=False)
563
-
564
  # Process PDF
565
- progress(0.2, desc="Processing PDF...")
566
- combined_markdown, status = process_pdf_document(pdf_file, dolphin, progress)
567
 
568
  if status == "processing_complete":
569
  processed_markdown = combined_markdown
@@ -574,9 +559,6 @@ def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
574
  document_embeddings = create_embeddings(document_chunks)
575
  print(f"Created {len(document_chunks)} chunks")
576
 
577
- # Keep DOLPHIN model loaded for GPU usage
578
- progress(0.95, desc="Preparing chatbot...")
579
-
580
  show_results_tab = True
581
  progress(1.0, desc="PDF processed successfully!")
582
  return "βœ… PDF processed successfully! Chatbot is ready in the Chat tab.", gr.Tabs(visible=True)
@@ -604,15 +586,16 @@ def clear_all():
604
  document_chunks = []
605
  document_embeddings = None
606
 
607
- # Unload DOLPHIN model
608
- unload_dolphin_model()
 
609
 
610
  return None, "", gr.Tabs(visible=False)
611
 
612
 
613
  # Create Gradio interface
614
  with gr.Blocks(
615
- title="DOLPHIN PDF AI",
616
  theme=gr.themes.Soft(),
617
  css="""
618
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
@@ -662,16 +645,15 @@ with gr.Blocks(
662
  # Home Tab
663
  with gr.TabItem("🏠 Home", id="home"):
664
  embedding_status = "βœ… RAG ready" if embedding_model else "❌ RAG not loaded"
665
- gemini_status = "βœ… Gemini API ready" if gemini_client else "❌ Gemini API not configured"
666
- current_status = f"Currently loaded: {current_model or 'None'}"
667
  gr.Markdown(
668
- "# Scholar Express - Alt Text Enhanced\n"
669
- "### Upload a research paper to get a web-friendly version with AI-generated alt text for accessibility. Includes an AI chatbot powered by Gemini API.\n"
670
  f"**System:** {model_status}\n"
671
  f"**RAG System:** {embedding_status}\n"
672
- f"**Gemini API:** {gemini_status}\n"
 
673
  f"**Alt Text:** Gemma 3n generates descriptive alt text for images\n"
674
- f"**Status:** {current_status}"
675
  )
676
 
677
  with gr.Column(elem_classes="upload-container"):
@@ -742,7 +724,7 @@ with gr.Blocks(
742
  send_btn = gr.Button("Send", variant="primary", scale=1)
743
 
744
  gr.Markdown(
745
- "*Ask questions about your processed document. The AI uses RAG (Retrieval-Augmented Generation) with Gemini API to find relevant sections and provide accurate answers.*",
746
  elem_id="chat-notice"
747
  )
748
 
@@ -771,7 +753,7 @@ with gr.Blocks(
771
  outputs=[chat_tab]
772
  )
773
 
774
- # Chatbot functionality with Gemini API
775
  def chatbot_response(message, history):
776
  if not message.strip():
777
  return history
@@ -780,26 +762,20 @@ with gr.Blocks(
780
  return history + [[message, "❌ Please process a PDF document first before asking questions."]]
781
 
782
  try:
783
- # Initialize Gemini client
784
- client = initialize_gemini_client()
785
-
786
- if client is None:
787
- return history + [[message, "❌ Failed to initialize Gemini client. Please check your GEMINI_API_KEY."]]
788
-
789
- # Use RAG to get relevant chunks from markdown (balanced for performance vs quota)
790
  if document_chunks and len(document_chunks) > 0:
791
  relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings, top_k=3)
792
  context = "\n\n".join(relevant_chunks)
793
- # Smart truncation: aim for ~4000 chars (good context while staying under quota)
794
- if len(context) > 4000:
795
  # Try to cut at sentence boundaries
796
- sentences = context[:4000].split('.')
797
- context = '.'.join(sentences[:-1]) + '...' if len(sentences) > 1 else context[:4000] + '...'
798
  else:
799
  # Fallback to truncated document if RAG fails
800
- context = processed_markdown[:4000] + "..." if len(processed_markdown) > 4000 else processed_markdown
801
 
802
- # Create prompt for Gemini
803
  prompt = f"""You are a helpful assistant that answers questions about documents. Use the provided context to answer questions accurately and concisely.
804
 
805
  Context from the document:
@@ -809,26 +785,9 @@ Question: {message}
809
 
810
  Please provide a clear and helpful answer based on the context provided."""
811
 
812
- # Generate response using Gemini API with retry logic
813
- import time
814
- max_retries = 2
815
-
816
- for attempt in range(max_retries):
817
- try:
818
- response = genai.GenerativeModel('gemma-3n-e4b-it').generate_content(prompt)
819
- response_text = response.text if hasattr(response, 'text') else str(response)
820
- return history + [[message, response_text]]
821
- except Exception as api_error:
822
- if "429" in str(api_error) and attempt < max_retries - 1:
823
- # Rate limit hit, wait and retry
824
- time.sleep(3)
825
- continue
826
- else:
827
- # Other error or final attempt failed
828
- if "429" in str(api_error):
829
- return history + [[message, "❌ API quota exceeded. Please wait a moment and try again, or check your Gemini API billing."]]
830
- else:
831
- raise api_error
832
 
833
  except Exception as e:
834
  error_msg = f"❌ Error generating response: {str(e)}"
@@ -863,7 +822,7 @@ if __name__ == "__main__":
863
  server_port=7860,
864
  share=False,
865
  show_error=True,
866
- max_threads=1, # Single thread for T4 Small
867
  inbrowser=False,
868
  quiet=True
869
  )
 
1
  """
2
+ DOLPHIN PDF Document AI - Local Gemma 3n Version
3
+ Optimized for powerful GPU deployment with local models
4
+ Features: AI-generated alt text for accessibility using local Gemma 3n
5
  """
6
 
7
  import gradio as gr
 
10
  import cv2
11
  import numpy as np
12
  from PIL import Image
13
+ from transformers import AutoProcessor, VisionEncoderDecoderModel, AutoModelForImageTextToText
14
  import torch
15
  try:
16
  from sentence_transformers import SentenceTransformer
17
  import numpy as np
18
  from sklearn.metrics.pairwise import cosine_similarity
 
 
19
  RAG_DEPENDENCIES_AVAILABLE = True
20
  except ImportError as e:
21
  print(f"RAG dependencies not available: {e}")
22
+ print("Please install: pip install sentence-transformers scikit-learn")
23
  RAG_DEPENDENCIES_AVAILABLE = False
24
  SentenceTransformer = None
25
  import os
 
41
 
42
  class DOLPHIN:
43
  def __init__(self, model_id_or_path):
44
+ """Initialize the Hugging Face model optimized for powerful GPU"""
45
  self.processor = AutoProcessor.from_pretrained(model_id_or_path)
46
  self.model = VisionEncoderDecoderModel.from_pretrained(
47
  model_id_or_path,
 
91
  decoder_input_ids=batch_prompt_ids,
92
  decoder_attention_mask=batch_attention_mask,
93
  min_length=1,
94
+ max_length=2048,
95
  pad_token_id=self.tokenizer.pad_token_id,
96
  eos_token_id=self.tokenizer.eos_token_id,
97
  use_cache=True,
 
115
  return results
116
 
117
 
118
+ class Gemma3nModel:
119
+ def __init__(self, model_id="google/gemma-3n-E4B-it"):
120
+ """Initialize the Gemma 3n model for text generation and image description"""
121
+ self.model_id = model_id
122
+ self.processor = AutoProcessor.from_pretrained(model_id)
123
+ self.model = AutoModelForImageTextToText.from_pretrained(
124
+ model_id,
125
+ torch_dtype="auto",
126
+ device_map="auto"
127
+ )
128
+ self.model.eval()
129
+ print(f"βœ… Gemma 3n loaded (Device: {self.model.device}, DType: {self.model.dtype})")
130
+
131
+ def generate_alt_text(self, pil_image):
132
+ """Generate alt text for an image using local Gemma 3n"""
133
+ try:
134
+ # Ensure image is in RGB mode
135
+ if pil_image.mode != 'RGB':
136
+ pil_image = pil_image.convert('RGB')
137
+
138
+ # Create a detailed prompt for alt text generation
139
+ prompt = """You are an accessibility expert creating alt text for images to help visually impaired users understand visual content. Analyze this image and provide a clear, concise description that captures the essential visual information.
140
+
141
+ Focus on:
142
+ - Main subject or content of the image
143
+ - Important details, text, or data shown
144
+ - Layout and structure if relevant (charts, diagrams, tables)
145
+ - Context that would help someone understand the image's purpose
146
+
147
+ Provide a descriptive alt text in 1-2 sentences that is informative but not overly verbose. Start directly with the description without saying "This image shows" or similar phrases."""
148
+
149
+ # Prepare the message format
150
+ message = {
151
+ "role": "user",
152
+ "content": [
153
+ {"type": "image", "image": pil_image},
154
+ {"type": "text", "text": prompt}
155
+ ]
156
+ }
157
+
158
+ # Apply chat template and generate
159
+ input_ids = self.processor.apply_chat_template(
160
+ [message],
161
+ add_generation_prompt=True,
162
+ tokenize=True,
163
+ return_dict=True,
164
+ return_tensors="pt",
165
+ )
166
+ input_len = input_ids["input_ids"].shape[-1]
167
+
168
+ input_ids = input_ids.to(self.model.device, dtype=self.model.dtype)
169
+ outputs = self.model.generate(
170
+ **input_ids,
171
+ max_new_tokens=256,
172
+ disable_compile=True,
173
+ do_sample=False,
174
+ temperature=0.1
175
+ )
176
+
177
+ text = self.processor.batch_decode(
178
+ outputs[:, input_len:],
179
+ skip_special_tokens=True,
180
+ clean_up_tokenization_spaces=True
181
+ )
182
+
183
+ alt_text = text[0].strip()
184
+
185
+ # Clean up the alt text
186
+ alt_text = alt_text.replace('\n', ' ').replace('\r', ' ')
187
+ # Remove common prefixes if they appear
188
+ prefixes_to_remove = ["This image shows", "The image shows", "This shows", "The figure shows"]
189
+ for prefix in prefixes_to_remove:
190
+ if alt_text.startswith(prefix):
191
+ alt_text = alt_text[len(prefix):].strip()
192
+ break
193
+
194
+ return alt_text if alt_text else "Image description unavailable"
195
+
196
+ except Exception as e:
197
+ print(f"❌ Error generating alt text: {e}")
198
+ import traceback
199
+ traceback.print_exc()
200
+ return "Image description unavailable"
201
+
202
+ def chat(self, prompt, history=None):
203
+ """Chat functionality using Gemma 3n for text-only conversations"""
204
+ try:
205
+ # Create message format
206
+ message = {
207
+ "role": "user",
208
+ "content": [
209
+ {"type": "text", "text": prompt}
210
+ ]
211
+ }
212
+
213
+ # If history exists, include it
214
+ conversation = history if history else []
215
+ conversation.append(message)
216
+
217
+ # Apply chat template and generate
218
+ input_ids = self.processor.apply_chat_template(
219
+ conversation,
220
+ add_generation_prompt=True,
221
+ tokenize=True,
222
+ return_dict=True,
223
+ return_tensors="pt",
224
+ )
225
+ input_len = input_ids["input_ids"].shape[-1]
226
+
227
+ input_ids = input_ids.to(self.model.device, dtype=self.model.dtype)
228
+ outputs = self.model.generate(
229
+ **input_ids,
230
+ max_new_tokens=1024,
231
+ disable_compile=True,
232
+ do_sample=True,
233
+ temperature=0.7
234
+ )
235
+
236
+ text = self.processor.batch_decode(
237
+ outputs[:, input_len:],
238
+ skip_special_tokens=True,
239
+ clean_up_tokenization_spaces=True
240
+ )
241
+
242
+ return text[0].strip()
243
+
244
+ except Exception as e:
245
+ print(f"❌ Error in chat: {e}")
246
+ import traceback
247
+ traceback.print_exc()
248
+ return f"Error generating response: {str(e)}"
249
+
250
+
251
  def convert_pdf_to_images_gradio(pdf_file):
252
  """Convert uploaded PDF file to list of PIL Images"""
253
  try:
 
301
  padded_image,
302
  dims,
303
  model,
304
+ max_batch_size=4
305
  )
306
 
307
  try:
 
330
  return error_msg, "error"
331
 
332
 
333
+ def process_elements_optimized(layout_results, padded_image, dims, model, max_batch_size=4):
334
+ """Optimized element processing for powerful GPU"""
335
  layout_results = parse_layout_string(layout_results)
336
 
337
  text_elements = []
 
352
  pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
353
  pil_crop = crop_margin(pil_crop)
354
 
355
+ # Generate alt text for accessibility using local Gemma 3n
356
+ alt_text = gemma_model.generate_alt_text(pil_crop)
357
 
358
  buffered = io.BytesIO()
359
  pil_crop.save(buffered, format="PNG")
 
405
  return recognition_results
406
 
407
 
408
+ def process_element_batch_optimized(elements, model, prompt, max_batch_size=4):
409
+ """Process elements in batches for powerful GPU"""
410
  results = []
411
  batch_size = min(len(elements), max_batch_size)
412
 
 
447
  return markdown_content
448
 
449
 
450
+ # Initialize models
451
  model_path = "./hf_model"
452
  if not os.path.exists(model_path):
453
  model_path = "ByteDance/DOLPHIN"
 
455
  # Model paths and configuration
456
  model_path = "./hf_model" if os.path.exists("./hf_model") else "ByteDance/DOLPHIN"
457
  hf_token = os.getenv('HF_TOKEN')
458
+ gemma_model_id = "google/gemma-3n-E4B-it"
459
 
460
+ # Initialize models
461
+ print("Loading DOLPHIN model...")
462
+ dolphin_model = DOLPHIN(model_path)
463
+ print(f"βœ… DOLPHIN model loaded (Device: {dolphin_model.device})")
464
 
465
+ print("Loading Gemma 3n model...")
466
+ gemma_model = Gemma3nModel(gemma_model_id)
467
+
468
+ model_status = "βœ… Both models loaded successfully"
469
+
470
+ # Initialize embedding model
471
  if RAG_DEPENDENCIES_AVAILABLE:
472
  try:
473
  print("Loading embedding model for RAG...")
474
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
475
  print("βœ… Embedding model loaded successfully (CPU)")
 
 
 
 
 
 
 
 
 
 
476
  except Exception as e:
477
+ print(f"❌ Error loading embedding model: {e}")
 
 
478
  embedding_model = None
 
479
  else:
480
  print("❌ RAG dependencies not available")
481
  embedding_model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
 
484
  # Global state for managing tabs
 
487
  document_chunks = []
488
  document_embeddings = None
489
 
 
 
 
 
 
490
 
491
  def chunk_document(text, chunk_size=1024, overlap=100):
492
+ """Split document into overlapping chunks for RAG"""
493
  words = text.split()
494
  chunks = []
495
 
 
546
  return "❌ No PDF uploaded", gr.Tabs(visible=False)
547
 
548
  try:
 
 
 
 
 
 
 
549
  # Process PDF
550
+ progress(0.1, desc="Processing PDF...")
551
+ combined_markdown, status = process_pdf_document(pdf_file, dolphin_model, progress)
552
 
553
  if status == "processing_complete":
554
  processed_markdown = combined_markdown
 
559
  document_embeddings = create_embeddings(document_chunks)
560
  print(f"Created {len(document_chunks)} chunks")
561
 
 
 
 
562
  show_results_tab = True
563
  progress(1.0, desc="PDF processed successfully!")
564
  return "βœ… PDF processed successfully! Chatbot is ready in the Chat tab.", gr.Tabs(visible=True)
 
586
  document_chunks = []
587
  document_embeddings = None
588
 
589
+ # Clear GPU cache
590
+ if torch.cuda.is_available():
591
+ torch.cuda.empty_cache()
592
 
593
  return None, "", gr.Tabs(visible=False)
594
 
595
 
596
  # Create Gradio interface
597
  with gr.Blocks(
598
+ title="DOLPHIN PDF AI - Local Gemma 3n",
599
  theme=gr.themes.Soft(),
600
  css="""
601
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
 
645
  # Home Tab
646
  with gr.TabItem("🏠 Home", id="home"):
647
  embedding_status = "βœ… RAG ready" if embedding_model else "❌ RAG not loaded"
 
 
648
  gr.Markdown(
649
+ "# Scholar Express - Local Gemma 3n Version\n"
650
+ "### Upload a research paper to get a web-friendly version with AI-generated alt text for accessibility. Includes an AI chatbot powered by local Gemma 3n.\n"
651
  f"**System:** {model_status}\n"
652
  f"**RAG System:** {embedding_status}\n"
653
+ f"**DOLPHIN:** Local model for PDF processing\n"
654
+ f"**Gemma 3n:** Local model for alt text generation and chat\n"
655
  f"**Alt Text:** Gemma 3n generates descriptive alt text for images\n"
656
+ f"**GPU:** {'CUDA available' if torch.cuda.is_available() else 'CPU only'}"
657
  )
658
 
659
  with gr.Column(elem_classes="upload-container"):
 
724
  send_btn = gr.Button("Send", variant="primary", scale=1)
725
 
726
  gr.Markdown(
727
+ "*Ask questions about your processed document. The AI uses RAG (Retrieval-Augmented Generation) with local Gemma 3n to find relevant sections and provide accurate answers.*",
728
  elem_id="chat-notice"
729
  )
730
 
 
753
  outputs=[chat_tab]
754
  )
755
 
756
+ # Chatbot functionality with local Gemma 3n
757
  def chatbot_response(message, history):
758
  if not message.strip():
759
  return history
 
762
  return history + [[message, "❌ Please process a PDF document first before asking questions."]]
763
 
764
  try:
765
+ # Use RAG to get relevant chunks from markdown
 
 
 
 
 
 
766
  if document_chunks and len(document_chunks) > 0:
767
  relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings, top_k=3)
768
  context = "\n\n".join(relevant_chunks)
769
+ # Smart truncation: aim for ~6000 chars for local model
770
+ if len(context) > 6000:
771
  # Try to cut at sentence boundaries
772
+ sentences = context[:6000].split('.')
773
+ context = '.'.join(sentences[:-1]) + '...' if len(sentences) > 1 else context[:6000] + '...'
774
  else:
775
  # Fallback to truncated document if RAG fails
776
+ context = processed_markdown[:6000] + "..." if len(processed_markdown) > 6000 else processed_markdown
777
 
778
+ # Create prompt for Gemma 3n
779
  prompt = f"""You are a helpful assistant that answers questions about documents. Use the provided context to answer questions accurately and concisely.
780
 
781
  Context from the document:
 
785
 
786
  Please provide a clear and helpful answer based on the context provided."""
787
 
788
+ # Generate response using local Gemma 3n
789
+ response_text = gemma_model.chat(prompt)
790
+ return history + [[message, response_text]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
791
 
792
  except Exception as e:
793
  error_msg = f"❌ Error generating response: {str(e)}"
 
822
  server_port=7860,
823
  share=False,
824
  show_error=True,
825
+ max_threads=4,
826
  inbrowser=False,
827
  quiet=True
828
  )