madankn79 commited on
Commit
7d27dff
·
1 Parent(s): d32fa95
Files changed (3) hide show
  1. Dockerfile +23 -0
  2. app.py +347 -0
  3. requirements.txt +6 -0
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Install dependencies for poppler (used by pdf2image)
4
+ RUN apt-get update && \
5
+ apt-get install -y poppler-utils libglib2.0-0 libsm6 libxext6 libxrender-dev && \
6
+ apt-get clean && \
7
+ rm -rf /var/lib/apt/lists/*
8
+
9
+ # Set working directory
10
+ WORKDIR /app
11
+
12
+ # Copy requirements and install
13
+ COPY requirements.txt .
14
+ RUN pip install --no-cache-dir --upgrade pip && pip install --no-cache-dir -r requirements.txt
15
+
16
+ # Copy app files
17
+ COPY . .
18
+
19
+ # Expose port (default Gradio port is 7860)
20
+ EXPOSE 7860
21
+
22
+ # Run Gradio app
23
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import tempfile
4
+ import time
5
+ import uuid
6
+
7
+ import cv2
8
+ import gradio as gr
9
+ import pymupdf
10
+ import spaces
11
+ import torch
12
+ from loguru import logger
13
+ from PIL import Image
14
+ from transformers import AutoProcessor, VisionEncoderDecoderModel
15
+
16
+ # --- Assumed to be in 'utils/utils.py' ---
17
+ # The following utility functions are required from your original project structure.
18
+ # Ensure you have the 'utils.py' file with these functions.
19
+ # Example placeholder for what these functions might do:
20
+ try:
21
+ from utils.utils import prepare_image, parse_layout_string, process_coordinates
22
+ except ImportError:
23
+ logger.error("Could not import from 'utils.utils'. Please ensure utils.py is in the correct path.")
24
+ # Define dummy functions to allow the script to load, but it will fail at runtime.
25
+ def prepare_image(image): return image, None
26
+ def parse_layout_string(s): return []
27
+ def process_coordinates(bbox, img, dims, prev_box): return 0,0,0,0,0,0,0,0,None
28
+ # -----------------------------------------
29
+
30
+
31
+ # --- Global Variables ---
32
+ model = None
33
+ processor = None
34
+ tokenizer = None
35
+
36
+
37
+ @spaces.GPU
38
+ def initialize_model():
39
+ """Initializes the Hugging Face model and processor."""
40
+ global model, processor, tokenizer
41
+
42
+ if model is None:
43
+ logger.info("Loading DOLPHIN model for PDF to JSON conversion...")
44
+ model_id = "ByteDance/Dolphin"
45
+
46
+ try:
47
+ processor = AutoProcessor.from_pretrained(model_id)
48
+ model = VisionEncoderDecoderModel.from_pretrained(model_id)
49
+
50
+ device = "cuda" if torch.cuda.is_available() else "cpu"
51
+ model.to(device)
52
+ # Use half-precision for better performance if on CUDA
53
+ if device == "cuda":
54
+ model = model.half()
55
+
56
+ model.eval()
57
+ tokenizer = processor.tokenizer
58
+ logger.info(f"Model loaded successfully on {device}")
59
+ except Exception as e:
60
+ logger.error(f"Fatal error during model initialization: {e}")
61
+ raise
62
+
63
+
64
+ @spaces.GPU
65
+ def model_inference(prompt, image):
66
+ """
67
+ Performs inference using the Dolphin model. Handles both single and batch processing.
68
+ """
69
+ global model, processor, tokenizer
70
+
71
+ if model is None:
72
+ logger.warning("Model not initialized. Initializing now...")
73
+ initialize_model()
74
+
75
+ is_batch = isinstance(image, list)
76
+ images = image if is_batch else [image]
77
+ prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
78
+
79
+ device = model.device
80
+
81
+ # Prepare image tensors
82
+ batch_inputs = processor(images, return_tensors="pt", padding=True)
83
+ pixel_values_dtype = torch.float16 if device == "cuda" else torch.float32
84
+ batch_pixel_values = batch_inputs.pixel_values.to(device, dtype=pixel_values_dtype)
85
+
86
+ # Prepare prompt tensors
87
+ prompts_with_task = [f"<s>{p} <Answer/>" for p in prompts]
88
+ batch_prompt_inputs = tokenizer(
89
+ prompts_with_task, add_special_tokens=False, return_tensors="pt"
90
+ )
91
+ batch_prompt_ids = batch_prompt_inputs.input_ids.to(device)
92
+ batch_attention_mask = batch_prompt_inputs.attention_mask.to(device)
93
+
94
+ # Generate text sequences
95
+ outputs = model.generate(
96
+ pixel_values=batch_pixel_values,
97
+ decoder_input_ids=batch_prompt_ids,
98
+ decoder_attention_mask=batch_attention_mask,
99
+ max_length=4096,
100
+ pad_token_id=tokenizer.pad_token_id,
101
+ eos_token_id=tokenizer.eos_token_id,
102
+ use_cache=True,
103
+ bad_words_ids=[[tokenizer.unk_token_id]],
104
+ return_dict_in_generate=True,
105
+ )
106
+
107
+ # Decode and clean up the output
108
+ sequences = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
109
+ results = [
110
+ seq.replace(prompts_with_task[i], "").replace("<pad>", "").replace("</s>", "").strip()
111
+ for i, seq in enumerate(sequences)
112
+ ]
113
+
114
+ return results[0] if not is_batch else results
115
+
116
+
117
+ @spaces.GPU
118
+ def process_element_batch(elements, prompt, max_batch_size=16):
119
+ """Processes a batch of elements of the same type (e.g., text or tables)."""
120
+ results = []
121
+ for i in range(0, len(elements), max_batch_size):
122
+ batch_elements = elements[i:i + max_batch_size]
123
+ crops_list = [elem["crop"] for elem in batch_elements]
124
+ prompts_list = [prompt] * len(crops_list)
125
+
126
+ batch_results = model_inference(prompts_list, crops_list)
127
+
128
+ for j, result in enumerate(batch_results):
129
+ elem = batch_elements[j]
130
+ results.append({
131
+ "label": elem["label"],
132
+ "bbox": elem["bbox"],
133
+ "text": result.strip(),
134
+ "reading_order": elem["reading_order"],
135
+ })
136
+ return results
137
+
138
+
139
+ def convert_all_pdf_pages_to_images(file_path, target_size=896):
140
+ """Converts all pages of a PDF file to a list of image file paths."""
141
+ if not file_path or not file_path.lower().endswith('.pdf'):
142
+ logger.warning("Not a PDF file. No pages to convert.")
143
+ return []
144
+
145
+ image_paths = []
146
+ try:
147
+ doc = pymupdf.open(file_path)
148
+ for page_num in range(len(doc)):
149
+ page = doc[page_num]
150
+ scale = target_size / max(page.rect.width, page.rect.height)
151
+ mat = pymupdf.Matrix(scale, scale)
152
+ pix = page.get_pixmap(matrix=mat)
153
+
154
+ img_data = pix.tobytes("png")
155
+ pil_image = Image.open(io.BytesIO(img_data))
156
+
157
+ # Use a unique filename for each temporary page image
158
+ with tempfile.NamedTemporaryFile(suffix=f"_page_{page_num+1}.png", delete=False) as tmp_file:
159
+ pil_image.save(tmp_file.name, "PNG")
160
+ image_paths.append(tmp_file.name)
161
+ doc.close()
162
+ except Exception as e:
163
+ logger.error(f"Error converting PDF pages to images: {e}")
164
+ # Clean up any files that were created before the error
165
+ for path in image_paths:
166
+ cleanup_temp_file(path)
167
+ return []
168
+
169
+ return image_paths
170
+
171
+
172
+ def process_elements(layout_results, padded_image, dims):
173
+ """Crops and recognizes content for all document elements found in the layout."""
174
+ layout_results = parse_layout_string(layout_results)
175
+ text_elements, table_elements, figure_results = [], [], []
176
+ reading_order = 0
177
+ previous_box = None
178
+
179
+ for bbox, label in layout_results:
180
+ try:
181
+ x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
182
+ bbox, padded_image, dims, previous_box
183
+ )
184
+ cropped = padded_image[y1:y2, x1:x2]
185
+
186
+ if cropped.size > 0 and (cropped.shape[0] > 3 and cropped.shape[1] > 3):
187
+ pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
188
+ element_info = {
189
+ "crop": pil_crop, "label": label,
190
+ "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
191
+ "reading_order": reading_order,
192
+ }
193
+ if label == "tab":
194
+ table_elements.append(element_info)
195
+ elif label == "fig":
196
+ figure_results.append({**element_info, "text": "[FIGURE]"}) # Placeholder for figures
197
+ else:
198
+ text_elements.append(element_info)
199
+ reading_order += 1
200
+ except Exception as e:
201
+ logger.error(f"Error processing element with label {label}: {str(e)}")
202
+ continue
203
+
204
+ recognition_results = figure_results.copy()
205
+ if text_elements:
206
+ recognition_results.extend(process_element_batch(text_elements, "Read text in the image."))
207
+ if table_elements:
208
+ recognition_results.extend(process_element_batch(table_elements, "Parse the table in the image."))
209
+
210
+ recognition_results.sort(key=lambda x: x.get("reading_order", 0))
211
+ # Remove the temporary 'crop' key before returning JSON
212
+ for res in recognition_results:
213
+ res.pop('crop', None)
214
+
215
+ return recognition_results
216
+
217
+
218
+ def process_page(image_path):
219
+ """Processes a single page image to extract all content and return structured data."""
220
+ pil_image = Image.open(image_path).convert("RGB")
221
+
222
+ # 1. Get layout and reading order
223
+ layout_output = model_inference("Parse the reading order of this document.", pil_image)
224
+
225
+ # 2. Extract content from each element
226
+ padded_image, dims = prepare_image(pil_image)
227
+ recognition_results = process_elements(layout_output, padded_image, dims)
228
+
229
+ return recognition_results
230
+
231
+
232
+ def cleanup_temp_file(file_path):
233
+ """Safely deletes a temporary file if it exists."""
234
+ try:
235
+ if file_path and os.path.exists(file_path):
236
+ os.unlink(file_path)
237
+ except Exception as e:
238
+ logger.warning(f"Failed to cleanup temp file {file_path}: {e}")
239
+
240
+
241
+ @spaces.GPU(duration=120)
242
+ def pdf_to_json_converter(pdf_file):
243
+ """
244
+ Main function for the Gradio interface. Takes a PDF file, processes all pages,
245
+ and returns the structured data as a JSON object.
246
+ """
247
+ if pdf_file is None:
248
+ return {"error": "No file uploaded. Please upload a PDF file."}
249
+
250
+ start_time = time.time()
251
+ file_path = pdf_file.name
252
+ temp_files_created = []
253
+
254
+ try:
255
+ logger.info(f"Starting processing for document: {os.path.basename(file_path)}")
256
+
257
+ # Convert all PDF pages to images
258
+ image_paths = convert_all_pdf_pages_to_images(file_path)
259
+ if not image_paths:
260
+ raise Exception("Failed to convert PDF to images. The file might be corrupted or not a valid PDF.")
261
+ temp_files_created.extend(image_paths)
262
+
263
+ all_pages_data = []
264
+ # Process each page sequentially
265
+ for page_idx, image_path in enumerate(image_paths):
266
+ logger.info(f"Processing page {page_idx + 1}/{len(image_paths)}")
267
+ page_elements = process_page(image_path)
268
+ all_pages_data.append({
269
+ "page": page_idx + 1,
270
+ "elements": page_elements,
271
+ })
272
+
273
+ processing_time = time.time() - start_time
274
+ logger.info(f"Document processed successfully in {processing_time:.2f}s")
275
+
276
+ # Final JSON output structure
277
+ final_json = {
278
+ "document_info": {
279
+ "file_name": os.path.basename(file_path),
280
+ "total_pages": len(image_paths),
281
+ "processing_time_seconds": round(processing_time, 2),
282
+ },
283
+ "pages": all_pages_data
284
+ }
285
+ return final_json
286
+
287
+ except Exception as e:
288
+ logger.error(f"An error occurred during document processing: {str(e)}")
289
+ return {"error": str(e), "file_name": os.path.basename(file_path)}
290
+
291
+ finally:
292
+ # Cleanup all temporary image files created during processing
293
+ logger.info("Cleaning up temporary files...")
294
+ for temp_file in temp_files_created:
295
+ cleanup_temp_file(temp_file)
296
+
297
+
298
+ # --- Gradio UI ---
299
+ def build_gradio_interface():
300
+ """Builds and returns the simple Gradio UI."""
301
+ with gr.Blocks(title="PDF to JSON Converter") as demo:
302
+ gr.Markdown(
303
+ """
304
+ # PDF to JSON Converter
305
+ Upload a multi-page PDF to extract its content into a structured JSON format using the Dolphin model.
306
+ """
307
+ )
308
+
309
+ with gr.Row():
310
+ with gr.Column(scale=1):
311
+ pdf_input = gr.File(
312
+ label="Upload PDF File",
313
+ file_types=[".pdf"],
314
+ )
315
+ submit_btn = gr.Button("Convert to JSON", variant="primary")
316
+
317
+ with gr.Column(scale=2):
318
+ json_output = gr.JSON(label="JSON Output", scale=2)
319
+
320
+ submit_btn.click(
321
+ fn=pdf_to_json_converter,
322
+ inputs=[pdf_input],
323
+ outputs=[json_output],
324
+ )
325
+
326
+ # Add a clear button for convenience
327
+ clear_btn = gr.ClearButton(
328
+ value="Clear",
329
+ components=[pdf_input, json_output]
330
+ )
331
+
332
+ return demo
333
+
334
+
335
+ # --- Main Execution ---
336
+ if __name__ == "__main__":
337
+ logger.info("Starting Gradio application...")
338
+ try:
339
+ # Initialize the model on startup to avoid delays on the first request
340
+ initialize_model()
341
+
342
+ # Build and launch the Gradio interface
343
+ app_ui = build_gradio_interface()
344
+ app_ui.launch()
345
+
346
+ except Exception as main_exception:
347
+ logger.error(f"Failed to start the application: {main_exception}")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.6.0
2
+ transformers
3
+ pdf2image
4
+ Pillow
5
+ gradio
6
+ sentencepiece