Gabriel commited on
Commit
f094617
·
verified ·
1 Parent(s): 1ec4316

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -606
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- import yaml
3
  import json
4
  import base64
5
  import tempfile
@@ -9,7 +8,8 @@ from datetime import datetime
9
  from PIL import Image, ImageDraw, ImageFont
10
  import io
11
  import spaces
12
-
 
13
  from htrflow.volume.volume import Collection
14
  from htrflow.pipeline.pipeline import Pipeline
15
 
@@ -20,9 +20,7 @@ PIPELINE_CONFIGS = {
20
  "step": "Segmentation",
21
  "settings": {
22
  "model": "yolo",
23
- "model_settings": {
24
- "model": "Riksarkivet/yolov9-lines-within-regions-1"
25
- },
26
  "generation_settings": {"batch_size": 8},
27
  },
28
  },
@@ -43,9 +41,7 @@ PIPELINE_CONFIGS = {
43
  "step": "Segmentation",
44
  "settings": {
45
  "model": "yolo",
46
- "model_settings": {
47
- "model": "Riksarkivet/yolov9-lines-within-regions-1"
48
- },
49
  "generation_settings": {"batch_size": 8},
50
  },
51
  },
@@ -53,9 +49,7 @@ PIPELINE_CONFIGS = {
53
  "step": "TextRecognition",
54
  "settings": {
55
  "model": "TrOCR",
56
- "model_settings": {
57
- "model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"
58
- },
59
  "generation_settings": {"batch_size": 16},
60
  },
61
  },
@@ -76,9 +70,7 @@ PIPELINE_CONFIGS = {
76
  "step": "Segmentation",
77
  "settings": {
78
  "model": "yolo",
79
- "model_settings": {
80
- "model": "Riksarkivet/yolov9-lines-within-regions-1"
81
- },
82
  "generation_settings": {"batch_size": 8},
83
  },
84
  },
@@ -107,9 +99,7 @@ PIPELINE_CONFIGS = {
107
  "step": "Segmentation",
108
  "settings": {
109
  "model": "yolo",
110
- "model_settings": {
111
- "model": "Riksarkivet/yolov9-lines-within-regions-1"
112
- },
113
  "generation_settings": {"batch_size": 8},
114
  },
115
  },
@@ -117,9 +107,7 @@ PIPELINE_CONFIGS = {
117
  "step": "TextRecognition",
118
  "settings": {
119
  "model": "TrOCR",
120
- "model_settings": {
121
- "model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"
122
- },
123
  "generation_settings": {"batch_size": 16},
124
  },
125
  },
@@ -129,30 +117,8 @@ PIPELINE_CONFIGS = {
129
  }
130
 
131
  @spaces.GPU
132
- def process_htr(
133
- image: Image.Image,
134
- document_type: Literal[
135
- "letter_english", "letter_swedish", "spread_english", "spread_swedish"
136
- ] = "spread_swedish",
137
- confidence_threshold: float = 0.8,
138
- custom_settings: Optional[str] = None,
139
- ) -> Dict:
140
- """
141
- Process handwritten text recognition on uploaded images using HTRflow pipelines.
142
-
143
- Supports templates for different document types (letters vs spreads) and
144
- languages (English vs Swedish). Uses HTRflow's modular pipeline system with
145
- configurable segmentation and text recognition models.
146
-
147
- Args:
148
- image (Image.Image): PIL Image object to process
149
- document_type (str): Type of document processing template to use
150
- confidence_threshold (float): Minimum confidence threshold for text recognition
151
- custom_settings (str, optional): JSON string with custom pipeline settings
152
-
153
- Returns:
154
- dict: Processing results including extracted text, metadata, and processing state
155
- """
156
  try:
157
  if image is None:
158
  return {"success": False, "error": "No image provided", "results": None}
@@ -166,30 +132,22 @@ def process_htr(
166
  try:
167
  config = json.loads(custom_settings)
168
  except json.JSONDecodeError:
169
- return {
170
- "success": False,
171
- "error": "Invalid JSON in custom_settings parameter",
172
- "results": None,
173
- }
174
  else:
175
  config = PIPELINE_CONFIGS[document_type]
176
 
177
  collection = Collection([temp_image_path])
178
-
179
  pipeline = Pipeline.from_config(config)
180
  processed_collection = pipeline.run(collection)
181
 
182
- results = extract_processing_results(
183
- processed_collection, confidence_threshold
184
- )
185
-
186
  img_buffer = io.BytesIO()
187
  image.save(img_buffer, format="PNG")
188
  image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
189
 
 
 
190
  processing_state = {
191
- "collection": serialize_collection(processed_collection),
192
- "config": config,
193
  "image_base64": image_base64,
194
  "image_size": image.size,
195
  "document_type": document_type,
@@ -203,54 +161,22 @@ def process_htr(
203
  "processing_state": json.dumps(processing_state),
204
  "metadata": {
205
  "total_lines": len(results.get("text_lines", [])),
206
- "average_confidence": calculate_average_confidence(results),
207
  "document_type": document_type,
208
  "image_dimensions": image.size,
209
  },
210
  }
211
-
212
  finally:
213
  if os.path.exists(temp_image_path):
214
  os.unlink(temp_image_path)
215
-
216
  except Exception as e:
217
- return {
218
- "success": False,
219
- "error": f"HTR processing failed: {str(e)}",
220
- "results": None,
221
- }
222
 
223
-
224
- def visualize_results(
225
- processing_state: str,
226
- visualization_type: Literal[
227
- "overlay", "confidence_heatmap", "text_regions"
228
- ] = "overlay",
229
- show_confidence: bool = True,
230
- highlight_low_confidence: bool = True,
231
- image: Optional[Image.Image] = None,
232
- ) -> Dict:
233
- """
234
- Generate interactive visualizations of HTR processing results.
235
-
236
- Creates visual representations of text recognition results including bounding box
237
- overlays, confidence heatmaps, and region segmentation displays. Supports multiple
238
- visualization modes for different analysis needs.
239
-
240
- Args:
241
- processing_state (str): JSON string containing HTR processing results and metadata
242
- visualization_type (str): Type of visualization to generate
243
- show_confidence (bool): Whether to display confidence scores on visualization
244
- highlight_low_confidence (bool): Whether to highlight low-confidence regions
245
- image (Image.Image, optional): PIL Image object to use instead of state image
246
-
247
- Returns:
248
- dict: Visualization data including base64-encoded images and metadata
249
- """
250
  try:
251
  state = json.loads(processing_state)
252
- collection = deserialize_collection(state["collection"])
253
- confidence_threshold = state["confidence_threshold"]
254
 
255
  if image is not None:
256
  original_image = image
@@ -258,23 +184,12 @@ def visualize_results(
258
  image_data = base64.b64decode(state["image_base64"])
259
  original_image = Image.open(io.BytesIO(image_data))
260
 
261
- if visualization_type == "overlay":
262
- viz_image = create_text_overlay_visualization(
263
- original_image, collection, show_confidence, highlight_low_confidence
264
- )
265
- elif visualization_type == "confidence_heatmap":
266
- viz_image = create_confidence_heatmap(
267
- original_image, collection, confidence_threshold
268
- )
269
- elif visualization_type == "text_regions":
270
- viz_image = create_region_visualization(original_image, collection)
271
 
272
  img_buffer = io.BytesIO()
273
  viz_image.save(img_buffer, format="PNG")
274
  img_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
275
 
276
- viz_metadata = generate_visualization_metadata(collection, visualization_type)
277
-
278
  return {
279
  "success": True,
280
  "visualization": {
@@ -283,521 +198,139 @@ def visualize_results(
283
  "visualization_type": visualization_type,
284
  "dimensions": viz_image.size,
285
  },
286
- "metadata": viz_metadata,
287
- "interactive_elements": extract_interactive_elements(collection),
288
  }
289
-
290
  except Exception as e:
291
- return {
292
- "success": False,
293
- "error": f"Visualization generation failed: {str(e)}",
294
- "visualization": None,
295
- }
296
-
297
-
298
- def export_results(
299
- processing_state: str,
300
- output_formats: List[Literal["txt", "json", "alto", "page"]] = ["txt"],
301
- include_metadata: bool = True,
302
- confidence_filter: float = 0.0,
303
- ) -> Dict:
304
- """
305
- Export HTR results to multiple formats including plain text, structured JSON, ALTO XML, and PAGE XML.
306
 
307
- Supports HTRflow's native export functionality with configurable output formats and
308
- filtering options. Maintains document structure and metadata across all export formats.
309
-
310
- Args:
311
- processing_state (str): JSON string containing HTR processing results
312
- output_formats (List[str]): List of output formats to generate
313
- include_metadata (bool): Whether to include processing metadata in exports
314
- confidence_filter (float): Minimum confidence threshold for included text
315
-
316
- Returns:
317
- dict: Export results with content for each requested format
318
- """
319
  try:
320
- # Parse processing state
321
  state = json.loads(processing_state)
322
- collection = deserialize_collection(state["collection"])
323
- config = state["config"]
324
-
325
- # Generate exports for each requested format
326
- exports = {}
327
-
328
- for format_type in output_formats:
329
- if format_type == "txt":
330
- exports["txt"] = export_plain_text(
331
- collection, confidence_filter, include_metadata
332
- )
333
- elif format_type == "json":
334
- exports["json"] = export_structured_json(
335
- collection, confidence_filter, include_metadata
336
- )
337
- elif format_type == "alto":
338
- exports["alto"] = export_alto_xml(
339
- collection, confidence_filter, include_metadata
340
- )
341
- elif format_type == "page":
342
- exports["page"] = export_page_xml(
343
- collection, confidence_filter, include_metadata
344
- )
345
 
346
- # Calculate export statistics
347
- export_stats = calculate_export_statistics(collection, confidence_filter)
 
 
348
 
349
- return {
350
- "success": True,
351
- "exports": exports,
352
- "statistics": export_stats,
353
- "export_metadata": {
354
- "formats_generated": output_formats,
355
- "confidence_filter": confidence_filter,
356
- "include_metadata": include_metadata,
357
- "timestamp": datetime.now().isoformat(),
358
- },
359
- }
 
 
 
 
 
 
 
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  except Exception as e:
362
- return {
363
- "success": False,
364
- "error": f"Export generation failed: {str(e)}",
365
- "exports": None,
366
- }
367
-
368
-
369
- # Helper Functions
370
- def extract_processing_results(
371
- collection: Collection, confidence_threshold: float
372
- ) -> Dict:
373
- """Extract structured results from processed HTRflow Collection."""
374
- results = {
375
- "extracted_text": "",
376
- "text_lines": [],
377
- "regions": [],
378
- "confidence_scores": [],
379
- }
380
 
381
- # Traverse collection hierarchy to extract text and metadata
 
382
  for page in collection.pages:
383
  for node in page.traverse():
384
- if hasattr(node, "text") and node.text:
385
- if (
386
- hasattr(node, "confidence")
387
- and node.confidence >= confidence_threshold
388
- ):
389
- results["text_lines"].append(
390
- {
391
- "text": node.text,
392
- "confidence": node.confidence,
393
- "bbox": getattr(node, "bbox", None),
394
- "node_id": getattr(node, "id", None),
395
- }
396
- )
397
- results["extracted_text"] += node.text + "\n"
398
- results["confidence_scores"].append(node.confidence)
399
-
400
  return results
401
 
402
-
403
- def serialize_collection(collection: Collection) -> str:
404
- """Serialize HTRflow Collection to JSON string for state storage."""
405
- serialized_data = {"pages": [], "metadata": getattr(collection, "metadata", {})}
406
-
407
  for page in collection.pages:
408
- page_data = {
409
- "nodes": [],
410
- "image_path": getattr(page, "image_path", None),
411
- "dimensions": getattr(page, "dimensions", None),
412
- }
413
-
414
  for node in page.traverse():
415
- node_data = {
416
- "text": getattr(node, "text", ""),
417
- "confidence": getattr(node, "confidence", 1.0),
418
- "bbox": getattr(node, "bbox", None),
419
- "node_id": getattr(node, "id", None),
420
- "node_type": type(node).__name__,
421
- }
422
- page_data["nodes"].append(node_data)
423
-
424
- serialized_data["pages"].append(page_data)
425
-
426
- return json.dumps(serialized_data)
427
-
428
-
429
- def deserialize_collection(serialized_data: str):
430
- """Deserialize JSON string back to HTRflow Collection."""
431
- data = json.loads(serialized_data)
432
-
433
- # Mock collection classes for state reconstruction
434
- class MockCollection:
435
- def __init__(self, data):
436
- self.pages = []
437
- for page_data in data.get("pages", []):
438
- page = MockPage(page_data)
439
- self.pages.append(page)
440
-
441
- class MockPage:
442
- def __init__(self, page_data):
443
- self.nodes = []
444
- for node_data in page_data.get("nodes", []):
445
- node = MockNode(node_data)
446
- self.nodes.append(node)
447
-
448
- def traverse(self):
449
- return self.nodes
450
-
451
- class MockNode:
452
- def __init__(self, node_data):
453
- self.text = node_data.get("text", "")
454
- self.confidence = node_data.get("confidence", 1.0)
455
- self.bbox = node_data.get("bbox")
456
- self.id = node_data.get("node_id")
457
-
458
- return MockCollection(data)
459
-
460
-
461
- def calculate_average_confidence(results: Dict) -> float:
462
- """Calculate average confidence score from processing results."""
463
- confidence_scores = results.get("confidence_scores", [])
464
- if not confidence_scores:
465
- return 0.0
466
- return sum(confidence_scores) / len(confidence_scores)
467
-
468
-
469
- def create_text_overlay_visualization(
470
- image, collection, show_confidence, highlight_low_confidence
471
- ):
472
- """Create image with text bounding boxes and recognition results overlaid."""
473
  viz_image = image.copy()
474
  draw = ImageDraw.Draw(viz_image)
475
-
476
- # Define visualization styles
477
- bbox_color = (0, 255, 0) # Green for normal confidence
478
- low_conf_color = (255, 165, 0) # Orange for low confidence
479
- text_color = (255, 255, 255) # White text
480
-
481
  try:
482
  font = ImageFont.truetype("arial.ttf", 12)
483
  except:
484
  font = ImageFont.load_default()
485
 
486
- # Draw bounding boxes and text for each recognized element
487
- for page in collection.pages:
488
- for node in page.traverse():
489
- if (
490
- hasattr(node, "bbox")
491
- and hasattr(node, "text")
492
- and node.bbox
493
- and node.text
494
- ):
495
- bbox = node.bbox
496
- confidence = getattr(node, "confidence", 1.0)
497
-
498
- # Choose color based on confidence
499
- if highlight_low_confidence and confidence < 0.7:
500
- color = low_conf_color
501
- else:
502
- color = bbox_color
503
-
504
- # Draw bounding box
505
  draw.rectangle(bbox, outline=color, width=2)
506
-
507
- # Add confidence score if requested
508
  if show_confidence:
509
- conf_text = f"{confidence:.2f}"
510
- draw.text((bbox[0], bbox[1] - 15), conf_text, fill=color, font=font)
511
-
512
- return viz_image
513
-
514
-
515
- def create_confidence_heatmap(image, collection, confidence_threshold):
516
- """Create confidence heatmap visualization."""
517
- viz_image = image.copy()
518
-
519
- # Create heatmap overlay based on confidence scores
520
- for page in collection.pages:
521
- for node in page.traverse():
522
- if hasattr(node, "bbox") and hasattr(node, "confidence") and node.bbox:
523
- confidence = node.confidence
524
- # Color mapping: red (low) -> yellow (medium) -> green (high)
525
  if confidence < 0.5:
526
- color = (255, 0, 0, 100) # Red with transparency
527
  elif confidence < 0.8:
528
- color = (255, 255, 0, 100) # Yellow with transparency
529
  else:
530
- color = (0, 255, 0, 100) # Green with transparency
531
-
532
- # Create overlay image for transparency
533
  overlay = Image.new("RGBA", viz_image.size, (0, 0, 0, 0))
534
  overlay_draw = ImageDraw.Draw(overlay)
535
- overlay_draw.rectangle(node.bbox, fill=color)
536
  viz_image = Image.alpha_composite(viz_image.convert("RGBA"), overlay)
 
 
 
 
 
537
 
538
- return viz_image.convert("RGB")
539
-
540
-
541
- def create_region_visualization(image, collection):
542
- """Create region segmentation visualization."""
543
- viz_image = image.copy()
544
- draw = ImageDraw.Draw(viz_image)
545
-
546
- # Draw different colors for different region types
547
- region_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
548
- region_count = 0
549
-
550
- for page in collection.pages:
551
- for node in page.traverse():
552
- if hasattr(node, "bbox") and node.bbox:
553
- color = region_colors[region_count % len(region_colors)]
554
- draw.rectangle(node.bbox, outline=color, width=3)
555
- region_count += 1
556
-
557
- return viz_image
558
-
559
-
560
- def generate_visualization_metadata(collection, visualization_type):
561
- """Generate metadata for visualization results."""
562
- total_elements = 0
563
- confidence_stats = []
564
-
565
- for page in collection.pages:
566
- for node in page.traverse():
567
- if hasattr(node, "text") and node.text:
568
- total_elements += 1
569
- if hasattr(node, "confidence"):
570
- confidence_stats.append(node.confidence)
571
-
572
- return {
573
- "total_elements": total_elements,
574
- "visualization_type": visualization_type,
575
- "confidence_stats": {
576
- "min": min(confidence_stats) if confidence_stats else 0,
577
- "max": max(confidence_stats) if confidence_stats else 0,
578
- "avg": sum(confidence_stats) / len(confidence_stats)
579
- if confidence_stats
580
- else 0,
581
- },
582
- }
583
-
584
-
585
- def extract_interactive_elements(collection):
586
- """Extract interactive elements for visualization."""
587
- elements = []
588
-
589
- for page in collection.pages:
590
- for node in page.traverse():
591
- if (
592
- hasattr(node, "bbox")
593
- and hasattr(node, "text")
594
- and node.bbox
595
- and node.text
596
- ):
597
- elements.append(
598
- {
599
- "bbox": node.bbox,
600
- "text": node.text,
601
- "confidence": getattr(node, "confidence", 1.0),
602
- "node_id": getattr(node, "id", None),
603
- }
604
- )
605
-
606
- return elements
607
-
608
-
609
- def export_plain_text(
610
- collection, confidence_filter: float, include_metadata: bool
611
- ) -> str:
612
- """Export recognition results as plain text."""
613
- text_lines = []
614
-
615
- if include_metadata:
616
- text_lines.append(f"# HTR Export Results")
617
- text_lines.append(f"# Confidence Filter: {confidence_filter}")
618
- text_lines.append(f"# Export Time: {datetime.now().isoformat()}")
619
- text_lines.append("")
620
-
621
- # Extract text from collection hierarchy
622
- for page in collection.pages:
623
- for node in page.traverse():
624
- if hasattr(node, "text") and node.text:
625
- confidence = getattr(node, "confidence", 1.0)
626
- if confidence >= confidence_filter:
627
- text_lines.append(node.text)
628
-
629
- return "\n".join(text_lines)
630
-
631
 
632
- def export_structured_json(
633
- collection, confidence_filter: float, include_metadata: bool
634
- ) -> str:
635
- """Export results as structured JSON with full hierarchy."""
636
- result = {"document": {"pages": []}}
637
-
638
- if include_metadata:
639
- result["metadata"] = {
640
- "confidence_filter": confidence_filter,
641
- "export_time": datetime.now().isoformat(),
642
- "total_pages": len(collection.pages),
643
- }
644
-
645
- # Build hierarchical structure
646
- for page_idx, page in enumerate(collection.pages):
647
- page_data = {"page_id": page_idx, "regions": []}
648
-
649
- for node in page.traverse():
650
- if hasattr(node, "text") and node.text:
651
- confidence = getattr(node, "confidence", 1.0)
652
- if confidence >= confidence_filter:
653
- node_data = {
654
- "text": node.text,
655
- "confidence": confidence,
656
- "bbox": getattr(node, "bbox", None),
657
- "node_id": getattr(node, "id", None),
658
- }
659
- page_data["regions"].append(node_data)
660
-
661
- result["document"]["pages"].append(page_data)
662
-
663
- return json.dumps(result, indent=2, ensure_ascii=False)
664
-
665
-
666
- def export_alto_xml(
667
- collection, confidence_filter: float, include_metadata: bool
668
- ) -> str:
669
- """Export results as ALTO XML format."""
670
- # Simplified ALTO XML generation
671
- xml_lines = ['<?xml version="1.0" encoding="UTF-8"?>']
672
- xml_lines.append('<alto xmlns="http://www.loc.gov/standards/alto/ns-v4#">')
673
- xml_lines.append(" <Description>")
674
- if include_metadata:
675
- xml_lines.append(f" <sourceImageInformation>")
676
- xml_lines.append(f" <fileName>htr_processed_image</fileName>")
677
- xml_lines.append(f" </sourceImageInformation>")
678
- xml_lines.append(" </Description>")
679
- xml_lines.append(" <Layout>")
680
- xml_lines.append(" <Page>")
681
-
682
- for page in collection.pages:
683
- for node in page.traverse():
684
- if hasattr(node, "text") and node.text:
685
- confidence = getattr(node, "confidence", 1.0)
686
- if confidence >= confidence_filter:
687
- bbox = getattr(node, "bbox", [0, 0, 100, 20])
688
- xml_lines.append(
689
- f' <TextLine HPOS="{bbox[0]}" VPOS="{bbox[1]}" WIDTH="{bbox[2] - bbox[0]}" HEIGHT="{bbox[3] - bbox[1]}">'
690
- )
691
- xml_lines.append(
692
- f' <String CONTENT="{node.text}" WC="{confidence:.3f}"/>'
693
- )
694
- xml_lines.append(" </TextLine>")
695
-
696
- xml_lines.append(" </Page>")
697
- xml_lines.append(" </Layout>")
698
- xml_lines.append("</alto>")
699
-
700
- return "\n".join(xml_lines)
701
-
702
-
703
- def export_page_xml(
704
- collection, confidence_filter: float, include_metadata: bool
705
- ) -> str:
706
- """Export results as PAGE XML format."""
707
- # Simplified PAGE XML generation
708
- xml_lines = ['<?xml version="1.0" encoding="UTF-8"?>']
709
- xml_lines.append(
710
- '<PcGts xmlns="http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15">'
711
- )
712
- if include_metadata:
713
- xml_lines.append(" <Metadata>")
714
- xml_lines.append(f" <Created>{datetime.now().isoformat()}</Created>")
715
- xml_lines.append(" </Metadata>")
716
- xml_lines.append(" <Page>")
717
-
718
- for page in collection.pages:
719
- for node in page.traverse():
720
- if hasattr(node, "text") and node.text:
721
- confidence = getattr(node, "confidence", 1.0)
722
- if confidence >= confidence_filter:
723
- bbox = getattr(node, "bbox", [0, 0, 100, 20])
724
- xml_lines.append(f" <TextRegion>")
725
- xml_lines.append(
726
- f' <Coords points="{bbox[0]},{bbox[1]} {bbox[2]},{bbox[1]} {bbox[2]},{bbox[3]} {bbox[0]},{bbox[3]}"/>'
727
- )
728
- xml_lines.append(f" <TextLine>")
729
- xml_lines.append(f' <TextEquiv conf="{confidence:.3f}">')
730
- xml_lines.append(f" <Unicode>{node.text}</Unicode>")
731
- xml_lines.append(" </TextEquiv>")
732
- xml_lines.append(" </TextLine>")
733
- xml_lines.append(" </TextRegion>")
734
-
735
- xml_lines.append(" </Page>")
736
- xml_lines.append("</PcGts>")
737
-
738
- return "\n".join(xml_lines)
739
-
740
-
741
- def calculate_export_statistics(collection, confidence_filter: float) -> Dict:
742
- """Calculate statistics for export results."""
743
- total_text_elements = 0
744
- filtered_text_elements = 0
745
- confidence_scores = []
746
- total_characters = 0
747
-
748
- for page in collection.pages:
749
- for node in page.traverse():
750
- if hasattr(node, "text") and node.text:
751
- total_text_elements += 1
752
- confidence = getattr(node, "confidence", 1.0)
753
- confidence_scores.append(confidence)
754
-
755
- if confidence >= confidence_filter:
756
- filtered_text_elements += 1
757
- total_characters += len(node.text)
758
-
759
- return {
760
- "total_text_elements": total_text_elements,
761
- "filtered_text_elements": filtered_text_elements,
762
- "filter_retention_rate": filtered_text_elements / total_text_elements
763
- if total_text_elements > 0
764
- else 0,
765
- "total_characters": total_characters,
766
- "average_confidence": sum(confidence_scores) / len(confidence_scores)
767
- if confidence_scores
768
- else 0,
769
- "confidence_range": {
770
- "min": min(confidence_scores) if confidence_scores else 0,
771
- "max": max(confidence_scores) if confidence_scores else 0,
772
- },
773
- }
774
-
775
-
776
- # Main Gradio Application with MCP Server
777
  def create_htrflow_mcp_server():
778
- """Create the complete HTRflow MCP server with all three tools."""
779
-
780
  demo = gr.TabbedInterface(
781
  [
782
  gr.Interface(
783
  fn=process_htr,
784
  inputs=[
785
  gr.Image(type="pil", label="Upload Image"),
786
- gr.Dropdown(
787
- choices=[
788
- "letter_english",
789
- "letter_swedish",
790
- "spread_english",
791
- "spread_swedish",
792
- ],
793
- value="letter_english",
794
- label="Document Type",
795
- ),
796
  gr.Slider(0.0, 1.0, value=0.8, label="Confidence Threshold"),
797
- gr.Textbox(
798
- label="Custom Settings (JSON)",
799
- placeholder="Optional custom pipeline settings",
800
- ),
801
  ],
802
  outputs=gr.JSON(label="Processing Results"),
803
  title="HTR Processing Tool",
@@ -807,21 +340,11 @@ def create_htrflow_mcp_server():
807
  gr.Interface(
808
  fn=visualize_results,
809
  inputs=[
810
- gr.Textbox(
811
- label="Processing State (JSON)",
812
- placeholder="Paste processing results from HTR tool",
813
- ),
814
- gr.Dropdown(
815
- choices=["overlay", "confidence_heatmap", "text_regions"],
816
- value="overlay",
817
- label="Visualization Type",
818
- ),
819
  gr.Checkbox(value=True, label="Show Confidence Scores"),
820
  gr.Checkbox(value=True, label="Highlight Low Confidence"),
821
- gr.Image(
822
- type="pil",
823
- label="Image (optional - will use image from processing state if not provided)",
824
- ),
825
  ],
826
  outputs=gr.JSON(label="Visualization Results"),
827
  title="Results Visualization Tool",
@@ -831,16 +354,8 @@ def create_htrflow_mcp_server():
831
  gr.Interface(
832
  fn=export_results,
833
  inputs=[
834
- gr.Textbox(
835
- label="Processing State (JSON)",
836
- placeholder="Paste processing results from HTR tool",
837
- ),
838
- gr.CheckboxGroup(
839
- choices=["txt", "json", "alto", "page"],
840
- value=["txt"],
841
- label="Output Formats",
842
- ),
843
- gr.Checkbox(value=True, label="Include Metadata"),
844
  gr.Slider(0.0, 1.0, value=0.0, label="Confidence Filter"),
845
  ],
846
  outputs=gr.JSON(label="Export Results"),
@@ -852,11 +367,8 @@ def create_htrflow_mcp_server():
852
  ["HTR Processing", "Results Visualization", "Export Results"],
853
  title="HTRflow MCP Server",
854
  )
855
-
856
  return demo
857
 
858
-
859
- # Launch MCP Server
860
  if __name__ == "__main__":
861
  demo = create_htrflow_mcp_server()
862
- demo.launch(mcp_server=True)
 
1
  import gradio as gr
 
2
  import json
3
  import base64
4
  import tempfile
 
8
  from PIL import Image, ImageDraw, ImageFont
9
  import io
10
  import spaces
11
+ import shutil
12
+ from pathlib import Path
13
  from htrflow.volume.volume import Collection
14
  from htrflow.pipeline.pipeline import Pipeline
15
 
 
20
  "step": "Segmentation",
21
  "settings": {
22
  "model": "yolo",
23
+ "model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"},
 
 
24
  "generation_settings": {"batch_size": 8},
25
  },
26
  },
 
41
  "step": "Segmentation",
42
  "settings": {
43
  "model": "yolo",
44
+ "model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"},
 
 
45
  "generation_settings": {"batch_size": 8},
46
  },
47
  },
 
49
  "step": "TextRecognition",
50
  "settings": {
51
  "model": "TrOCR",
52
+ "model_settings": {"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"},
 
 
53
  "generation_settings": {"batch_size": 16},
54
  },
55
  },
 
70
  "step": "Segmentation",
71
  "settings": {
72
  "model": "yolo",
73
+ "model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"},
 
 
74
  "generation_settings": {"batch_size": 8},
75
  },
76
  },
 
99
  "step": "Segmentation",
100
  "settings": {
101
  "model": "yolo",
102
+ "model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"},
 
 
103
  "generation_settings": {"batch_size": 8},
104
  },
105
  },
 
107
  "step": "TextRecognition",
108
  "settings": {
109
  "model": "TrOCR",
110
+ "model_settings": {"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"},
 
 
111
  "generation_settings": {"batch_size": 16},
112
  },
113
  },
 
117
  }
118
 
119
  @spaces.GPU
120
+ def process_htr(image: Image.Image, document_type: Literal["letter_english", "letter_swedish", "spread_english", "spread_swedish"] = "spread_swedish", confidence_threshold: float = 0.8, custom_settings: Optional[str] = None) -> Dict:
121
+ """Process handwritten text recognition on uploaded images using HTRflow pipelines."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  try:
123
  if image is None:
124
  return {"success": False, "error": "No image provided", "results": None}
 
132
  try:
133
  config = json.loads(custom_settings)
134
  except json.JSONDecodeError:
135
+ return {"success": False, "error": "Invalid JSON in custom_settings parameter", "results": None}
 
 
 
 
136
  else:
137
  config = PIPELINE_CONFIGS[document_type]
138
 
139
  collection = Collection([temp_image_path])
 
140
  pipeline = Pipeline.from_config(config)
141
  processed_collection = pipeline.run(collection)
142
 
 
 
 
 
143
  img_buffer = io.BytesIO()
144
  image.save(img_buffer, format="PNG")
145
  image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
146
 
147
+ results = extract_text_results(processed_collection, confidence_threshold)
148
+
149
  processing_state = {
150
+ "collection_data": serialize_collection_data(processed_collection),
 
151
  "image_base64": image_base64,
152
  "image_size": image.size,
153
  "document_type": document_type,
 
161
  "processing_state": json.dumps(processing_state),
162
  "metadata": {
163
  "total_lines": len(results.get("text_lines", [])),
164
+ "average_confidence": results.get("average_confidence", 0),
165
  "document_type": document_type,
166
  "image_dimensions": image.size,
167
  },
168
  }
 
169
  finally:
170
  if os.path.exists(temp_image_path):
171
  os.unlink(temp_image_path)
 
172
  except Exception as e:
173
+ return {"success": False, "error": f"HTR processing failed: {str(e)}", "results": None}
 
 
 
 
174
 
175
+ def visualize_results(processing_state: str, visualization_type: Literal["overlay", "confidence_heatmap", "text_regions"] = "overlay", show_confidence: bool = True, highlight_low_confidence: bool = True, image: Optional[Image.Image] = None) -> Dict:
176
+ """Generate interactive visualizations of HTR processing results."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  try:
178
  state = json.loads(processing_state)
179
+ collection_data = state["collection_data"]
 
180
 
181
  if image is not None:
182
  original_image = image
 
184
  image_data = base64.b64decode(state["image_base64"])
185
  original_image = Image.open(io.BytesIO(image_data))
186
 
187
+ viz_image = create_visualization(original_image, collection_data, visualization_type, show_confidence, highlight_low_confidence)
 
 
 
 
 
 
 
 
 
188
 
189
  img_buffer = io.BytesIO()
190
  viz_image.save(img_buffer, format="PNG")
191
  img_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
192
 
 
 
193
  return {
194
  "success": True,
195
  "visualization": {
 
198
  "visualization_type": visualization_type,
199
  "dimensions": viz_image.size,
200
  },
201
+ "metadata": {"total_elements": len(collection_data.get("text_elements", []))},
 
202
  }
 
203
  except Exception as e:
204
+ return {"success": False, "error": f"Visualization generation failed: {str(e)}", "visualization": None}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ def export_results(processing_state: str, output_formats: List[Literal["txt", "json", "alto", "page"]] = ["txt"], confidence_filter: float = 0.0) -> Dict:
207
+ """Export HTR results to multiple formats using HTRflow's native export functionality."""
 
 
 
 
 
 
 
 
 
 
208
  try:
 
209
  state = json.loads(processing_state)
210
+
211
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
212
+ image_data = base64.b64decode(state["image_base64"])
213
+ image = Image.open(io.BytesIO(image_data))
214
+ image.save(temp_file.name, "PNG")
215
+ temp_image_path = temp_file.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
+ try:
218
+ collection = Collection([temp_image_path])
219
+ pipeline = Pipeline.from_config(PIPELINE_CONFIGS[state["document_type"]])
220
+ processed_collection = pipeline.run(collection)
221
 
222
+ temp_dir = Path(tempfile.mkdtemp())
223
+ exports = {}
224
+
225
+ for fmt in output_formats:
226
+ export_dir = temp_dir / fmt
227
+ processed_collection.save(directory=str(export_dir), serializer=fmt)
228
+
229
+ export_files = []
230
+ for root, _, files in os.walk(export_dir):
231
+ for file in files:
232
+ file_path = os.path.join(root, file)
233
+ with open(file_path, 'r', encoding='utf-8') as f:
234
+ content = f.read()
235
+ export_files.append({"filename": file, "content": content})
236
+
237
+ exports[fmt] = export_files
238
+
239
+ shutil.rmtree(temp_dir)
240
 
241
+ return {
242
+ "success": True,
243
+ "exports": exports,
244
+ "export_metadata": {
245
+ "formats_generated": output_formats,
246
+ "confidence_filter": confidence_filter,
247
+ "timestamp": datetime.now().isoformat(),
248
+ },
249
+ }
250
+ finally:
251
+ if os.path.exists(temp_image_path):
252
+ os.unlink(temp_image_path)
253
+
254
  except Exception as e:
255
+ return {"success": False, "error": f"Export generation failed: {str(e)}", "exports": None}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
+ def extract_text_results(collection: Collection, confidence_threshold: float) -> Dict:
258
+ results = {"extracted_text": "", "text_lines": [], "confidence_scores": []}
259
  for page in collection.pages:
260
  for node in page.traverse():
261
+ if hasattr(node, "text") and node.text and hasattr(node, "confidence") and node.confidence >= confidence_threshold:
262
+ results["text_lines"].append({
263
+ "text": node.text,
264
+ "confidence": node.confidence,
265
+ "bbox": getattr(node, "bbox", None),
266
+ })
267
+ results["extracted_text"] += node.text + "\n"
268
+ results["confidence_scores"].append(node.confidence)
269
+
270
+ results["average_confidence"] = sum(results["confidence_scores"]) / len(results["confidence_scores"]) if results["confidence_scores"] else 0
 
 
 
 
 
 
271
  return results
272
 
273
+ def serialize_collection_data(collection: Collection) -> Dict:
274
+ text_elements = []
 
 
 
275
  for page in collection.pages:
 
 
 
 
 
 
276
  for node in page.traverse():
277
+ if hasattr(node, "text") and node.text:
278
+ text_elements.append({
279
+ "text": node.text,
280
+ "confidence": getattr(node, "confidence", 1.0),
281
+ "bbox": getattr(node, "bbox", None),
282
+ })
283
+ return {"text_elements": text_elements}
284
+
285
+ def create_visualization(image, collection_data, visualization_type, show_confidence, highlight_low_confidence):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  viz_image = image.copy()
287
  draw = ImageDraw.Draw(viz_image)
288
+
 
 
 
 
 
289
  try:
290
  font = ImageFont.truetype("arial.ttf", 12)
291
  except:
292
  font = ImageFont.load_default()
293
 
294
+ for element in collection_data.get("text_elements", []):
295
+ if element.get("bbox"):
296
+ bbox = element["bbox"]
297
+ confidence = element.get("confidence", 1.0)
298
+
299
+ if visualization_type == "overlay":
300
+ color = (255, 165, 0) if highlight_low_confidence and confidence < 0.7 else (0, 255, 0)
 
 
 
 
 
 
 
 
 
 
 
 
301
  draw.rectangle(bbox, outline=color, width=2)
 
 
302
  if show_confidence:
303
+ draw.text((bbox[0], bbox[1] - 15), f"{confidence:.2f}", fill=color, font=font)
304
+
305
+ elif visualization_type == "confidence_heatmap":
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  if confidence < 0.5:
307
+ color = (255, 0, 0, 100)
308
  elif confidence < 0.8:
309
+ color = (255, 255, 0, 100)
310
  else:
311
+ color = (0, 255, 0, 100)
 
 
312
  overlay = Image.new("RGBA", viz_image.size, (0, 0, 0, 0))
313
  overlay_draw = ImageDraw.Draw(overlay)
314
+ overlay_draw.rectangle(bbox, fill=color)
315
  viz_image = Image.alpha_composite(viz_image.convert("RGBA"), overlay)
316
+
317
+ elif visualization_type == "text_regions":
318
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
319
+ color = colors[hash(str(bbox)) % len(colors)]
320
+ draw.rectangle(bbox, outline=color, width=3)
321
 
322
+ return viz_image.convert("RGB") if visualization_type == "confidence_heatmap" else viz_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  def create_htrflow_mcp_server():
 
 
325
  demo = gr.TabbedInterface(
326
  [
327
  gr.Interface(
328
  fn=process_htr,
329
  inputs=[
330
  gr.Image(type="pil", label="Upload Image"),
331
+ gr.Dropdown(choices=["letter_english", "letter_swedish", "spread_english", "spread_swedish"], value="letter_english", label="Document Type"),
 
 
 
 
 
 
 
 
 
332
  gr.Slider(0.0, 1.0, value=0.8, label="Confidence Threshold"),
333
+ gr.Textbox(label="Custom Settings (JSON)", placeholder="Optional custom pipeline settings"),
 
 
 
334
  ],
335
  outputs=gr.JSON(label="Processing Results"),
336
  title="HTR Processing Tool",
 
340
  gr.Interface(
341
  fn=visualize_results,
342
  inputs=[
343
+ gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"),
344
+ gr.Dropdown(choices=["overlay", "confidence_heatmap", "text_regions"], value="overlay", label="Visualization Type"),
 
 
 
 
 
 
 
345
  gr.Checkbox(value=True, label="Show Confidence Scores"),
346
  gr.Checkbox(value=True, label="Highlight Low Confidence"),
347
+ gr.Image(type="pil", label="Image (optional)"),
 
 
 
348
  ],
349
  outputs=gr.JSON(label="Visualization Results"),
350
  title="Results Visualization Tool",
 
354
  gr.Interface(
355
  fn=export_results,
356
  inputs=[
357
+ gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"),
358
+ gr.CheckboxGroup(choices=["txt", "json", "alto", "page"], value=["txt"], label="Output Formats"),
 
 
 
 
 
 
 
 
359
  gr.Slider(0.0, 1.0, value=0.0, label="Confidence Filter"),
360
  ],
361
  outputs=gr.JSON(label="Export Results"),
 
367
  ["HTR Processing", "Results Visualization", "Export Results"],
368
  title="HTRflow MCP Server",
369
  )
 
370
  return demo
371
 
 
 
372
  if __name__ == "__main__":
373
  demo = create_htrflow_mcp_server()
374
+ demo.launch(mcp_server=True)