Gabriel commited on
Commit
a987d91
·
verified ·
1 Parent(s): d6e55c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -83
app.py CHANGED
@@ -143,16 +143,11 @@ def process_htr(image: Image.Image, document_type: Literal["letter_english", "le
143
  except Exception as pipeline_error:
144
  return {"success": False, "error": f"Pipeline execution failed: {str(pipeline_error)}", "results": None}
145
 
146
- img_buffer = io.BytesIO()
147
- image.save(img_buffer, format="PNG")
148
- image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
149
-
150
  results = extract_text_results(processed_collection, confidence_threshold)
 
151
 
152
  processing_state = {
153
- "processed_collection": processed_collection,
154
- "image_base64": image_base64,
155
- "image_size": image.size,
156
  "document_type": document_type,
157
  "confidence_threshold": confidence_threshold,
158
  "timestamp": datetime.now().isoformat(),
@@ -161,7 +156,7 @@ def process_htr(image: Image.Image, document_type: Literal["letter_english", "le
161
  return {
162
  "success": True,
163
  "results": results,
164
- "processing_state": json.dumps(processing_state, default=str),
165
  "metadata": {
166
  "total_lines": len(results.get("text_lines", [])),
167
  "average_confidence": results.get("average_confidence", 0),
@@ -175,58 +170,44 @@ def process_htr(image: Image.Image, document_type: Literal["letter_english", "le
175
  if os.path.exists(temp_image_path):
176
  os.unlink(temp_image_path)
177
 
178
- def visualize_results(processing_state: str, visualization_type: Literal["overlay", "confidence_heatmap", "text_regions"] = "overlay", show_confidence: bool = True, highlight_low_confidence: bool = True, image: Optional[Image.Image] = None) -> Dict:
179
  """Generate interactive visualizations of HTR processing results."""
180
  try:
 
 
 
181
  state = json.loads(processing_state)
 
182
 
183
- if image is not None:
184
- original_image = image
185
- else:
186
- image_data = base64.b64decode(state["image_base64"])
187
- original_image = Image.open(io.BytesIO(image_data))
188
-
189
- # Recreate the collection from the stored image
190
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
191
- original_image.save(temp_file.name, "PNG")
192
- temp_image_path = temp_file.name
193
 
194
- try:
195
- collection = Collection([temp_image_path])
196
- pipeline = Pipeline.from_config(PIPELINE_CONFIGS[state["document_type"]])
197
- processed_collection = pipeline.run(collection)
198
-
199
- viz_image = create_visualization(original_image, processed_collection, visualization_type, show_confidence, highlight_low_confidence)
200
-
201
- img_buffer = io.BytesIO()
202
- viz_image.save(img_buffer, format="PNG")
203
- img_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
204
 
205
- return {
206
- "success": True,
207
- "visualization": {
208
- "image_base64": img_base64,
209
- "image_format": "PNG",
210
- "visualization_type": visualization_type,
211
- "dimensions": viz_image.size,
212
- },
213
- "metadata": {"visualization_type": visualization_type},
214
- }
215
- finally:
216
- if os.path.exists(temp_image_path):
217
- os.unlink(temp_image_path)
218
 
219
  except Exception as e:
220
  return {"success": False, "error": f"Visualization generation failed: {str(e)}", "visualization": None}
221
 
222
- def export_results(processing_state: str, output_formats: List[Literal["txt", "json", "alto", "page"]] = ["txt"], confidence_filter: float = 0.0) -> Dict:
223
  """Export HTR results to multiple formats using HTRflow's native export functionality."""
224
  try:
 
 
 
225
  state = json.loads(processing_state)
226
 
227
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
228
- image_data = base64.b64decode(state["image_base64"])
229
- image = Image.open(io.BytesIO(image_data))
230
  image.save(temp_file.name, "PNG")
231
  temp_image_path = temp_file.name
232
 
@@ -279,19 +260,33 @@ def extract_text_results(collection: Collection, confidence_threshold: float) ->
279
  results = {"extracted_text": "", "text_lines": [], "confidence_scores": []}
280
  for page in collection.pages:
281
  for node in page.traverse():
282
- if hasattr(node, "text") and node.text and hasattr(node, "confidence") and node.confidence >= confidence_threshold:
283
- results["text_lines"].append({
284
- "text": node.text,
285
- "confidence": node.confidence,
286
- "bbox": getattr(node, "bbox", None),
287
- })
288
- results["extracted_text"] += node.text + "\n"
289
- results["confidence_scores"].append(node.confidence)
 
 
290
 
291
  results["average_confidence"] = sum(results["confidence_scores"]) / len(results["confidence_scores"]) if results["confidence_scores"] else 0
292
  return results
293
 
294
- def create_visualization(image, collection, visualization_type, show_confidence, highlight_low_confidence):
 
 
 
 
 
 
 
 
 
 
 
 
295
  viz_image = image.copy()
296
  draw = ImageDraw.Draw(viz_image)
297
 
@@ -300,34 +295,33 @@ def create_visualization(image, collection, visualization_type, show_confidence,
300
  except:
301
  font = ImageFont.load_default()
302
 
303
- for page in collection.pages:
304
- for node in page.traverse():
305
- if hasattr(node, "bbox") and hasattr(node, "text") and node.bbox and node.text:
306
- bbox = node.bbox
307
- confidence = getattr(node, "confidence", 1.0)
308
-
309
- if visualization_type == "overlay":
310
- color = (255, 165, 0) if highlight_low_confidence and confidence < 0.7 else (0, 255, 0)
311
- draw.rectangle(bbox, outline=color, width=2)
312
- if show_confidence:
313
- draw.text((bbox[0], bbox[1] - 15), f"{confidence:.2f}", fill=color, font=font)
314
-
315
- elif visualization_type == "confidence_heatmap":
316
- if confidence < 0.5:
317
- color = (255, 0, 0, 100)
318
- elif confidence < 0.8:
319
- color = (255, 255, 0, 100)
320
- else:
321
- color = (0, 255, 0, 100)
322
- overlay = Image.new("RGBA", viz_image.size, (0, 0, 0, 0))
323
- overlay_draw = ImageDraw.Draw(overlay)
324
- overlay_draw.rectangle(bbox, fill=color)
325
- viz_image = Image.alpha_composite(viz_image.convert("RGBA"), overlay)
326
-
327
- elif visualization_type == "text_regions":
328
- colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
329
- color = colors[hash(str(bbox)) % len(colors)]
330
- draw.rectangle(bbox, outline=color, width=3)
331
 
332
  return viz_image.convert("RGB") if visualization_type == "confidence_heatmap" else viz_image
333
 
@@ -351,10 +345,10 @@ def create_htrflow_mcp_server():
351
  fn=visualize_results,
352
  inputs=[
353
  gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"),
 
354
  gr.Dropdown(choices=["overlay", "confidence_heatmap", "text_regions"], value="overlay", label="Visualization Type"),
355
  gr.Checkbox(value=True, label="Show Confidence Scores"),
356
  gr.Checkbox(value=True, label="Highlight Low Confidence"),
357
- gr.Image(type="pil", label="Image (optional)"),
358
  ],
359
  outputs=gr.JSON(label="Visualization Results"),
360
  title="Results Visualization Tool",
@@ -365,6 +359,7 @@ def create_htrflow_mcp_server():
365
  fn=export_results,
366
  inputs=[
367
  gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"),
 
368
  gr.CheckboxGroup(choices=["txt", "json", "alto", "page"], value=["txt"], label="Output Formats"),
369
  gr.Slider(0.0, 1.0, value=0.0, label="Confidence Filter"),
370
  ],
 
143
  except Exception as pipeline_error:
144
  return {"success": False, "error": f"Pipeline execution failed: {str(pipeline_error)}", "results": None}
145
 
 
 
 
 
146
  results = extract_text_results(processed_collection, confidence_threshold)
147
+ collection_data = serialize_collection_data(processed_collection)
148
 
149
  processing_state = {
150
+ "collection_data": collection_data,
 
 
151
  "document_type": document_type,
152
  "confidence_threshold": confidence_threshold,
153
  "timestamp": datetime.now().isoformat(),
 
156
  return {
157
  "success": True,
158
  "results": results,
159
+ "processing_state": json.dumps(processing_state),
160
  "metadata": {
161
  "total_lines": len(results.get("text_lines", [])),
162
  "average_confidence": results.get("average_confidence", 0),
 
170
  if os.path.exists(temp_image_path):
171
  os.unlink(temp_image_path)
172
 
173
+ def visualize_results(processing_state: str, image: Image.Image, visualization_type: Literal["overlay", "confidence_heatmap", "text_regions"] = "overlay", show_confidence: bool = True, highlight_low_confidence: bool = True) -> Dict:
174
  """Generate interactive visualizations of HTR processing results."""
175
  try:
176
+ if image is None:
177
+ return {"success": False, "error": "Image is required for visualization", "visualization": None}
178
+
179
  state = json.loads(processing_state)
180
+ collection_data = state["collection_data"]
181
 
182
+ viz_image = create_visualization(image, collection_data, visualization_type, show_confidence, highlight_low_confidence)
 
 
 
 
 
 
 
 
 
183
 
184
+ img_buffer = io.BytesIO()
185
+ viz_image.save(img_buffer, format="PNG")
186
+ img_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
 
 
 
 
 
 
 
187
 
188
+ return {
189
+ "success": True,
190
+ "visualization": {
191
+ "image_base64": img_base64,
192
+ "image_format": "PNG",
193
+ "visualization_type": visualization_type,
194
+ "dimensions": viz_image.size,
195
+ },
196
+ "metadata": {"total_elements": len(collection_data.get("text_elements", []))},
197
+ }
 
 
 
198
 
199
  except Exception as e:
200
  return {"success": False, "error": f"Visualization generation failed: {str(e)}", "visualization": None}
201
 
202
+ def export_results(processing_state: str, image: Image.Image, output_formats: List[Literal["txt", "json", "alto", "page"]] = ["txt"], confidence_filter: float = 0.0) -> Dict:
203
  """Export HTR results to multiple formats using HTRflow's native export functionality."""
204
  try:
205
+ if image is None:
206
+ return {"success": False, "error": "Image is required for export", "exports": None}
207
+
208
  state = json.loads(processing_state)
209
 
210
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
 
 
211
  image.save(temp_file.name, "PNG")
212
  temp_image_path = temp_file.name
213
 
 
260
  results = {"extracted_text": "", "text_lines": [], "confidence_scores": []}
261
  for page in collection.pages:
262
  for node in page.traverse():
263
+ if hasattr(node, "text") and node.text:
264
+ confidence = getattr(node, "confidence", 1.0)
265
+ if confidence >= confidence_threshold:
266
+ results["text_lines"].append({
267
+ "text": node.text,
268
+ "confidence": confidence,
269
+ "bbox": getattr(node, "bbox", None),
270
+ })
271
+ results["extracted_text"] += node.text + "\n"
272
+ results["confidence_scores"].append(confidence)
273
 
274
  results["average_confidence"] = sum(results["confidence_scores"]) / len(results["confidence_scores"]) if results["confidence_scores"] else 0
275
  return results
276
 
277
+ def serialize_collection_data(collection: Collection) -> Dict:
278
+ text_elements = []
279
+ for page in collection.pages:
280
+ for node in page.traverse():
281
+ if hasattr(node, "text") and node.text:
282
+ text_elements.append({
283
+ "text": node.text,
284
+ "confidence": getattr(node, "confidence", 1.0),
285
+ "bbox": getattr(node, "bbox", None),
286
+ })
287
+ return {"text_elements": text_elements}
288
+
289
+ def create_visualization(image, collection_data, visualization_type, show_confidence, highlight_low_confidence):
290
  viz_image = image.copy()
291
  draw = ImageDraw.Draw(viz_image)
292
 
 
295
  except:
296
  font = ImageFont.load_default()
297
 
298
+ for element in collection_data.get("text_elements", []):
299
+ if element.get("bbox"):
300
+ bbox = element["bbox"]
301
+ confidence = element.get("confidence", 1.0)
302
+
303
+ if visualization_type == "overlay":
304
+ color = (255, 165, 0) if highlight_low_confidence and confidence < 0.7 else (0, 255, 0)
305
+ draw.rectangle(bbox, outline=color, width=2)
306
+ if show_confidence:
307
+ draw.text((bbox[0], bbox[1] - 15), f"{confidence:.2f}", fill=color, font=font)
308
+
309
+ elif visualization_type == "confidence_heatmap":
310
+ if confidence < 0.5:
311
+ color = (255, 0, 0, 100)
312
+ elif confidence < 0.8:
313
+ color = (255, 255, 0, 100)
314
+ else:
315
+ color = (0, 255, 0, 100)
316
+ overlay = Image.new("RGBA", viz_image.size, (0, 0, 0, 0))
317
+ overlay_draw = ImageDraw.Draw(overlay)
318
+ overlay_draw.rectangle(bbox, fill=color)
319
+ viz_image = Image.alpha_composite(viz_image.convert("RGBA"), overlay)
320
+
321
+ elif visualization_type == "text_regions":
322
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
323
+ color = colors[hash(str(bbox)) % len(colors)]
324
+ draw.rectangle(bbox, outline=color, width=3)
 
325
 
326
  return viz_image.convert("RGB") if visualization_type == "confidence_heatmap" else viz_image
327
 
 
345
  fn=visualize_results,
346
  inputs=[
347
  gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"),
348
+ gr.Image(type="pil", label="Image"),
349
  gr.Dropdown(choices=["overlay", "confidence_heatmap", "text_regions"], value="overlay", label="Visualization Type"),
350
  gr.Checkbox(value=True, label="Show Confidence Scores"),
351
  gr.Checkbox(value=True, label="Highlight Low Confidence"),
 
352
  ],
353
  outputs=gr.JSON(label="Visualization Results"),
354
  title="Results Visualization Tool",
 
359
  fn=export_results,
360
  inputs=[
361
  gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"),
362
+ gr.Image(type="pil", label="Image"),
363
  gr.CheckboxGroup(choices=["txt", "json", "alto", "page"], value=["txt"], label="Output Formats"),
364
  gr.Slider(0.0, 1.0, value=0.0, label="Confidence Filter"),
365
  ],